## Define MultiHeadAttention Class




In [None]:
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        print("MultiHeadAttention class initialized with d_model, num_heads, and dropout_rate.")

## Initialize Weights


In [None]:
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # Ensure d_model is divisible by num_heads
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_k = d_model // num_heads

        # Linear layers for Wq, Wk, Wv, and Wo
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

        # Dropout layer
        self.dropout = nn.Dropout(dropout_rate)

        print("MultiHeadAttention class initialized with Wq, Wk, Wv, Wo, and dropout layers.")

## Implement Split Heads Function


In [None]:
import torch.nn as nn
import torch

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # Ensure d_model is divisible by num_heads
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_k = d_model // num_heads

        # Linear layers for Wq, Wk, Wv, and Wo
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

        # Dropout layer
        self.dropout = nn.Dropout(dropout_rate)

        print("MultiHeadAttention class initialized with Wq, Wk, Wv, Wo, and dropout layers.")

    def _split_heads(self, x, batch_size):
        # Reshape from (batch_size, sequence_length, d_model) to
        # (batch_size, sequence_length, num_heads, d_k)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        # Transpose to (batch_size, num_heads, sequence_length, d_k)
        return x.transpose(1, 2)

    # The forward method will be added in a subsequent step

## Calculate Scaled Dot-Product Attention




In [None]:
import torch.nn as nn
import torch
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # Ensure d_model is divisible by num_heads
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_k = d_model // num_heads

        # Linear layers for Wq, Wk, Wv, and Wo
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

        # Dropout layer
        self.dropout = nn.Dropout(dropout_rate)

        print("MultiHeadAttention class initialized with Wq, Wk, Wv, Wo, and dropout layers.")

    def _split_heads(self, x, batch_size):
        # Reshape from (batch_size, sequence_length, d_model) to
        # (batch_size, sequence_length, num_heads, d_k)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        # Transpose to (batch_size, num_heads, sequence_length, d_k)
        return x.transpose(1, 2)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate Q.K^T
        # Q: (batch_size, num_heads, sequence_length, d_k)
        # K.transpose(-2, -1): (batch_size, num_heads, d_k, sequence_length)
        # scores: (batch_size, num_heads, sequence_length, sequence_length)
        scores = torch.matmul(Q, K.transpose(-2, -1))

        # Scale scores
        # Divide by square root of d_k
        scaled_scores = scores / math.sqrt(self.d_k)

        # Apply mask (if provided)
        if mask is not None:
            # mask: (batch_size, 1, 1, sequence_length) or (batch_size, 1, sequence_length, sequence_length)
            # where 0 indicates masked positions
            scaled_scores = scaled_scores.masked_fill(mask == 0, -1e9)

        return scaled_scores

    # The forward method will be added in a subsequent step

## Apply Softmax and Multiply by Values





In [None]:
import torch.nn as nn
import torch
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # Ensure d_model is divisible by num_heads
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_k = d_model // num_heads

        # Linear layers for Wq, Wk, Wv, and Wo
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

        # Dropout layer
        self.dropout = nn.Dropout(dropout_rate)

        print("MultiHeadAttention class initialized with Wq, Wk, Wv, Wo, and dropout layers.")

    def _split_heads(self, x, batch_size):
        # Reshape from (batch_size, sequence_length, d_model) to
        # (batch_size, sequence_length, num_heads, d_k)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        # Transpose to (batch_size, num_heads, sequence_length, d_k)
        return x.transpose(1, 2)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate Q.K^T
        # Q: (batch_size, num_heads, sequence_length, d_k)
        # K.transpose(-2, -1): (batch_size, num_heads, d_k, sequence_length)
        # scores: (batch_size, num_heads, sequence_length, sequence_length)
        scores = torch.matmul(Q, K.transpose(-2, -1))

        # Scale scores
        # Divide by square root of d_k
        scaled_scores = scores / math.sqrt(self.d_k)

        # Apply mask (if provided)
        if mask is not None:
            # mask: (batch_size, 1, 1, sequence_length) or (batch_size, 1, sequence_length, sequence_length)
            # where 0 indicates masked positions
            scaled_scores = scaled_scores.masked_fill(mask == 0, -1e9)

        # Apply softmax to get attention weights
        attention_weights = torch.softmax(scaled_scores, dim=-1)

        # Apply dropout to attention weights
        attention_weights = self.dropout(attention_weights)

        # Multiply attention weights by V
        output = torch.matmul(attention_weights, V)

        return output, attention_weights

    # The forward method will be added in a subsequent step

## Concatenate Heads and Output Projection


