In [1]:
import os
import numpy as np
import pandas as pd
import nltk.tokenize
import re
import random
from nltk.util import ngrams
import tqdm
from nltk.tokenize import RegexpTokenizer
import torch

### Read Data

In [2]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence
from torchtext.data.utils import get_tokenizer
from collections import Counter, OrderedDict
from torchtext.vocab import vocab
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [3]:
def collate_fn(data):
    data.sort(key=lambda x: len(x[0]), reverse=True)
    text_data = []
    target_data = []
    for unit in data:
        text_data.append(torch.tensor(unit[0]))
        target_data.append(torch.tensor(unit[1]))
    text = pad_sequence(text_data, batch_first=True)
    target = pad_sequence(target_data, batch_first=True)
    return text, target

In [4]:
train_data_loader = torch.load('train_data_loader.pth')
test_data_loader = torch.load('test_data_loader.pth')
text_vocab = torch.load('text_vocab.pth')

### Build Model

In [5]:
class Encoder(torch.nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dec_hidden_dim, num_layers,dropout=0.5):
        super().__init__()
        self.input_dim=input_dim
        self.hidden_dim=hidden_dim
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.num_layers=num_layers
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim * 2, dec_hidden_dim)

        self.layer=nn.LSTM(input_size=emb_dim,hidden_size=hidden_dim, \
                        num_layers=num_layers,batch_first=True, \
                        dropout=dropout,bidirectional=True)
    
    def forward(self,x):
        batch_size = x.shape[0]
        
        embedded = self.dropout(self.embedding(x))     
        
        out,(hidden,c)=self.layer(embedded)
        
        s = torch.tanh(self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)))
        
        return out, s

Attention Mechanism

In [6]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        # [size(h_t)+size(s_{t-1}), dec_hid_dim]
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim, bias=False)
        self.v = nn.Linear(dec_hid_dim, 1, bias=False)

    def forward(self, s, enc_output):
        # s = [batch_size, dec_hid_dim]
        # enc_output = [batch_size, src_len, enc_hid_dim * 2]

        batch_size = enc_output.shape[0]
        src_len = enc_output.shape[1]

        # repeat decoder hidden state src_len times
        # s = [batch_size, src_len, enc_hid_dim * 2]
        # enc_output = [batch_size, src_len, enc_hid_dim * 2]
        s = s.unsqueeze(1).repeat(1, src_len, 1)

        # energy = [batch_size, src_len, dec_hid_dim]
        energy = torch.tanh(self.attn(torch.cat((s, enc_output), dim=2)))

        # attention = [batch_size, src_len]
        attention = self.v(energy).squeeze(2)

        return F.softmax(attention, dim=1)

In [7]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention, device):
        super().__init__()
        self.output_dim = output_dim
        self.dec_hid_dim = dec_hid_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.device = device
        
        self.layer=nn.LSTM(input_size=enc_hid_dim * 2 + emb_dim, hidden_size=dec_hid_dim, \
                        num_layers=1,batch_first=True, \
                        dropout=dropout,bidirectional=False)
        
        self.fc_out = nn.Linear(enc_hid_dim * 2 + dec_hid_dim + emb_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, dec_input, s, enc_output):
        # dec_input = [batch_size]
        # s = [batch_size, dec_hid_dim]
        # enc_output = [src_len, batch_size, enc_hid_dim *2]
        
        batch_size = dec_input.shape[0]

        # dec_input = [batch_size,1]
        dec_input = dec_input.unsqueeze(1)

        # embedded = [batch_size, 1, emb_dim]
        embedded = self.dropout(self.embedding(dec_input))

        # s = [batch_size, dec_hid_dim]
        # enc_output = [batch_size, src_len, enc_hid_dim *2]

        # a = [batch_size, 1, src_len]
        a = self.attention(s, enc_output).unsqueeze(1)

        # c = [batch_size, 1, enc_hid_dim * 2]
        c = torch.bmm(a, enc_output)

        # lstm_input = [batch_size, 1, (enc_hid_dim*2) + emb_dim]
        lstm_input = torch.cat((embedded, c), dim=2)
        
        c0 = torch.randn(1, batch_size, self.dec_hid_dim).to(self.device)

        # dec_output = [batch_size, src_len(=1), dec_hid_dim]
        # dec_hidden = [n_layers*num_directions, batch_size, dec_hid_dim]
        dec_output, (dec_hidden, _) = self.layer(lstm_input, (s.unsqueeze(0), c0))

        # embedded = [batch_size, emb_dim]
        # dec_output = [batch_size, dec_hid_dim]
        # c = [batch_size, enc_hid_dim * 2]
        embedded = embedded.squeeze(1)
        dec_output = dec_output.squeeze(1)
        c = c.squeeze(1)

        # pred = [batch_size, output_dim]
        pred = self.fc_out(torch.cat((dec_output, c, embedded), dim=1))

        return pred, dec_hidden.squeeze(0)


