In [124]:
# importing required libraries
import torch.nn as nn
import torch
import torch.nn.functional as F
import warnings
import pandas as pd
import numpy as np
import seaborn as sns
import nltk
import matplotlib.pyplot as plt
warnings.simplefilter("ignore")
print(torch.__version__)

2.4.0


In [125]:
device="cuda" if torch.cuda.is_available() else 'cpu'
class positional_embedding(nn.Module):
    def __init__(self,max_len,embedding_size):
        super(positional_embedding,self).__init__()
        self.embed_size=embedding_size
        self.pe=torch.zeros(max_len,self.embed_size).to(device)
        for pos in range(max_len):
            for i in range(0,embedding_size,2):
                self.pe[pos,i]=np.sin(pos/(10000**((2*i)/embedding_size)))
                self.pe[pos,i+1]=np.cos(pos/(10000**((2*(i+1))/embedding_size)))
        self.pe=self.pe.unsqueeze(dim=0)
    def forward(self,x):
        x=x*np.sqrt(self.embed_size)
        seq_len=x.shape[1]
        x=x+self.pe[:,:seq_len]
        return x
        

In [126]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embedding_size,n_heads):
        super(MultiHeadAttention,self).__init__()
        self.embed_size=embedding_size
        self.n_heads=n_heads
        self.single_head_size=int(self.embed_size/n_heads)
        self.Q_matrix=nn.Linear(self.single_head_size,self.single_head_size,bias=False)
        self.K_matrix=nn.Linear(self.single_head_size,self.single_head_size,bias=False)
        self.V_matrix=nn.Linear(self.single_head_size,self.single_head_size,bias=False)
        self.out=nn.Linear(embedding_size,embedding_size)
    def forward(self,query,key,value,mask=None):
        batch_size=key.shape[0]

        seq_length=key.shape[1]

        seq_length_query=query.shape[1]

        key = key.view(batch_size, seq_length, self.n_heads, self.single_head_size)  #batch_size x sequence_length x n_heads x single_head_size = (32x10x8x64)
        query = query.view(batch_size, seq_length_query, self.n_heads, self.single_head_size) #(32x10x8x64)
        value = value.view(batch_size, seq_length, self.n_heads, self.single_head_size) #(32x10x8x64)
        k = self.K_matrix(key)       # (32x10x8x64)
        q = self.Q_matrix(query)   
        v = self.V_matrix(value)

        q = q.permute(0,2,1,3)  # (batch_size, n_heads, seq_len, single_head_size)    # (32 x 8 x 10 x 64)
        k = k.permute(0,2,1,3)  # (batch_size, n_heads, seq_len, single_head_size)
        v = v.permute(0,2,1,3) 
        
        k_adjusted = k.transpose(-1,-2)  #(batch_size, n_heads, single_head_size, seq_ken)  #(32 x 8 x 64 x 10)
        dot=torch.matmul(q,k_adjusted)
        if mask is not None:
            dot = dot.masked_fill(mask == 0, float("-1e20"))
        dot=dot/np.sqrt(self.single_head_size)
        
        scores=F.softmax(dot,dim=-1)
        
        scores=torch.matmul(scores,v)
        
        concat=scores.permute(0,2,1,3)
        
        concat=concat.reshape(batch_size, seq_length_query, self.single_head_size*self.n_heads)
        
        output=self.out(concat)
        
        return output
        


In [127]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_size,expansion_factor,n_heads):
        super(TransformerBlock, self).__init__()
        self.attention=MultiHeadAttention(embed_size,n_heads)
        self.NN=nn.Sequential(
            nn.Linear(embed_size,embed_size*expansion_factor),
            nn.ReLU(),
            nn.Linear(embed_size*expansion_factor,embed_size)
        )
        self.norm1=nn.LayerNorm(embed_size)
        self.norm2=nn.LayerNorm(embed_size)
        self.dropout1=nn.Dropout(0.2)
        self.dropout2=nn.Dropout(0.2)
    def forward(self,key,query,value):
        value1=self.attention(query,key,value)
        value1=self.dropout1(self.norm1(value1+value))
        value2=self.NN(value1)
        value2=self.dropout2(self.norm2(value2+value1))
        return value2
        


In [128]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,max_len,embedding_size,num_layers,expansion_factor=4,n_heads=8):
        super(Encoder,self).__init__()
        self.blocks = nn.ModuleList([TransformerBlock(embedding_size, expansion_factor, n_heads) for i in range(num_layers)])
        self.embedding=nn.Embedding(vocab_size,embedding_size)
        self.positional=positional_embedding(max_len,embedding_size)
    def forward(self,x):
        embed_out=self.embedding(x)
        out=self.positional(embed_out)
        for block in self.blocks:
            out=block(out,out,out)
        return out

