In [2]:
import torch
import torch.nn as nn
from datasets import load_dataset

In [3]:
class PositionEncod(nn.Module):
    def __init__(self,d_model,seq_len,dropout):
        super().__init__()
        self.dropout=nn.Dropout(dropout)
        pos_enc=torch.zeros((seq_len,d_model))
        pos=torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1)
        div=torch.exp(torch.arange(0,d_model,2).float()*(-torch.log(torch.tensor(10000.0))/d_model))
        pos_enc[:,0::2]=torch.sin(pos*div)
        pos_enc[:,1::2]=torch.cos(pos*div)
        pos_enc=pos_enc.unsqueeze(0)
        self.register_buffer('pos_enc',pos_enc)
    def forward(self,x):
        return self.dropout(x+self.pos_enc[:,x.shape[1]-1,:])

In [4]:
class MHA(nn.Module):
    def __init__(self,d_model:int,h:int,dropout:float):
        super().__init__()
        self.d_model=d_model
        self.h=h
        assert d_model%h==0,'emd_dim is not divisible by heads'
        self.d_key=self.d_model//self.h
        self.q=nn.Linear(d_model,d_model)
        self.k=nn.Linear(d_model,d_model)
        self.v=nn.Linear(d_model,d_model)
        self.o=nn.Linear(d_model,d_model)
        self.dropout=nn.Dropout(dropout)
    @staticmethod
    def compute_attention(q,k,v,dropout:nn.Dropout,mask):
        d_key=k.shape[-1]
        a_scores=(q@k.transpose(-1,-2))/d_key**0.5
        if mask is not None:
            a_scores=a_scores.masked_fill_(mask==0,-1e9)
        a_scores=a_scores.softmax(-1)
        if dropout is not None:
            a_scores=dropout(a_scores)
        return (a_scores@v),a_scores
    def forward(self,q,k,v,mask=None):
        query=self.q(q)
        key=self.k(k)
        value=self.v(v)
        query=query.reshape(query.shape[0],-1,self.h,self.d_key).transpose(1,2)
        key=key.reshape(key.shape[0],-1,self.h,self.d_key).transpose(1,2)
        value=value.reshape(value.shape[0],-1,self.h,self.d_key).transpose(1,2)
        x,self.attention_scores=MHA.compute_attention(query,key,value,self.dropout,mask)
        x=x.transpose(1,2).reshape(x.shape[0],-1,self.h*self.d_key)
        return self.o(x)

In [5]:
class ResidualConnection(nn.Module):
    def __init__(self,d_model: int,dropout: float):
        super().__init__()
        self.dropout=nn.Dropout(dropout)
        self.norm=nn.LayerNorm(d_model)
    def forward(self,x,sublayer):
        return x+self.dropout(sublayer(self.norm(x)))

In [6]:
class FeedForwardBlock(nn.Module):
    def __init__(self,d_model,dropout):
        super().__init__()
        self.l1=nn.Linear(d_model,2*d_model)
        self.l2=nn.Linear(2*d_model,d_model)
        self.dropout=nn.Dropout(dropout)
    def forward(self,x):
        x=self.dropout(self.l1(x))
        return self.dropout(self.l2(x))

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self,d_model,h,dropout):
        super().__init__()
        self.mha=MHA(d_model,h,dropout)
        self.ffb=FeedForwardBlock(d_model,dropout)
        self.rc=nn.ModuleList([ResidualConnection(d_model,dropout) for _ in range(2)])
    def forward(self,x,src_mask=None):
        x=self.rc[0](x,lambda x:self.mha(x,x,x,src_mask))
        x=self.rc[1](x,self.ffb)
        return x

In [8]:
class Encoder(nn.Module):
    def __init__(self,num_blocks,d_model,h,dropout):
        super().__init__()
        self.layers=nn.ModuleList([EncoderBlock(d_model,h,dropout) for _ in range(num_blocks)])
        self.norm=nn.LayerNorm(d_model)
    def forward(self,x,mask=None):
        for layer in self.layers:
            x=layer(x,mask)
        return self.norm(x)

