In [1]:
%config Completer.use_jedi = False

In [2]:
import random
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
hyp_params = {
    "batch_size": 64,
    "num_epochs": 1,

    # Encoder parameters
    "encoder_embedding_size": 300,
    "encoder_dropout": 0.5,

    # Decoder parameters
    "decoder_dropout": 0.5,
    "decoder_embedding_size": 300,

    # Common parameters
    "hidden_size": 1024,
    "num_layers": 2
}

In [5]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout):
        super(Encoder, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.LSTM = nn.LSTM(embedding_dim, hidden_size, num_layers, dropout=dropout)
        
    def forward(self, x):
        embedding = self.dropout(self.embedding(x))
        
        outputs, (hidden_state, cell_state) = self.LSTM(embedding)

        return hidden_state, cell_state
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout, output_size):
        super(Decoder, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.LSTM = nn.LSTM(embedding_dim, hidden_size, num_layers, dropout=dropout)
        
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x, hidden_state, cell_state):
        # as we are not feeding whole sentence we will each token a time
        # hence our sequence length would be just 1 however shape of x is batch_size
        # to add sequence length we will unsequeeze it
        x = x.unsqueeze(0)
        
        embedded = self.dropout(self.embedding(x))
        
        outputs, (hidden_state, cell_state) = self.LSTM(embedded, (hidden_state, cell_state))

        # Shape --> predictions  (1, batch_size, target_vocab_size)
        predictions = self.fc(outputs)

        # Shape --> predictions (32, 4556) (batch_size , target_vocab_size)
        predictions = self.softmax(predictions.squeeze(0))

        return predictions, hidden_state, cell_state

class SeqtoSeq(nn.Module):
    def __init__(self, gen_params, target_vocab, device):
        super(SeqtoSeq, self).__init__()

        self.Encoder = Encoder(gen_params["input_size_encoder"],
                          gen_params["encoder_embedding_size"],
                          gen_params["hidden_size"],
                          gen_params["num_layers"],
                          gen_params["encoder_dropout"]).to(device)

        self.Decoder = Decoder(gen_params["input_size_decoder"],
                          gen_params["decoder_embedding_size"],
                          gen_params["hidden_size"],
                          gen_params["num_layers"],
                          gen_params["decoder_dropout"],
                          gen_params["output_size"]).to(device)

        self.target_vocab = target_vocab
        self.device = device
    
    def forward(self, source, target, tfr=0.5):
        # Shape -> (Sentence length, Batch_size)
        batch_size = source.shape[1]

        target_len = target.shape[0]  # Length of target sentences
        target_vocab_size = len(self.target_vocab)
        
        # here we will store all the outputs
        # so outputs is arrange in a way that sentences are in column and batch size is row and every element
        # will consist of probability of each word from the vocab
        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(self.device)

        # Shape --> (hs, cs) (num_layers, batch_size size, hidden_size) (contains encoder's hs, cs - context vectors)
        hidden_state, cell_state = self.Encoder(source)

        # Shape of x (32 elements)
        x = target[0]  # First token (Trigger)

        for i in range(1, target_len):
            # Shape --> output (batch_size, target_vocab_size)
            output, hidden_state, cell_state = self.Decoder(x, hidden_state, cell_state)
            outputs[i] = output
            best_guess = output.argmax(1)  # 0th dimension is batch size, 1st dimension is word embedding
            # Schedule sampling
            x = target[
                i] if random.random() < tfr else best_guess  # Either pass the next word correctly from the dataset
            # or use the earlier predicted word

        # Shape --> (sentence length, batch size, vocab size)
        return outputs

In [6]:
from torch.utils.data import Dataset
from collections import Counter
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer

class nmtDataset(Dataset):
    def __init__(self, ds_path, split):
        self.en = open(ds_path + split + '.en', encoding='utf-8').readlines()
        self.de = open(ds_path + split + '.de', encoding='utf-8').readlines()
        
        self.tokenizers = { 'en': get_tokenizer('spacy', language='en'), 
                            'de': get_tokenizer('spacy', language='de') }
        
        self.src_vocab, self.trg_vocab = self._build_vocab()
    
    def __len__(self):
        return len(self.en)

    def __getitem__(self, item):
        src_tokens = self.tokenizers['en'](self.en[item].strip())
        trg_tokens = self.tokenizers['de'](self.de[item].strip())
  
        return {
            "src": self.src_vocab.lookup_indices(src_tokens),
            "trg": [self.trg_vocab['<sos>']] + self.trg_vocab.lookup_indices(trg_tokens) + [self.trg_vocab['<eos>']]
        }
        
    def _build_vocab(self):
        src_vocab = build_vocab_from_iterator(self._get_tokens(self.en, 'en'), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=10)  # discarding words occurs < 10 times
        trg_vocab = build_vocab_from_iterator(self._get_tokens(self.de, 'de'), specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=10)
  
        trg_vocab.set_default_index(trg_vocab['<unk>'])
        src_vocab.set_default_index(src_vocab['<unk>'])

        return src_vocab, trg_vocab
    
    def _get_tokens(self, corpus, lang):
        for line in corpus:
            yield self.tokenizers[lang](line.strip())
    