In [129]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, expansion_factor=4, n_heads=8):
        super(DecoderBlock, self).__init__()
       
        self.attention=MultiHeadAttention(embed_size,n_heads)
       
        self.norm=nn.LayerNorm(embed_size)
       
        self.dropout=nn.Dropout(0.2)
       
        self.transformer_block=TransformerBlock(embed_size,expansion_factor,n_heads)
    def forward(self,query,key,x,mask):
        attention=self.attention(x,x,x,mask=mask)
       
        value=self.dropout(self.norm(attention+x))
       
        out=self.transformer_block(query,key,value)
       
        return out
class TransfomerDecoder(nn.Module):
    def __init__(self,num_layers,target_vocab_size,max_len,embed_size,expansion_factor=4,n_heads=8):
        super(TransfomerDecoder,self).__init__()
        self.positional=positional_embedding(max_len,embed_size)
        
        self.embedding=nn.Embedding(target_vocab_size,embed_size)
        
        self.fc_out = nn.Linear(embed_size, target_vocab_size)

        self.dropout = nn.Dropout(0.2)

        self.blocks=nn.ModuleList([DecoderBlock(embed_size,expansion_factor,n_heads) for i in range(num_layers)])
    def forward(self,x,enc_out,mask):
        x = self.embedding(x)  #32x10x512
        x = self.positional(x) #32x10x512
        x = self.dropout(x)
     
        for layer in self.blocks:
            x = layer(x,enc_out, enc_out, mask) 

        out = F.softmax(self.fc_out(x))

        return out

    

In [130]:
class Transformer(nn.Module):
    def __init__(self,embed_size,num_layers,vocab_size,target_vocab_size,max_len,n_heads=8):
        super(Transformer,self).__init__()
        self.target_vocab_size=target_vocab_size
        self.encoder=Encoder(vocab_size,max_len,embed_size,num_layers) 
        self.decoder=TransfomerDecoder(num_layers,target_vocab_size,max_len,embed_size)
    def make_trg_mask(self, trg):
        batch_size, trg_len = trg.shape
        # returns the lower triangular part of matrix filled with ones
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            batch_size, 1, trg_len, trg_len
        ).to(device)
        return trg_mask  
    def forward(self,encoder_input,decoder_input):
        mask=self.make_trg_mask(decoder_input)
        encoder_output=self.encoder(encoder_input)
        output=self.decoder(decoder_input,encoder_output,mask)
        return output

In [131]:
src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length= 12
n_heads=8

# let 0 be sos token and 1 be eos token
src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1], 
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1], 
                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

print(src.shape,target.shape)
model = Transformer(embed_size=512,num_layers=num_layers,vocab_size=src_vocab_size, 
                    target_vocab_size=target_vocab_size, max_len=seq_length,n_heads=n_heads).to(device)


torch.Size([2, 12]) torch.Size([2, 12])


In [132]:
src_data = torch.randint(1, src_vocab_size, (64, seq_length)).to(device)  # (batch_size, seq_length)
tgt_data = torch.randint(1, target_vocab_size, (64, seq_length)).to(device)
src_data.shape

torch.Size([64, 12])

In [135]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
model.train()
from tqdm.auto import tqdm
for epoch in tqdm(range(100)):
    optimizer.zero_grad()
    output = model(src_data, tgt_data[:, :])
    loss = criterion(output.contiguous().view(-1, target_vocab_size), tgt_data[:,:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 1, Loss: 2.3727633953094482
Epoch: 2, Loss: 2.3690943717956543
Epoch: 3, Loss: 2.3746025562286377
Epoch: 4, Loss: 2.374056100845337
Epoch: 5, Loss: 2.3723132610321045
Epoch: 6, Loss: 2.3733551502227783
Epoch: 7, Loss: 2.3721957206726074
Epoch: 8, Loss: 2.372122287750244
Epoch: 9, Loss: 2.371638298034668
Epoch: 10, Loss: 2.37229323387146
Epoch: 11, Loss: 2.3688580989837646
Epoch: 12, Loss: 2.370147943496704
Epoch: 13, Loss: 2.3694193363189697
Epoch: 14, Loss: 2.3709821701049805
Epoch: 15, Loss: 2.370960235595703
Epoch: 16, Loss: 2.3703792095184326
Epoch: 17, Loss: 2.3674182891845703
Epoch: 18, Loss: 2.367587089538574
Epoch: 19, Loss: 2.3661177158355713
Epoch: 20, Loss: 2.367017984390259
Epoch: 21, Loss: 2.36651349067688
Epoch: 22, Loss: 2.3698039054870605
Epoch: 23, Loss: 2.367149829864502
Epoch: 24, Loss: 2.363593101501465
Epoch: 25, Loss: 2.3651373386383057
Epoch: 26, Loss: 2.365278482437134
Epoch: 27, Loss: 2.3649349212646484
Epoch: 28, Loss: 2.36445689201355
Epoch: 29, Loss: 

'cuda'