# ENCODER-DECODER MHA WITH MASKING USING IN-BUILT PYTHON



### Training part of the transformer


In [38]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.shape[1]].detach()


class TransformerModel1(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel1, self).__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):

        src_mask = None
        seq_length = tgt.size(0)
        
        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        print("Tgt mask shape = ", tgt_mask.shape)

        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)


        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask, tgt_is_causal = False)

        output = self.fc(output)

        print("Fully connected layer op = ", output)
        print()

        
        return output
    

In [None]:
torch.manual_seed(0)

src_vocab_size = 20
tgt_vocab_size = 20
d_model = 16
num_heads = 4
num_encoder_layers = 1
num_decoder_layers = 1
d_ff = 20
max_seq_len = 5
dropout = 0

transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)

# Generate random sample data
# src_data = torch.randint(1, src_vocab_size, (max_seq_len , 3))  # (seq_length, batch_size,)
# tgt_data = torch.randint(1, tgt_vocab_size, ( max_seq_len, 3))  # (seq_length, batch_size)

src_data = torch.tensor([[0, 2, 4], [1, 0, 7], [2, 2, 0], [3, 5, 6], [6, 1, 9]])
tgt_data = torch.tensor([[1, 7, 9], [3, 4, 1], [5, 2, 8], [8, 0, 3], [4, 5, 9]])  # Target sequence

# src_data = torch.tensor([[2], [1], [5], [4]])
# tgt_data = torch.tensor([[1], [16], [5], [3], [9]]) 


state_dict = transformer.state_dict()

In [40]:
import copy

state_dict1 = copy.deepcopy(state_dict)

In [41]:
src_data.shape, tgt_data.shape

(torch.Size([5, 3]), torch.Size([5, 3]))

In [42]:
tgt_data.view(-1)
tgt_data.shape

torch.Size([5, 3])

In [43]:
src_data.shape, tgt_data.shape

(torch.Size([5, 3]), torch.Size([5, 3]))

In [44]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(1):

    optimizer.zero_grad()

    print("FWD PASS START\n")
    print("src_data shape = ", src_data.shape)
    print("tgt_data shape = ", tgt_data[:-1, :].shape)
    
    output = transformer(src_data, tgt_data[:-1, :])
    print("FWD PASS END\n")

    print("output shape = ",output.shape)

    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[1:, :].contiguous().view(-1))


    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

FWD PASS START

src_data shape =  torch.Size([5, 3])
tgt_data shape =  torch.Size([4, 3])
Tgt mask shape =  torch.Size([4, 4])
torch.Size([5, 3, 16]) torch.Size([4, 3, 16])
MASK =  None
query =  tensor([[[-1.1258, -0.1524, -0.2506,  0.5661,  0.8487,  1.6920, -0.3160,
          -1.1152,  0.3223, -0.2633,  0.3500,  1.3081,  0.1198,  2.2377,
           1.1168,  0.7527],
         [ 0.2279,  0.5719, -0.1817,  1.1988,  0.5395,  1.1074,  0.6724,
           1.4407, -0.0923,  1.7924, -0.2865,  1.0525,  0.5239,  3.3022,
          -1.4686, -0.5867],
         [ 0.3400,  0.5038,  1.7019,  2.0965, -1.2795,  3.5473, -0.4099,
           1.3336, -1.6093,  0.4501, -0.4735,  0.5003, -1.0650,  2.1149,
          -0.1400,  1.8058]],

        [[-1.3527, -0.6959,  0.5667,  1.7935,  0.5988, -0.5551, -0.3414,
           2.8530,  0.7502,  0.4145, -0.1734,  1.1835,  1.3894,  2.5863,
           0.9463,  0.1563],
         [-0.2844, -0.6121,  0.0604,  0.5165,  0.9485,  1.6870, -0.2844,
          -1.1157,  0.3323, -0

In [46]:
src_data, tgt_data[:-1, :]

(tensor([[0, 2, 4],
         [1, 0, 7],
         [2, 2, 0],
         [3, 5, 6],
         [6, 1, 9]]),
 tensor([[1, 7, 9],
         [3, 4, 1],
         [5, 2, 8],
         [8, 0, 3]]))