nmtds_train = nmtDataset('datasets/Multi30k/', 'train')
nmtds_valid = nmtDataset('datasets/Multi30k/', 'val')



In [7]:
from torch.utils.data import DataLoader


def collate_fn(batch, device):
    PAD_IDX = 1
    trgs = []
    srcs = []
    for row in batch:
        srcs.append(torch.tensor(row["src"]).to(device))
        trgs.append(torch.tensor(row["trg"]).to(device))

    padded_srcs = pad_sequence(srcs, padding_value=PAD_IDX)
    padded_trgs = pad_sequence(trgs, padding_value=PAD_IDX)
    return {"src": padded_srcs, "trg": padded_trgs}

train_dataloader = DataLoader(nmtds_train, batch_size=hyp_params['batch_size'], shuffle=True,
                              collate_fn=lambda batch_size: collate_fn(batch_size, device))

valid_dataloader = DataLoader(nmtds_valid, batch_size=hyp_params['batch_size'], shuffle=True,
                              collate_fn=lambda batch_size: collate_fn(batch_size, device))

In [8]:
hyp_params["input_size_encoder"] = len(nmtds_train.src_vocab)
hyp_params["input_size_decoder"] = len(nmtds_train.trg_vocab)
hyp_params["output_size"] = len(nmtds_train.trg_vocab)

model = SeqtoSeq(hyp_params, target_vocab=nmtds_train.trg_vocab, device=device)
optimizer = optim.Adam(model.parameters())

pad_idx = nmtds_train.trg_vocab["<pad>"]
criterion = nn.NLLLoss(ignore_index=pad_idx).to(device)

In [211]:
model.train()

epoch_loss = 0
for epoch in range(hyp_params["num_epochs"]):
    for batch_idx, batch in enumerate(tqdm(train_dataloader)):
        src = batch["src"]  # shape --> e.g. (19, 2) sentence len, batch size
        trg = batch["trg"]  # shape --> e.g. (3, 2) sentence len, batch size

        # Clear the accumulating gradients
        optimizer.zero_grad()

        # Pass the input and target for model's forward method
        # Shape --> (sentence len of TRG, batch size, vocab size) e.g (3, 2, 196)
        # Explanation:
        #    It just outputs probabilities for every single word in our vocab
        #    for each word in sentence and each sentence in batch size
        output = model(src, trg)

        # Updating output shape --> [sentence length * batch size , vocab size]
        # e.g (6, 196)
        output = output.reshape(-1, output.shape[2])

        # sentence len  * batch size
        target = trg.reshape(-1)

        # Calculate the loss value for every epoch
        loss = criterion(output, target)

        # Calculate the gradients for weights & biases using back-propagation
        loss.backward()

        epoch_loss += loss.item()

        # Clip the gradient value is it exceeds > 1
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        # Update the weights values
        optimizer.step()

100%|██████████| 454/454 [40:01<00:00,  5.29s/it]


2401.5422139167786

In [212]:
!pip freeze

anyio==3.3.0
appnope==0.1.2
argon2-cffi==20.1.0
async-generator==1.10
attrs==21.2.0
Babel==2.9.1
backcall==0.2.0
bleach==4.1.0
blis==0.7.4
catalogue==2.0.6
certifi==2021.5.30
cffi==1.14.6
charset-normalizer==2.0.4
click==7.1.2
contextvars==2.4
cymem==2.0.5
dataclasses==0.8
de-core-news-sm @ https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.1.0/de_core_news_sm-3.1.0-py3-none-any.whl
decorator==5.0.9
defusedxml==0.7.1
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0-py3-none-any.whl
entrypoints==0.3
idna==3.2
immutables==0.16
importlib-metadata==4.6.4
ipykernel==5.5.5
ipython==7.16.1
ipython-genutils==0.2.0
jedi==0.18.0
Jinja2==3.0.1
json5==0.9.6
jsonschema==3.2.0
jupyter-client==7.0.1
jupyter-core==4.7.1
jupyter-server==1.10.2
jupyterlab==3.1.9
jupyterlab-pygments==0.1.2
jupyterlab-server==2.7.2
MarkupSafe==2.0.1
mistune==0.8.4
murmurhash==1.0.5
nbclassic==0.3.1
nbclient==0.5.4
nbconvert==6