In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import random
random.seed(24)  # Python random seed
np.random.seed(24)  # NumPy seed
torch.manual_seed(24)  # PyTorch seed (CPU)

<torch._C.Generator at 0x7fd37bc9b8b0>

In [3]:
# Set print options: No scientific notation, 2 decimal places
torch.set_printoptions(sci_mode=False, precision=4)

# Task-1: Implement MultiHeadAttention class

In [4]:
# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
    ### BEGIN SOLUTION
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        self.fc = nn.Linear(d_model, d_model)
    ### END SOLUTION
    
    def forward(self, q_input, k_input, v_input, mask=None):
    ### BEGIN SOLUTION
        batch_size, max_sequence_length, _ = q_input.size()
        
        Q = self.query(q_input)
        K = self.key(k_input)
        V = self.value(v_input)
        
        q = Q.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k)
        k = K.reshape(batch_size, max_sequence_length, self.num_heads, self.d_k)
        v = V.reshape(batch_size, max_sequence_length, self.num_heads, self.d_v)
        
        q = q.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_k]
        k = k.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_k]
        v = v.transpose(1, 2) # [batch_size, num_heads, max_sequence_length, d_v]

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))  # Ensure mask matches input length
        attn_weights = F.softmax(attn_scores, dim=-1)
    
        attention_output = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, max_sequence_length, self.d_model)
       
        output = self.fc(attention_output)
        return output
    ### END SOLUTION

In [5]:
# Define a simple test function
def test_multi_head_attention():
    d_model_t = 8  # Small for testing
    num_heads_t = 2
    batch_size_t = 4
    max_seq_length_t = 5

    mha_t = MultiHeadAttention(d_model_t, num_heads_t)

    q_input_t = torch.randn(batch_size_t, max_seq_length_t, d_model_t)
    k_input_t = torch.randn(batch_size_t, max_seq_length_t, d_model_t)
    v_input_t = torch.randn(batch_size_t, max_seq_length_t, d_model_t)

    # Forward pass
    output_t = mha_t(q_input_t, k_input_t, v_input_t)

    # Check output shape
    assert output_t.shape == (batch_size_t, max_seq_length_t, d_model_t), \
        f"Expected shape {(batch_size_t, max_seq_length_t, d_model_t)}, got {output_t.shape}"

    # Check if output is differentiable (should require gradients)
    assert output_t.requires_grad, "Output should require gradients"

    # Ensure attention weights sum to 1
    with torch.no_grad():
        q_proj_t = mha_t.query(q_input_t).reshape(batch_size_t, max_seq_length_t, num_heads_t, mha_t.d_k).transpose(1, 2)
        k_proj_t = mha_t.key(k_input_t).reshape(batch_size_t, max_seq_length_t, num_heads_t, mha_t.d_k).transpose(1, 2)

        attn_scores_t = torch.matmul(q_proj_t, k_proj_t.transpose(-2, -1)) / torch.sqrt(torch.tensor(mha_t.d_k, dtype=torch.float))
        attn_weights_t = F.softmax(attn_scores_t, dim=-1)

        assert torch.allclose(attn_weights_t.sum(dim=-1), torch.ones_like(attn_weights_t.sum(dim=-1))), \
            "Attention weights should sum to 1 across the sequence dimension"

    print("All tests passed!")

# Run the test function
test_multi_head_attention()


All tests passed!


In [6]:
d_model = 16  # Small model for testing
num_heads = 4  # Number of heads
max_sequence_length = 5  # Sequence length
batch_size = 2  # Batch size

In [7]:
# Create a random input tensor
x = torch.rand((batch_size, max_sequence_length, d_model))

# Self-Attention in the Encoder
## Purpose: Allows each token to attend to all other tokens in the input sequence.

### Masking: No causal masking (tokens can see all positions).

#### Input: x (same for query, key, value).

#### Mask: Usually a padding mask (not needed for random input).

In [8]:
mha_enc = MultiHeadAttention(d_model, num_heads)

mask = None

# Forward pass
output = mha_enc(x, x, x, mask)
output.shape

torch.Size([2, 5, 16])

# Masked Self-Attention in the Decoder
## Purpose: Prevents each token from attending to future tokens.

### Masking: Causal mask applied.

### Input: x (same for query, key, value).

### Mask: Lower triangular mask to enforce causality.

In [9]:
mha_dec = MultiHeadAttention(d_model, num_heads)

mask = torch.tril(torch.ones((max_sequence_length, max_sequence_length)))

# Forward pass
output = mha_dec(x, x, x, mask)
output.shape

torch.Size([2, 5, 16])

# Encoder-Decoder Cross-Attention in the Decoder
## Purpose: Allows decoder tokens to attend to all encoder tokens.

## Masking: No causal mask (attends to all encoder tokens).

### Input:

#### q (decoder representation).

#### k, v (encoder output).

In [10]:
# Initialize Multi-Head Attention
mha = MultiHeadAttention(d_model, num_heads)

# Create random input tensors
decoder_input = torch.rand((batch_size, max_sequence_length, d_model))  # Query from decoder
encoder_output = torch.rand((batch_size, max_sequence_length, d_model))  # Key & Value from encoder

# No causal mask needed for encoder-decoder cross-attention
mask = None

# Forward pass
cross_attn_output = mha(decoder_input, encoder_output, encoder_output, mask)
print(cross_attn_output.shape)  # Expected: (batch_size, seq_len, d_model)

torch.Size([2, 5, 16])


# Understanding Contiguous and Non-Contiguous Tensors in PyTorch

In [11]:
# Create a contiguous tensor
x = torch.randn(2, 3)
print("Original tensor:\n", x)
print("Is contiguous?", x.is_contiguous())  # True

# Transpose it (creates a non-contiguous tensor)
x_t = x.transpose(0, 1)
print("\nTransposed tensor:\n", x_t)
print("Is contiguous?", x_t.is_contiguous())  # False

Original tensor:
 tensor([[ 0.8104,  0.0467,  1.0917],
        [-0.5887,  0.4677,  1.1150]])
Is contiguous? True

Transposed tensor:
 tensor([[ 0.8104, -0.5887],
        [ 0.0467,  0.4677],
        [ 1.0917,  1.1150]])
Is contiguous? False
