In [7]:
# utils
from utils import count_parameters
import torch
import math

# data
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

# model
import torch.nn as nn
import torch.nn.functional as F

# training
import torch.optim as optim
import tqdm

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
device

device(type='cpu')

## Data Preparation

In [10]:
# create data fields for source and target
source = Field(
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,
    tokenize="spacy",
    tokenizer_language="de",
    batch_first=True
)
target = Field(
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,
    tokenize="spacy",
    tokenizer_language="de",
    batch_first=True
)

In [11]:
# download the parallel corpus
train, val, test = Multi30k.splits(
    exts=(".de", ".en"),
    fields=(source, target)
)

In [12]:
# build the vocab
source.build_vocab(train)
target.build_vocab(train)

In [13]:
# create data loaders
BATCH_SIZE = 64
train_loader, val_loader, test_loader = BucketIterator.splits(
    datasets=(train, val, test),
    batch_size=BATCH_SIZE,
    device=device,
    shuffle=True
)

In [14]:
batch =  next(iter(train_loader))
print(batch.src.shape, batch.trg.shape)

torch.Size([64, 23]) torch.Size([64, 26])


### PyTorch's Transformer's Model

In [73]:
class EmbeddingLayer(nn.Module):
    """
        takes input as token and convert it into embeddings
    """
    
    def __init__(self, vocab_size, embedding_dim):
        
        super(EmbeddingLayer, self).__init__()  
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
    
    def forward(self, x):
        """
            Ideally x.shape -> [batch, seq_len]
        """
        embedded = self.embedding(x)
        return embedded

In [74]:
# encoder model
class Encoder(nn.Module):
    """
        encoder will get source token and will produce contextualized embedding
    """
    def __init__(self, embedding_dim, vocab_size, num_heads=8, num_layers=1, max_len=100, dropout=0.15):
        super(Encoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        
        # same embedding can be used for both sys
        self.tok_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.pos_embedding = nn.Embedding(num_embeddings=max_len, embedding_dim=embedding_dim)
        
        # scaling
        self.scale = torch.sqrt(torch.FloatTensor([embedding_dim])).to(device)
    
        self.encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads), 
            num_layers=num_layers, 
            norm=nn.LayerNorm(normalized_shape=embedding_dim)
        )
        self.dropout = nn.Dropout(p=dropout)
    
    
    def forward(self, src, src_mask=None):
        
        """
            src.shape -> [batch, src_len]
        """
        batch, src_len = src.shape[0], src.shape[1]
        
        # create position tensor, shape will be [batch, src_len] by dooing so batch_first will be True
        position  = torch.arange(start=0, end=src_len, device=device).unsqueeze(0).repeat(batch, 1)
        
        # embeddings
        tok_embedded = self.tok_embedding(src)
        pos_embedded = self.pos_embedding(position)
        
        # scale the token embeddings by multiplyig it with srqt(d_model) where d_model is embedding_dim
        tok_scaled = tok_embedded * self.scale
        
        # add the scaled_tok and position embedding and then apply dropout, that will be input to the encoder
        encoder_input = self.dropout(tok_scaled + pos_embedded)
        
        
        encoded = self.encoder(encoder_input)
        
        return encoded
        
        

In [75]:
class Decoder(nn.Module):
    
    def __init__(self, embedding_dim, vocab_size, num_heads=8, num_layers=8, max_len=100, dropout=0.2):
        super(Decoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        
        self.tok_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.pos_embedding = nn.Embedding(num_embeddings=max_len, embedding_dim=embedding_dim)
        
        # scaling
        self.scale = torch.sqrt(torch.FloatTensor([embedding_dim])).to(device)
        
        # dropout layer
        self.dropout = nn.Dropout(p=0.15)
        
        self.decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=num_heads),
            num_layers=num_layers,
            norm=nn.LayerNorm(normalized_shape=embedding_dim)
        )
        
        
        self.fc_out = nn.Linear(in_features=embedding_dim, out_features=vocab_size)
    
    def forward(self, trg, src_encoded):
        
        """
            trg.shape -> [batch, trg_len]
        """
        batch, trg_len = trg.shape[0], trg.shape[1]
        
        # create position tensor, shape will be [batch, src_len] by dooing so batch_first will be True
        position  = torch.arange(start=0, end=trg_len, device=device).unsqueeze(0).repeat(batch, 1)
        
        # embeddings
        tok_embedded = self.tok_embedding(trg)
        pos_embedded = self.pos_embedding(position)
        
        # scale the token embeddings by multiplyig it with srqt(d_model) where d_model is embedding_dim
        tok_scaled = tok_embedded * self.scale
        
        # add the scaled_tok and position embedding and then apply dropout, that will be input to the encoder
        decoder_input = self.dropout(tok_scaled + pos_embedded)
        
        outputs = self.decoder(decoder_input.permute(1, 0, 2), src_encoded.permute(1, 0, 2))
        prediction = self.fc_out(outputs)
        return prediction
         

### Training with PyTorch-Lightning

In [76]:
import pytorch_lightning as pl