In [9]:
class DecoderBlock(nn.Module):
    def __init__(self,d_model,h,dropout):
        super().__init__()
        self.smha=MHA(d_model,h,dropout)
        self.cmha=MHA(d_model,h,dropout)
        self.ffb=FeedForwardBlock(d_model,dropout)
        self.rc=nn.ModuleList([ResidualConnection(d_model,dropout) for _ in range(3)])
    def forward(self,x,encoder_output,src_mask,tar_mask):
        x=self.rc[0](x,lambda x:self.smha(x,x,x,tar_mask))
        x=self.rc[1](x,lambda x:self.cmha(x,encoder_output,encoder_output,src_mask))
        x=self.rc[2](x,self.ffb)
        return x

In [10]:
class Decoder(nn.Module):
    def __init__(self,num_blocks,d_model,h,dropout):
        super().__init__()
        self.layers=nn.ModuleList([DecoderBlock(d_model,h,dropout) for _ in range(num_blocks)])
        self.norm=nn.LayerNorm(d_model)
    def forward(self,x,encoder_output,src_mask,tar_mask):
        for layer in self.layers:
            x=layer(x,encoder_output,src_mask,tar_mask)
        return self.norm(x)

In [11]:
class Transformer(nn.Module):
    def __init__(self,num_blocks,src_len,tar_len,vocab_size_src,vocab_size_tar,d_model,h,dropout):
        super().__init__()
        self.src_emb=nn.Embedding(vocab_size_src,d_model)
        self.tar_emb=nn.Embedding(vocab_size_tar,d_model)
        self.src_pos=PositionEncod(d_model,src_len,dropout)
        self.tar_pos=PositionEncod(d_model,tar_len,dropout)
        self.encoder=Encoder(num_blocks,d_model,h,dropout)
        self.decoder=Decoder(num_blocks,d_model,h,dropout)
        self.linear=nn.Linear(d_model,vocab_size_tar)
    def encode(self,x,src_mask):
        x=self.src_pos(self.src_emb(x))
        x=self.encoder(x,src_mask)
        return x
    def decode(self,x,encoder_output,src_mask,tar_mask):
        x=self.tar_pos(self.tar_emb(x))
        x=self.decoder(x,encoder_output,src_mask,tar_mask)
        return x
    def forward(self,x):
        e_input,d_input,e_mask,d_mask,a_mask=x['encoder_input'].int(),x['decoder_input'].int(),x['encoder_mask'],x['decoder_mask'],x['attention_mask']
        x=self.encode(e_input,e_mask)
        x=self.decode(d_input,x,d_mask,a_mask)
        x=self.linear(x)
        return x

In [12]:
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import Tokenizer
from tokenizers.trainers import BpeTrainer
from torch.utils.data import Dataset,DataLoader
from tokenizers.processors import TemplateProcessing

In [13]:
ds=load_dataset('Helsinki-NLP/tatoeba_mt','eng-tel')

In [14]:
def get_sentences(ds,lang='eng'):
    if lang=='eng':
        for i in ds['test']['sourceString']:
            yield i
    else:
        for i in ds['test']['targetString']:
            yield i

In [15]:
def get_build_tokenizer(ds,lang):
    tokenizer=Tokenizer(BPE(unk_token='[UNK]'))
    tokenizer.pre_tokenizer=Whitespace()
    trainer=BpeTrainer(special_tokens=['[UNK]','[PAD]','[SOS]','[EOS]'],min_frequency=1)
    tokenizer.train_from_iterator(get_sentences(ds,lang),trainer=trainer)
    return tokenizer

In [16]:
tokenizer_src=get_build_tokenizer(ds,'eng')
tokenizer_tar=get_build_tokenizer(ds,'tel')

