In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, q_input, k_input, v_input, mask=None):

        batch_size, max_sequence_length, _ = q_input.size()
        
        Q = self.query(q_input)
        K = self.key(k_input)
        V = self.value(v_input)
        print(f"*** QKV before MHA: *** ")
        print(f"Q.size(): {Q.size()}")
        print(f"K.size(): {K.size()}")
        print(f"V.size(): {V.size()}")
        q = Q.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k)
        k = K.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k)
        v = V.reshape(batch_size, max_sequence_length, self.num_heads, self.d_v)
     
        q = q.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_k]
        k = k.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_k]
        v = v.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_v]
        print(f"*** qkv after MHA: *** ")
        print(f"q.size(): {q.size()}")
        print(f"k.size(): {k.size()}")
        print(f"v.size(): {v.size()}")
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))  # Ensure mask matches input length
        attn_weights = F.softmax(attn_scores, dim=-1)
    
        attention_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, max_sequence_length, self.d_model)
       
        output = self.fc(attention_output)
        print(f"***  output after MHA: ***  {output.size()} ")
        return output
    
### END SOLUTION

In [3]:
class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        print(f"x after first linear layer: {x.size()}")
        x = self.relu(x)
        print(f"x after activation: {x.size()}")
        x = self.dropout(x)
        print(f"x after dropout: {x.size()}")
        x = self.linear2(x)
        print(f"x after 2nd linear layer: {x.size()}")
        return x


In [4]:
# Step 6: Decoder Layer
class DecoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        self.masked_multihead_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        ###BEGIN SOLUTION
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout1 = nn.Dropout(p=drop_prob)
        ###END SOLUTION
        self.encoder_decoder_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        ###BEGIN SOLUTION
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout2 = nn.Dropout(p=drop_prob)
        self.feedforward = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm3 = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout3 = nn.Dropout(p=drop_prob)
        ###END SOLUTION

    def forward(self, x, enc_out_k, enc_out_v, decoder_mask):
        attn_output1 = self.masked_multihead_attn(x, x, x, mask=decoder_mask) 
        attn_output1 = self.dropout1(attn_output1) 
        x = self.norm1(attn_output1 + x) 

        attn_output2 = self.encoder_decoder_attention(x, enc_out_k, enc_out_v, mask=None) 
        attn_output2 = self.dropout2(attn_output2)
        x = self.norm2(attn_output2 + x) 

        ff_output = self.feedforward(x) 
        ff_output = self.dropout3(ff_output) 
        dec_out = self.norm3(ff_output + x) 
        
        return dec_out

In [5]:
def test_decoder_layer():
    d_model = 512
    num_heads = 8
    seq_length = 10
    batch_size = 4
    drop_prob=0.2 
    num_heads = 8
    ffn_hidden = 2048
    
    # Create a random tensor as input
    x = torch.rand(batch_size, seq_length, d_model)
    encoder_output = torch.rand(batch_size, seq_length, d_model)
     # Create a random mask
    mask = torch.randint(0, 2, (batch_size, num_heads, seq_length, seq_length))
    
    # Initialize the DecoderLayer
    decoder_layer = DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
    # Check if there is an instance of MultiHeadAttention in the layer
    norm_instances = [module for module in decoder_layer.modules() if isinstance(module, nn.LayerNorm)]
    assert len(norm_instances) == 3, "Missing LayerNorm instance(s) in EncoderLayer"
    ff_instance = [module for module in decoder_layer.modules() if isinstance(module, PositionwiseFeedForward)]
    assert len(ff_instance) > 0, "Missing FeedForward instance in EncoderLayer"

    # Forward pass
    output = decoder_layer(x, encoder_output, encoder_output, mask)
    
    # Assertions
    assert output.shape == x.shape, "Output shape mismatch"
    assert not torch.isnan(output).any(), "Output contains NaN values"

# Run the test
test_decoder_layer()
print("Test Case DecoderLayer Passed")

*** QKV before MHA: *** 
Q.size(): torch.Size([4, 10, 512])
K.size(): torch.Size([4, 10, 512])
V.size(): torch.Size([4, 10, 512])
*** qkv after MHA: *** 
q.size(): torch.Size([4, 8, 10, 64])
k.size(): torch.Size([4, 8, 10, 64])
v.size(): torch.Size([4, 8, 10, 64])
***  output after MHA: ***  torch.Size([4, 10, 512]) 
*** QKV before MHA: *** 
Q.size(): torch.Size([4, 10, 512])
K.size(): torch.Size([4, 10, 512])
V.size(): torch.Size([4, 10, 512])
*** qkv after MHA: *** 
q.size(): torch.Size([4, 8, 10, 64])
k.size(): torch.Size([4, 8, 10, 64])
v.size(): torch.Size([4, 8, 10, 64])
***  output after MHA: ***  torch.Size([4, 10, 512]) 
x after first linear layer: torch.Size([4, 10, 2048])
x after activation: torch.Size([4, 10, 2048])
x after dropout: torch.Size([4, 10, 2048])
x after 2nd linear layer: torch.Size([4, 10, 512])
Test Case DecoderLayer Passed


In [6]:
class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, enc_out_k, enc_out_v, mask = inputs
        for module in self._modules.values():
            y = module(x, enc_out_k, enc_out_v, mask) #30 x 200 x 512
        return y

In [7]:
class Decoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers):
        super().__init__()
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                     for _ in range(num_layers)])

    def forward(self, x, enc_out_k, enc_out_v, mask):
        x = self.layers(x, enc_out_k, enc_out_v, mask)
        return x

In [8]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 100
ffn_hidden = 2048
num_layers = 2

In [9]:
x = torch.randn( (batch_size, max_sequence_length, d_model) ) # Telugu sentence positional encoded 
enc_out = torch.randn( (batch_size, max_sequence_length, d_model) ) # Telugu sentence - encoder outputs 

In [10]:
mask = torch.tril(torch.ones((max_sequence_length, max_sequence_length)))
decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)
out = decoder(x, enc_out, enc_out, mask)

*** QKV before MHA: *** 
Q.size(): torch.Size([30, 100, 512])
K.size(): torch.Size([30, 100, 512])
V.size(): torch.Size([30, 100, 512])
*** qkv after MHA: *** 
q.size(): torch.Size([30, 8, 100, 64])
k.size(): torch.Size([30, 8, 100, 64])
v.size(): torch.Size([30, 8, 100, 64])
***  output after MHA: ***  torch.Size([30, 100, 512]) 
*** QKV before MHA: *** 
Q.size(): torch.Size([30, 100, 512])
K.size(): torch.Size([30, 100, 512])
V.size(): torch.Size([30, 100, 512])
*** qkv after MHA: *** 
q.size(): torch.Size([30, 8, 100, 64])
k.size(): torch.Size([30, 8, 100, 64])
v.size(): torch.Size([30, 8, 100, 64])
***  output after MHA: ***  torch.Size([30, 100, 512]) 
x after first linear layer: torch.Size([30, 100, 2048])
x after activation: torch.Size([30, 100, 2048])
x after dropout: torch.Size([30, 100, 2048])
x after 2nd linear layer: torch.Size([30, 100, 512])
*** QKV before MHA: *** 
Q.size(): torch.Size([30, 100, 512])
K.size(): torch.Size([30, 100, 512])
V.size(): torch.Size([30, 100, 51