In [8]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.device = device
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # src = [batch_size, text_length]
        # trg = [batch_size, summarizarion_length]
        # teacher_forcing_ratio is probability to use teacher forcing (scheduled sampling)
        batch_size = src.shape[0]
        trg_len = trg.shape[1]
        vocab_size = self.decoder.output_dim

        outputs = torch.zeros(trg_len, batch_size, vocab_size).to(self.device)

        # enc_output : [src_len, batch_size, enc_hid_dim * 2]
        # s : [batch_size, dec_hid_dim]
        enc_output, s = self.encoder(src)

        # first input to the decoder is the <bob> tokens
        dec_input = trg[:, 0]

        for t in range(1, trg_len):
            dec_output, s = self.decoder(dec_input, s, enc_output)

            outputs[t] = dec_output

            # decide if using teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio

            # get the highest predicted token from predictions
            prediction = dec_output.argmax(1)

            # if teacher forcing, use actural next token as input
            # if not, use predicted token
            dec_input = trg[:, t] if teacher_force else prediction

        return outputs


In [9]:
# Define Hyper parameter

INPUT_DIM = len(text_vocab)
OUTPUT_DIM = len(text_vocab)
ENC_EMB_DIM = 300
DEC_EMB_DIM = 300
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

device = "cuda"

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, 2, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn, device)

model = Seq2Seq(enc, dec, device).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)



### Option: Initialize Model with the pretrained word embedding

In [None]:
from torchtext.vocab import GloVe

In [None]:
my_glove= GloVe()
pretained_embedding = my_glove.get_vecs_by_tokens(text_vocab.get_itos())
pretained_embedding = pretained_embedding.to(device)
enc.embedding.weight.data = pretained_embedding
dec.embedding.weight.data = pretained_embedding

In [None]:
pretained_embedding = pretained_embedding.to(device)
enc.embedding.weight.data = pretained_embedding
dec.embedding.weight.data = pretained_embedding

### Option: Reload a Trained Model

In [10]:
model_dir = 'model/3/lstm_model.pt'
model.load_state_dict(torch.load(model_dir))

<All keys matched successfully>

### Train and Evaluate Model

In [11]:
# Define train function

def train(model, iterator, optimizer, criterion):
    model.train()
    train_loss = 0
    for i, batch in tqdm.tqdm(enumerate(iterator)):
        text, highlight = batch 
        
        text = text.to(device)
        highlight = highlight.to(device)

        # pred = [highlight_len, batch_size, pred_dim]
        pred = model(text, highlight)
        pred_dim = pred.shape[-1]

        # highlight = [highlight_len*batch_size]
        # pred = [highlight_len*batch_size]
        highlight = highlight.view(-1)
        pred = pred.permute([1,0,2]).contiguous().view(-1, pred_dim)

        loss = criterion(pred, highlight)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if i%500 == 0 and i!=0:
            print(train_loss/500)
            train_loss = 0


In [12]:
# Define evaluation function

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in tqdm.tqdm(enumerate(iterator)):
            text, highlight = batch
            
            text = text.to(device)
            highlight = highlight.to(device)

            # output = [highlight_len, batch_size, output_dim]
            output = model(text, highlight, 0) # turn off teacher forcing

            output_dim = output.shape[-1]

            # trg = [highlight_len*batch_size]
            # output = [highlight_len*batch_size, output_dim]
            output = output.permute([1,0,2]).contiguous().view(-1, output_dim)
            highlight = highlight.view(-1)

            loss = criterion(output, highlight)
            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [13]:
num_epoch = 3

for epoch in range(num_epoch):

    train(model, train_data_loader, optimizer, criterion)
    print(evaluate(model, test_data_loader, criterion))


501it [07:45,  1.18it/s]

6.219203948974609


1001it [15:28,  1.15it/s]

6.276841556549072


1501it [23:10,  1.17it/s]

6.235005108833313


2001it [30:57,  1.21it/s]

6.265779499530792


2501it [38:45,  1.00it/s]

6.279764371395111


3001it [46:22,  1.15it/s]

6.2363659567832945


3501it [54:07,  1.03it/s]

6.2297070064544675


4001it [1:01:53,  1.05s/it]

6.225161261558533


4501it [1:09:41,  1.11it/s]

6.207482703685761


5001it [1:17:25,  1.05it/s]

6.210640722751617


5501it [1:25:16,  1.00s/it]

6.199427905559539


6001it [1:33:03,  1.06it/s]

6.211570686817169


6501it [1:40:43,  1.15it/s]

6.167676355361938


7001it [1:48:29,  1.13it/s]

6.1602605438232425


7501it [1:56:11,  1.10it/s]

6.1262509484291074


8001it [2:03:56,  1.01s/it]

6.143068867683411


8501it [2:11:45,  1.14it/s]

6.128458675384522


9001it [2:19:37,  1.10it/s]

6.126866454124451


9501it [2:27:28,  1.02s/it]

6.089778331756592


10001it [2:35:17,  1.04it/s]

6.10336150598526


10501it [2:43:02,  1.10s/it]

6.128002324581146


11001it [2:50:50,  1.00it/s]

6.10438598203659


11045it [2:51:31,  1.07it/s]


KeyboardInterrupt: 

In [14]:
torch.save(model.state_dict(), "model/4/lstm_model.pt")