In [48]:
import math
import pandas as pd
import torch
from torch import nn
from d2l import torch as d2l

In [49]:
num_hiddens = 32
num_layers = 2
dropout = 0.1
batch_size = 64
num_steps = 10
lr = 0.005
num_epochs = 200
device = d2l.try_gpu()
train_iter,src_vocab,tgt_vocab = d2l.load_data_nmt(batch_size,num_steps)

In [50]:
for x in train_iter:
    print(x[0].shape) # src sentence
    print(x[1].shape) # valid length
    print(x[2].shape) # tgt sentence
    print(x[3].shape) # valid length
    break
print(len(src_vocab))
print(len(tgt_vocab))

torch.Size([64, 10])
torch.Size([64])
torch.Size([64, 10])
torch.Size([64])
184
201


In [51]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math 
device = torch.device('cpu')

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size,embedding_size) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_size)
        self.embedding_size = embedding_size
    
    def forward(self,x):
        return self.embedding(x.long())*math.sqrt(self.embedding_size)

class PositionalEncoding(nn.Module):
    def __init__(self,dimen,dropout=0.1,max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        PE = torch.zeros(max_len,dimen)
        position = torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,dimen,2).float() * (-math.log(10000.0) / dimen))
        PE[:,0::2] = torch.sin(position * div_term)
        PE[:,1::2] = torch.cos(position * div_term)
        PE = PE.unsqueeze(0).transpose(0,1)
        self.register_buffer('PE',PE)
    
    def forward(self,x):
        x = x + self.PE[:x.size(0),:]
        return self.dropout(x)

class seq2seqTrans(nn.Module):
    def __init__(self,num_encoder_layers,num_decoder_layers,emb_size,nhead,src_vocab_size,tgt_vocab_size,dim_feedforward=512,dropout=0.1):
        super().__init__()
        self.src_tok_emb = TokenEmbedding(src_vocab_size,emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size,emb_size)

        self.pos_emb = PositionalEncoding(emb_size,dropout)

        self.transformer = Transformer(d_model=emb_size,nhead=nhead,num_encoder_layers=num_encoder_layers,num_decoder_layers=num_decoder_layers,dim_feedforward=dim_feedforward,dropout=dropout)
        self.linear = nn.Linear(emb_size,tgt_vocab_size)
    
    def forward(self,src,tgt,src_mask,tgt_mask,src_padding_mask,tgt_padding_mask,src_key_padding_mask,tgt_key_padding_mask,memory_key_padding_mask):
        src_emb = self.pos_emb(self.src_tok_emb(src))

        tgt_emb = self.pos_emb(self.tgt_tok_emb(tgt))
        memory = self.transformer(src_emb,tgt_emb,src_mask,tgt_mask,None,src_key_padding_mask,tgt_key_padding_mask,memory_key_padding_mask)
        return self.linear(memory)
    
    def encode(self,src,src_mask):
        return self.transformer.encoder(self.pos_emb(self.src_tok_emb(src)),src_mask)
    
    def decode(self,tgt,memory,tgt_mask):
        return self.transformer.decoder(self.pos_emb(self.tgt_tok_emb(tgt)),memory,tgt_mask)



model = seq2seqTrans(num_encoder_layers=num_layers,
                     num_decoder_layers=num_layers,
                     emb_size=num_hiddens,
                     nhead=4,
                     src_vocab_size=len(src_vocab),
                     tgt_vocab_size=len(tgt_vocab),
                     dim_feedforward=512,
                     dropout=dropout)

In [52]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == src_vocab['<pad>']).transpose(0, 1)
    tgt_padding_mask = (tgt == tgt_vocab['<pad>']).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

## Training the model
attention to the mask when input the model for gradient decent

In [53]:
loss_fn = nn.CrossEntropyLoss(ignore_index=src_vocab['<pad>'])
optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    

In [66]:
model.to(device)
model.train()
from torch.utils.data import DataLoader as Dataloader
for src,src_valid,tgt,tgt_valid in train_iter:
    src.to(device)
    tgt.to(device)
    print('src shape:',src.shape)
    print('tgt shape:',tgt.shape)
    src_mask,tgt_mask,src_padding_mask,tgt_padding_mask = create_mask(src,tgt)
    src_mask.to(device)
    tgt_mask.to(device)
    src_padding_mask.to(device)
    tgt_padding_mask.to(device)
    

    print(src_mask.shape)
    print(tgt_mask.shape)
    print(src_padding_mask.shape)
    print(tgt_padding_mask.shape)

    logits = model(src,
                   tgt,
                   src_mask,
                   tgt_mask,
                   src_padding_mask,
                   tgt_padding_mask,
                   src_padding_mask,
                   tgt_padding_mask,
                   src_padding_mask)
    print(logits.shape)
    optimizer.zero_grad()
    l = loss_fn(logits.reshape(-1,len(tgt_vocab)),tgt.reshape(-1))
    l.backward()
    optimizer.step()

src shape: torch.Size([64, 10])
tgt shape: torch.Size([64, 10])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([10, 64])
torch.Size([10, 64])
torch.Size([64, 10, 201])
src shape: torch.Size([64, 10])
tgt shape: torch.Size([64, 10])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([10, 64])
torch.Size([10, 64])
torch.Size([64, 10, 201])
src shape: torch.Size([64, 10])
tgt shape: torch.Size([64, 10])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([10, 64])
torch.Size([10, 64])
torch.Size([64, 10, 201])
src shape: torch.Size([64, 10])
tgt shape: torch.Size([64, 10])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([10, 64])
torch.Size([10, 64])
torch.Size([64, 10, 201])
src shape: torch.Size([64, 10])
tgt shape: torch.Size([64, 10])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([10, 64])
torch.Size([10, 64])
torch.Size([64, 10, 201])
src shape: torch.Size([64, 10])
tgt shape: torch.Size([64, 10])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([10, 64])
tor