In [17]:
class BiLingualDataset(Dataset):
    def __init__(self,ds,token_src,token_tar,seq_len,post_pr):
        super().__init__()
        self.token_src=token_src
        self.token_tar=token_tar
        self.seq_len=seq_len
        self.ds=ds

        self.pad_id=self.token_tar.token_to_id('[PAD]')
        self.sos=self.token_src.token_to_id('[SOS]')
        self.eos=self.token_src.token_to_id('[EOS]')
        self.token_src.enable_padding(pad_id=self.token_src.token_to_id('[PAD]'),pad_token='[PAD]',length=seq_len)
        self.token_tar.enable_padding(pad_id=self.token_tar.token_to_id('[PAD]'),pad_token='[PAD]',length=seq_len+1)
        self.token_src.post_processor=post_pr
        self.token_tar.post_processor=post_pr
    def __len__(self):
        return len(ds['test'])
    def __getitem__(self,idx):
        sample_src=ds['test']['sourceString'][idx]
        sample_tar=ds['test']['targetString'][idx]
        s_src=torch.tensor(self.token_src.encode(sample_src).ids,dtype=torch.float)
        t_src=self.token_tar.encode(sample_tar).ids
        t_src.remove(self.sos)
        label=torch.tensor(t_src,dtype=torch.float)
        t_src.remove(self.eos)
        t_src=torch.tensor([self.sos]+t_src,dtype=torch.float)
        encoder_mask=torch.tensor(self.token_src.encode(sample_src).attention_mask).unsqueeze(0).unsqueeze(0)
        decoder_mask=torch.tensor(self.token_tar.encode(sample_tar).attention_mask)
        eos_idx=decoder_mask.nonzero().max().item()
        decoder_mask=torch.concat([decoder_mask[:eos_idx],decoder_mask[eos_idx+1:]]).unsqueeze(0).unsqueeze(0)
        self_att_mask=torch.ones((self.seq_len,self.seq_len)).tril().int()&(decoder_mask)
        return {'encoder_input':s_src,'decoder_input':t_src,'label':label,
                'encoder_mask':encoder_mask,'decoder_mask':decoder_mask,'attention_mask':self_att_mask}
        

In [18]:
processor=TemplateProcessing(
    single='[SOS] $A [EOS]',
    special_tokens=[
        ('[SOS]',tokenizer_src.token_to_id('[SOS]')),
        ('[EOS]',tokenizer_tar.token_to_id('[EOS]')),
    ]
    )

In [19]:
seq_len=22
torch_ds=BiLingualDataset(ds,tokenizer_src,tokenizer_tar,seq_len,processor)
train_ds=DataLoader(torch_ds,batch_size=32,shuffle=True)

In [81]:
s_size=tokenizer_src.get_vocab_size()
t_size=tokenizer_tar.get_vocab_size()
s_len=t_len=22
num_blocks,d_model,h=4,256,8
epochs=50

In [86]:
model=Transformer(num_blocks,s_len,t_len,s_size,t_size,d_model,h,0.1)

In [87]:
def train_model(model,optim,loss_fn,data,epochs):
    t_loss=[]
    for epoch in range(epochs):
        model.train()
        for idx,input in enumerate(data):
            optim.zero_grad()
            logits=model(input)
            target=input['label'].long()
            b_size,seq_len=input['encoder_input'].shape[0],input['encoder_input'].shape[1]
            preds=logits.view(b_size,-1,seq_len)
            loss=loss_fn(preds,target)
            loss.backward()
            optim.step()
            t_loss.append(loss.item())
        print('Epoch:',epoch,'loss:',round(sum(t_loss)/len(t_loss),4))

In [88]:
loss_fn=nn.CrossEntropyLoss(ignore_index=0)
optimizer=torch.optim.Adam(model.parameters())

In [None]:
train_model(model,optimizer,loss_fn,train_ds,epochs)

Epoch: 0 loss: 5.2385
Epoch: 1 loss: 3.9777
Epoch: 2 loss: 3.1959
Epoch: 3 loss: 2.7639
Epoch: 4 loss: 2.5024
Epoch: 5 loss: 2.3199