In [77]:
class Transformer(pl.LightningModule):
    
    def __init__(self, encoder, decoder, PAD_IDX):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src, trg):
        src_encoded = self.encoder(src)
        outputs = self.decoder(trg, src_encoded)
        return outputs
    
    def configure_optimizers(self):
        return optim.Adam(params=self.parameters(), lr=1e-3)
    
    def train_dataloader(self):
        return train_loader
    
    def training_step(self, batch, batch_idx):
        src, trg = batch.src, batch.trg
        batch_size, trg_len = trg.shape[0], trg.shape[1] # batch_first was true in BucketIterato
        outputs = self(src, trg)
        loss = F.cross_entropy(outputs.view(batch_size*trg_len, -1), trg.view(-1), ignore_index=PAD_IDX)
        ppl = torch.exp(loss)
        tensorboard_logs = {"loss":loss, "ppl":ppl}
        return {"loss":loss, "ppl":ppl,"log":tensorboard_logs}
    
    def val_dataloader(self):
        return val_loader
    
    def validation_step(self, batch, batch_idx):
        src, trg = batch.src, batch.trg
        batch_size, trg_len = trg.shape[0], trg.shape[1] # batch_first was true in BucketIterato
        outputs = self(src, trg)
        loss = F.cross_entropy(outputs.view(batch_size*trg_len, -1), trg.view(-1), ignore_index=PAD_IDX)
        ppl = torch.exp(loss)
        tensorboard_logs = {"val_loss":loss, "val_ppl":ppl}
        return {"val_loss":loss, "val_ppl":ppl,"log":tensorboard_logs}
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_ppl = torch.stack([x['val_ppl'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss, 'val_ppl':avg_ppl}
        return {'val_loss': avg_loss,'val_ppl':avg_ppl, 'log': tensorboard_logs}    

In [78]:
src_vocab_size = len(source.vocab)
trg_vocab_size = len(target.vocab)
PAD_IDX = target.vocab.stoi[target.pad_token]
embedding_dim = 256

In [79]:
encoder = Encoder(embedding_dim=embedding_dim, vocab_size=src_vocab_size)
decoder = Decoder(embedding_dim=embedding_dim, vocab_size=trg_vocab_size)
transformer = Transformer(encoder=encoder, decoder=decoder, PAD_IDX=PAD_IDX)

In [80]:
trainer = pl.Trainer(max_epochs=2)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


In [81]:
trainer.fit(transformer)


  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 6 M   
1 | decoder | Decoder | 17 M  


Epoch 1:  97%|█████████▋| 454/470 [09:44<00:20,  1.29s/it, loss=5.413, v_num=5]
Validating: 0it [00:00, ?it/s][A
Epoch 1:  97%|█████████▋| 455/470 [09:44<00:19,  1.28s/it, loss=5.413, v_num=5]
Epoch 1:  97%|█████████▋| 456/470 [09:44<00:17,  1.28s/it, loss=5.413, v_num=5]
Epoch 1:  97%|█████████▋| 457/470 [09:44<00:16,  1.28s/it, loss=5.413, v_num=5]
Epoch 1:  97%|█████████▋| 458/470 [09:44<00:15,  1.28s/it, loss=5.413, v_num=5]
Epoch 1:  98%|█████████▊| 459/470 [09:44<00:14,  1.27s/it, loss=5.413, v_num=5]
Epoch 1:  98%|█████████▊| 460/470 [09:45<00:12,  1.27s/it, loss=5.413, v_num=5]
Epoch 1:  98%|█████████▊| 461/470 [09:45<00:11,  1.27s/it, loss=5.413, v_num=5]
Epoch 1:  98%|█████████▊| 462/470 [09:45<00:10,  1.27s/it, loss=5.413, v_num=5]
Epoch 1:  99%|█████████▊| 463/470 [09:45<00:08,  1.26s/it, loss=5.413, v_num=5]
Epoch 1:  99%|█████████▊| 464/470 [09:45<00:07,  1.26s/it, loss=5.413, v_num=5]
Epoch 1:  99%|█████████▉| 465/470 [09:46<00:06,  1.26s/it, loss=5.413, v_num=5]
Epoch 

1

In [82]:
trainer.run_evaluation()


Validating: 0it [00:00, ?it/s][A
Validating:   6%|▋         | 1/16 [00:00<00:03,  4.80it/s][A
Validating:  12%|█▎        | 2/16 [00:00<00:02,  5.14it/s][A
Validating:  19%|█▉        | 3/16 [00:00<00:02,  5.48it/s][A
Validating:  25%|██▌       | 4/16 [00:00<00:02,  5.73it/s][A
Validating:  31%|███▏      | 5/16 [00:00<00:01,  5.53it/s][A
Validating:  38%|███▊      | 6/16 [00:01<00:01,  5.15it/s][A
Validating:  44%|████▍     | 7/16 [00:01<00:01,  5.14it/s][A
Validating:  50%|█████     | 8/16 [00:01<00:01,  5.08it/s][A
Validating:  56%|█████▋    | 9/16 [00:01<00:01,  5.37it/s][A
Validating:  62%|██████▎   | 10/16 [00:01<00:01,  5.58it/s][A
Validating:  69%|██████▉   | 11/16 [00:02<00:00,  5.60it/s][A
Validating:  75%|███████▌  | 12/16 [00:02<00:00,  5.43it/s][A
Validating:  81%|████████▏ | 13/16 [00:02<00:00,  4.93it/s][A
Validating:  88%|████████▊ | 14/16 [00:02<00:00,  4.90it/s][A
Validating:  94%|█████████▍| 15/16 [00:02<00:00,  4.69it/s][A
Validating: 100%|██████████| 

{'val_loss': tensor(5.3635), 'val_ppl': tensor(216.9025)}