In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, q_input, k_input, v_input, mask=None):

        batch_size, max_sequence_length, _ = q_input.size()
        
        Q = self.query(q_input)
        K = self.key(k_input)
        V = self.value(v_input)
        print(f"*** QKV before MHA: *** ")
        print(f"Q.size(): {Q.size()}")
        print(f"K.size(): {K.size()}")
        print(f"V.size(): {V.size()}")
        q = Q.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k)
        k = K.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k)
        v = V.reshape(batch_size, max_sequence_length, self.num_heads, self.d_v)
     
        q = q.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_k]
        k = k.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_k]
        v = v.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_v]
        print(f"*** qkv after MHA: *** ")
        print(f"q.size(): {q.size()}")
        print(f"k.size(): {k.size()}")
        print(f"v.size(): {v.size()}")
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))  # Ensure mask matches input length
        attn_weights = F.softmax(attn_scores, dim=-1)
    
        attention_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, max_sequence_length, self.d_model)
       
        output = self.fc(attention_output)
        print(f"***  output after MHA: ***  {output.size()} ")
        return output

In [3]:
class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        print(f"x after first linear layer: {x.size()}")
        x = self.relu(x)
        print(f"x after activation: {x.size()}")
        x = self.dropout(x)
        print(f"x after dropout: {x.size()}")
        x = self.linear2(x)
        print(f"x after 2nd linear layer: {x.size()}")
        return x


In [4]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        ###BEGIN SOLUTION
        super(EncoderLayer, self).__init__()
        self.multihead_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.feedforward = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
        self.dropout2 = nn.Dropout(p=drop_prob)
        ###END SOLUTION
    def forward(self, x, mask=None):
        ### BEGIN SOLUTION
        
        attn_output = self.multihead_attn(x, x, x, mask=None)
        attn_output = self.dropout1(attn_output)
        x = self.norm1(attn_output + x)
        
        ff_output = self.feedforward(x)
        ff_output = self.dropout2(ff_output)
        enc_out = self.norm2(ff_output + x)
        
        return enc_out
        ###END SOLUTION

In [5]:
def test_encoder_layer():
    d_model = 16  # Model dimension
    num_heads = 4  # Number of attention heads
    seq_len = 5  # Sequence length
    batch_size = 2  # Batch size

    encoder_layer = EncoderLayer(d_model, 32, num_heads, 0.2)

    # Check if there is an instance of MultiHeadAttention in the layer
    mha_instances = [module for module in encoder_layer.modules() if isinstance(module, MultiHeadAttention)]
    assert len(mha_instances) > 0, "No MultiHeadAttention instance found in EncoderLayer"

    # Create a random input tensor
    x = torch.rand((batch_size, seq_len, d_model))

    # Capture intermediate outputs using hooks to detect multi-head attention calls
    activations = []

    def hook_fn(module, inp, out):
        activations.append(out)

    # Register hooks for all detected MultiHeadAttention instances
    handles = []
    for module in mha_instances:
        h = module.register_forward_hook(hook_fn)
        handles.append(h)

    # Forward pass through the encoder layer
    output = encoder_layer(x, mask=None)

    # Remove hooks after capturing activations
    for h in handles:
        h.remove()

    # Ensure at least one attention mechanism modified input
    assert len(activations) > 0, "MultiHeadAttention was not used in forward pass"
    assert activations[0].shape == x.shape, "Attention output shape mismatch"
    assert not torch.equal(activations[0], x), "MultiHeadAttention does not modify input"

    # Ensure LayerNorm is applied correctly by checking standard deviation per feature
    assert torch.all(output.std(dim=-1) > 0), "LayerNorm not applied correctly!"

    print("Test Case 3:EncoderLayer tests passed!")
    
test_encoder_layer()    

*** QKV before MHA: *** 
Q.size(): torch.Size([2, 5, 16])
K.size(): torch.Size([2, 5, 16])
V.size(): torch.Size([2, 5, 16])
*** qkv after MHA: *** 
q.size(): torch.Size([2, 4, 5, 4])
k.size(): torch.Size([2, 4, 5, 4])
v.size(): torch.Size([2, 4, 5, 4])
***  output after MHA: ***  torch.Size([2, 5, 16]) 
x after first linear layer: torch.Size([2, 5, 32])
x after activation: torch.Size([2, 5, 32])
x after dropout: torch.Size([2, 5, 32])
x after 2nd linear layer: torch.Size([2, 5, 16])
Test Case 3:EncoderLayer tests passed!


In [6]:
class Encoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers):
        super().__init__()
        self.layers = nn.Sequential(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                     for _ in range(num_layers)])

    def forward(self, x):
        x = self.layers(x)
        return x

In [7]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 100
ffn_hidden = 2048
num_layers = 2


encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)
     

In [8]:
x = torch.randn( (batch_size, max_sequence_length, d_model) ) # includes positional encoding
print(f"x input to encoder: {x.size()}")
out = encoder(x)

x input to encoder: torch.Size([30, 100, 512])
*** QKV before MHA: *** 
Q.size(): torch.Size([30, 100, 512])
K.size(): torch.Size([30, 100, 512])
V.size(): torch.Size([30, 100, 512])
*** qkv after MHA: *** 
q.size(): torch.Size([30, 8, 100, 64])
k.size(): torch.Size([30, 8, 100, 64])
v.size(): torch.Size([30, 8, 100, 64])
***  output after MHA: ***  torch.Size([30, 100, 512]) 
x after first linear layer: torch.Size([30, 100, 2048])
x after activation: torch.Size([30, 100, 2048])
x after dropout: torch.Size([30, 100, 2048])
x after 2nd linear layer: torch.Size([30, 100, 512])
*** QKV before MHA: *** 
Q.size(): torch.Size([30, 100, 512])
K.size(): torch.Size([30, 100, 512])
V.size(): torch.Size([30, 100, 512])
*** qkv after MHA: *** 
q.size(): torch.Size([30, 8, 100, 64])
k.size(): torch.Size([30, 8, 100, 64])
v.size(): torch.Size([30, 8, 100, 64])
***  output after MHA: ***  torch.Size([30, 100, 512]) 
x after first linear layer: torch.Size([30, 100, 2048])
x after activation: torch.Siz