In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, q_input, k_input, v_input, mask=None):

        batch_size, max_sequence_length, _ = q_input.size()
        
        Q = self.query(q_input)
        K = self.key(k_input)
        V = self.value(v_input)
        print(f"*** QKV before MHA: *** ")
        print(f"Q.size(): {Q.size()}")
        print(f"K.size(): {K.size()}")
        print(f"V.size(): {V.size()}")
        q = Q.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k)
        k = K.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k)
        v = V.reshape(batch_size, max_sequence_length, self.num_heads, self.d_v)
     
        q = q.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_k]
        k = k.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_k]
        v = v.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_v]
        print(f"*** qkv after MHA: *** ")
        print(f"q.size(): {q.size()}")
        print(f"k.size(): {k.size()}")
        print(f"v.size(): {v.size()}")
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))  # Ensure mask matches input length
        attn_weights = F.softmax(attn_scores, dim=-1)
    
        attention_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, max_sequence_length, self.d_model)
       
        output = self.fc(attention_output)
        print(f"***  output after MHA: ***  {output.size()} ")
        return output
    
### END SOLUTION

In [3]:
class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        print(f"x after first linear layer: {x.size()}")
        x = self.relu(x)
        print(f"x after activation: {x.size()}")
        x = self.dropout(x)
        print(f"x after dropout: {x.size()}")
        x = self.linear2(x)
        print(f"x after 2nd linear layer: {x.size()}")
        return x


In [4]:
class DecoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.encoder_decoder_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, enc_out_k, enc_out_v, decoder_mask):
        _x = x # 30 x 200 x 512
        print("MASKED SELF ATTENTION")
        x = self.self_attention(x, x, x, mask=decoder_mask) # 30 x 200 x 512
        print("DROP OUT 1")
        x = self.dropout1(x) # 30 x 200 x 512
        print("ADD + LAYER NORMALIZATION 1")
        x = self.norm1(x + _x) # 30 x 200 x 512

        _x = x # 30 x 200 x 512
        print("CROSS ATTENTION")
        x = self.encoder_decoder_attention(x, enc_out_k, enc_out_v, mask=None) #30 x 200 x 512
        print("DROP OUT 2")  #30 x 200 x 512
        x = self.dropout2(x)
        print("ADD + LAYER NORMALIZATION 2")
        x = self.norm2(x + _x) # 30 x 200 x 512

        _x = x # 30 x 200 x 512
        print("FEED FORWARD 1")
        x = self.ffn(x) #30 x 200 x 512
        print("DROP OUT 3")
        x = self.dropout3(x) #30 x 200 x 512
        print("ADD + LAYER NORMALIZATION 3")
        x = self.norm3(x + _x) #30 x 200 x 512
        return x #30 x 200 x 512

In [5]:
class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, enc_out_k, enc_out_v, mask = inputs
        for module in self._modules.values():
            y = module(x, enc_out_k, enc_out_v, mask) #30 x 200 x 512
        return y

In [6]:
class Decoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers):
        super().__init__()
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                     for _ in range(num_layers)])

    def forward(self, x, enc_out_k, enc_out_v, mask):
        x = self.layers(x, enc_out_k, enc_out_v, mask)
        return x

In [7]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 100
ffn_hidden = 2048
num_layers = 2

In [8]:
x = torch.randn( (batch_size, max_sequence_length, d_model) ) # English sentence positional encoded 
enc_out = torch.randn( (batch_size, max_sequence_length, d_model) ) # Telugu sentence - encoder outputs 

In [9]:
mask = torch.full([max_sequence_length, max_sequence_length] , float('-inf'))
mask = torch.triu(mask, diagonal=1)
decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)
out = decoder(x, enc_out, enc_out, mask)

MASKED SELF ATTENTION
*** QKV before MHA: *** 
Q.size(): torch.Size([30, 100, 512])
K.size(): torch.Size([30, 100, 512])
V.size(): torch.Size([30, 100, 512])
*** qkv after MHA: *** 
q.size(): torch.Size([30, 8, 100, 64])
k.size(): torch.Size([30, 8, 100, 64])
v.size(): torch.Size([30, 8, 100, 64])
***  output after MHA: ***  torch.Size([30, 100, 512]) 
DROP OUT 1
ADD + LAYER NORMALIZATION 1
CROSS ATTENTION
*** QKV before MHA: *** 
Q.size(): torch.Size([30, 100, 512])
K.size(): torch.Size([30, 100, 512])
V.size(): torch.Size([30, 100, 512])
*** qkv after MHA: *** 
q.size(): torch.Size([30, 8, 100, 64])
k.size(): torch.Size([30, 8, 100, 64])
v.size(): torch.Size([30, 8, 100, 64])
***  output after MHA: ***  torch.Size([30, 100, 512]) 
DROP OUT 2
ADD + LAYER NORMALIZATION 2
FEED FORWARD 1
x after first linear layer: torch.Size([30, 100, 2048])
x after activation: torch.Size([30, 100, 2048])
x after dropout: torch.Size([30, 100, 2048])
x after 2nd linear layer: torch.Size([30, 100, 512])
D