In [18]:
# importing required libraries
import torch.nn as nn
import torch
import torch.nn.functional as F
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
warnings.simplefilter("ignore")
print(torch.__version__)

2.4.0


In [19]:
device="cuda" if torch.cuda.is_available() else 'cpu'
class positional_embedding(nn.Module):
    def __init__(self,max_len,embedding_size):
        super(positional_embedding,self).__init__()
        self.embed_size=embedding_size
        self.pe=torch.zeros(max_len,self.embed_size).to(device)
        for pos in range(max_len):
            for i in range(0,embedding_size,2):
                self.pe[pos,i]=np.sin(pos/(10000**((2*i)/embedding_size)))
                self.pe[pos,i+1]=np.cos(pos/(10000**((2*(i+1))/embedding_size)))
        self.pe=self.pe.unsqueeze(dim=0)
    def forward(self,x):
        x=x*np.sqrt(self.embed_size)
        seq_len=x.shape[1]
        x=x+self.pe[:,:seq_len]
        return x
        

In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embedding_size,n_heads):
        super(MultiHeadAttention,self).__init__()
        self.embed_size=embedding_size
        self.n_heads=n_heads
        self.single_head_size=int(self.embed_size/n_heads)
        self.Q_matrix=nn.Linear(self.single_head_size,self.single_head_size,bias=False)
        self.K_matrix=nn.Linear(self.single_head_size,self.single_head_size,bias=False)
        self.V_matrix=nn.Linear(self.single_head_size,self.single_head_size,bias=False)
        self.out=nn.Linear(embedding_size,embedding_size)
    def forward(self,query,key,value,mask=None):
        batch_size=key.shape[0]

        seq_length=key.shape[1]

        seq_length_query=query.shape[1]

        key = key.view(batch_size, seq_length, self.n_heads, self.single_head_size)  
        query = query.view(batch_size, seq_length_query, self.n_heads, self.single_head_size)
        value = value.view(batch_size, seq_length, self.n_heads, self.single_head_size) 
        k = self.K_matrix(key)    
        q = self.Q_matrix(query)   
        v = self.V_matrix(value)

        q = q.permute(0,2,1,3) 
        k = k.permute(0,2,1,3)  
        v = v.permute(0,2,1,3) 
        
        k_adjusted = k.transpose(-1,-2) 
        dot=torch.matmul(q,k_adjusted)
        if mask is not None:
            dot = dot.masked_fill(mask == 0, float("-1e20"))
        dot=dot/np.sqrt(self.single_head_size)
        
        scores=F.softmax(dot,dim=-1)
        
        scores=torch.matmul(scores,v)
        
        concat=scores.permute(0,2,1,3)
        
        concat=concat.reshape(batch_size, seq_length_query, self.single_head_size*self.n_heads)
        
        output=self.out(concat)
        
        return output
        


In [21]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_size,expansion_factor,n_heads):
        super(TransformerBlock, self).__init__()
        self.attention=MultiHeadAttention(embed_size,n_heads)
        self.NN=nn.Sequential(
            nn.Linear(embed_size,embed_size*expansion_factor),
            nn.ReLU(),
            nn.Linear(embed_size*expansion_factor,embed_size)
        )
        self.norm1=nn.LayerNorm(embed_size)
        self.norm2=nn.LayerNorm(embed_size)
        self.dropout1=nn.Dropout(0.2)
        self.dropout2=nn.Dropout(0.2)
    def forward(self,key,query,value):
        value1=self.attention(query,key,value)
        value1=self.dropout1(self.norm1(value1+value))
        value2=self.NN(value1)
        value2=self.dropout2(self.norm2(value2+value1))
        return value2
        


In [22]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,max_len,embedding_size,num_layers,expansion_factor=4,n_heads=8):
        super(Encoder,self).__init__()
        self.blocks = nn.ModuleList([TransformerBlock(embedding_size, expansion_factor, n_heads) for i in range(num_layers)])
        self.embedding=nn.Embedding(vocab_size,embedding_size)
        self.positional=positional_embedding(max_len,embedding_size)
    def forward(self,x):
        embed_out=self.embedding(x)
        out=self.positional(embed_out)
        for block in self.blocks:
            out=block(out,out,out)
        return out