In [None]:
import torch.nn as nn
import torch
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # Ensure d_model is divisible by num_heads
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_k = d_model // num_heads

        # Linear layers for Wq, Wk, Wv, and Wo
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

        # Dropout layer
        self.dropout = nn.Dropout(dropout_rate)

        print("MultiHeadAttention class initialized with Wq, Wk, Wv, Wo, and dropout layers.")

    def _split_heads(self, x, batch_size):
        # Reshape from (batch_size, sequence_length, d_model) to
        # (batch_size, sequence_length, num_heads, d_k)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        # Transpose to (batch_size, num_heads, sequence_length, d_k)
        return x.transpose(1, 2)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate Q.K^T
        # Q: (batch_size, num_heads, sequence_length, d_k)
        # K.transpose(-2, -1): (batch_size, num_heads, d_k, sequence_length)
        # scores: (batch_size, num_heads, sequence_length, sequence_length)
        scores = torch.matmul(Q, K.transpose(-2, -1))

        # Scale scores
        # Divide by square root of d_k
        scaled_scores = scores / math.sqrt(self.d_k)

        # Apply mask (if provided)
        if mask is not None:
            # mask: (batch_size, 1, 1, sequence_length) or (batch_size, 1, sequence_length, sequence_length)
            # where 0 indicates masked positions
            scaled_scores = scaled_scores.masked_fill(mask == 0, -1e9)

        # Apply softmax to get attention weights
        attention_weights = torch.softmax(scaled_scores, dim=-1)

        # Apply dropout to attention weights
        attention_weights = self.dropout(attention_weights)

        # Multiply attention weights by V
        output = torch.matmul(attention_weights, V)

        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]

        # 1. Apply linear transformations
        Q = self.Wq(query)  # (batch_size, sequence_length, d_model)
        K = self.Wk(key)    # (batch_size, sequence_length, d_model)
        V = self.Wv(value)  # (batch_size, sequence_length, d_model)

        # 2. Split heads
        # Q, K, V: (batch_size, num_heads, sequence_length, d_k)
        Q = self._split_heads(Q, batch_size)
        K = self._split_heads(K, batch_size)
        V = self._split_heads(V, batch_size)

        # 3. Calculate scaled dot-product attention
        # attention_output: (batch_size, num_heads, sequence_length, d_k)
        # attention_weights: (batch_size, num_heads, sequence_length, sequence_length)
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # 4. Concatenate heads
        # Transpose back to (batch_size, sequence_length, num_heads, d_k)
        attention_output = attention_output.transpose(1, 2).contiguous()
        # Reshape to (batch_size, sequence_length, d_model)
        concat_attention = attention_output.view(batch_size, -1, self.d_model)

        # 5. Apply final linear output projection (Wo)
        output = self.Wo(concat_attention)

        return output, attention_weights

print("MultiHeadAttention class with forward method implemented.")

MultiHeadAttention class with forward method implemented.


## Test Implementation

In [None]:
import torch

# 1. Define input dimensions
d_model = 512
num_heads = 8
seq_len = 60
batch_size = 2
dropout_rate = 0.1

# 2. Create dummy input tensors
query = torch.randn(batch_size, seq_len, d_model)
key = torch.randn(batch_size, seq_len, d_model)
value = torch.randn(batch_size, seq_len, d_model)

print(f"Dummy query tensor shape: {query.shape}")
print(f"Dummy key tensor shape: {key.shape}")
print(f"Dummy value tensor shape: {value.shape}")

# 3. Instantiate the MultiHeadAttention class
multi_head_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)

# 4. Call the forward method
output, attention_weights = multi_head_attention(query, key, value)

# 5. Print the shape of the returned tensors
print(f"Output tensor shape: {output.shape}")
print(f"Attention weights tensor shape: {attention_weights.shape}")

# Verify expected shapes
expected_output_shape = (batch_size, seq_len, d_model)
expected_attention_weights_shape = (batch_size, num_heads, seq_len, seq_len)

assert output.shape == expected_output_shape, f"Output shape mismatch! Expected {expected_output_shape}, got {output.shape}"
assert attention_weights.shape == expected_attention_weights_shape, f"Attention weights shape mismatch! Expected {expected_attention_weights_shape}, got {attention_weights.shape}"

print("Output and attention weights shapes are correct for unmasked attention.")

# 6. Create a dummy mask tensor and test with it
# Example: a causal mask for an autoregressive model
mask = torch.ones(seq_len, seq_len).triu(diagonal=1).bool()
mask = mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
mask = ~mask # Invert to have True where attention is allowed, False where masked
mask = mask.repeat(batch_size, 1, 1, 1) # Repeat for batch

print(f"Dummy mask tensor shape: {mask.shape}")

output_masked, attention_weights_masked = multi_head_attention(query, key, value, mask=mask)

print(f"Output tensor shape with mask: {output_masked.shape}")
print(f"Attention weights tensor shape with mask: {attention_weights_masked.shape}")

assert output_masked.shape == expected_output_shape, f"Output shape mismatch with mask! Expected {expected_output_shape}, got {output_masked.shape}"
assert attention_weights_masked.shape == expected_attention_weights_shape, f"Attention weights shape mismatch with mask! Expected {expected_attention_weights_shape}, got {attention_weights_masked.shape}"

print("Output and attention weights shapes are correct for masked attention.")
print("MultiHeadAttention implementation test successful!")

Dummy query tensor shape: torch.Size([2, 60, 512])
Dummy key tensor shape: torch.Size([2, 60, 512])
Dummy value tensor shape: torch.Size([2, 60, 512])
MultiHeadAttention class initialized with Wq, Wk, Wv, Wo, and dropout layers.
Output tensor shape: torch.Size([2, 60, 512])
Attention weights tensor shape: torch.Size([2, 8, 60, 60])
Output and attention weights shapes are correct for unmasked attention.
Dummy mask tensor shape: torch.Size([2, 1, 60, 60])
Output tensor shape with mask: torch.Size([2, 60, 512])
Attention weights tensor shape with mask: torch.Size([2, 8, 60, 60])
Output and attention weights shapes are correct for masked attention.
MultiHeadAttention implementation test successful!