In [23]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, expansion_factor=4, n_heads=8):
        super(DecoderBlock, self).__init__()
       
        self.attention=MultiHeadAttention(embed_size,n_heads)
       
        self.norm=nn.LayerNorm(embed_size)
       
        self.dropout=nn.Dropout(0.2)
       
        self.transformer_block=TransformerBlock(embed_size,expansion_factor,n_heads)
    def forward(self,query,key,x,mask):
        attention=self.attention(x,x,x,mask=mask)
       
        value=self.dropout(self.norm(attention+x))
       
        out=self.transformer_block(query,key,value)
       
        return out
class TransfomerDecoder(nn.Module):
    def __init__(self,num_layers,target_vocab_size,max_len,embed_size,expansion_factor=4,n_heads=8):
        super(TransfomerDecoder,self).__init__()
        self.positional=positional_embedding(max_len,embed_size)
        
        self.embedding=nn.Embedding(target_vocab_size,embed_size)
        
        self.fc_out = nn.Linear(embed_size, target_vocab_size)

        self.dropout = nn.Dropout(0.2)

        self.blocks=nn.ModuleList([DecoderBlock(embed_size,expansion_factor,n_heads) for i in range(num_layers)])
    def forward(self,x,enc_out,mask):
        x = self.embedding(x)  
        x = self.positional(x) 
        x = self.dropout(x)
     
        for layer in self.blocks:
            x = layer(x,enc_out, enc_out, mask) 

        out = F.softmax(self.fc_out(x))

        return out

    

In [24]:
class Transformer(nn.Module):
    def __init__(self,embed_size,num_layers,vocab_size,target_vocab_size,max_len,n_heads=8):
        super(Transformer,self).__init__()
        self.target_vocab_size=target_vocab_size
        self.encoder=Encoder(vocab_size,max_len,embed_size,num_layers) 
        self.decoder=TransfomerDecoder(num_layers,target_vocab_size,max_len,embed_size)
    def make_trg_mask(self, trg):
        batch_size, trg_len = trg.shape
    
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            batch_size, 1, trg_len, trg_len
        ).to(device)
        return trg_mask  
    def forward(self,encoder_input,decoder_input):
        mask=self.make_trg_mask(decoder_input)
        encoder_output=self.encoder(encoder_input)
        output=self.decoder(decoder_input,encoder_output,mask)
        return output

In [25]:
src_vocab_size = 11
target_vocab_size = 11
num_layers = 6
seq_length= 12
n_heads=8


src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1], 
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1], 
                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])

print(src.shape,target.shape)
model = Transformer(embed_size=512,num_layers=num_layers,vocab_size=src_vocab_size, 
                    target_vocab_size=target_vocab_size, max_len=seq_length,n_heads=n_heads).to(device)


torch.Size([2, 12]) torch.Size([2, 12])


In [26]:
src_data = torch.randint(1, src_vocab_size, (64, seq_length)).to(device)  
tgt_data = torch.randint(1, target_vocab_size, (64, seq_length)).to(device)
src_data.shape

torch.Size([64, 12])

In [27]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
model.train()
epochs=200
from tqdm.auto import tqdm
for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()
    output = model(src_data, tgt_data[:, :])
    loss = criterion(output.view(-1, target_vocab_size), tgt_data[:,:].view(-1))
    loss.backward()
    optimizer.step()
    if epoch %20==0:
        print(f"Epoch: {epoch}, Loss: {loss.item()}")

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0, Loss: 2.3974502086639404
Epoch: 20, Loss: 2.395129680633545
Epoch: 40, Loss: 2.389923095703125
Epoch: 60, Loss: 2.3847243785858154
Epoch: 80, Loss: 2.3738722801208496
Epoch: 100, Loss: 2.365818500518799
Epoch: 120, Loss: 2.3568003177642822
Epoch: 140, Loss: 2.34832501411438
Epoch: 160, Loss: 2.3421144485473633
Epoch: 180, Loss: 2.334085702896118
