In [1]:
import torch
import torch.nn as nn


In [35]:
import torch
import torch.nn as nn

class MultiheadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.d_in = d_in
        self.head_dim = d_out // num_heads
        self.num_heads = num_heads

        # Linear layers for Q, K, and V transformations
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        self.linear_projection = nn.Linear(d_out, d_out)  # Projection layer

        self.dropout = nn.Dropout(dropout)

        # Causal mask for autoregressive processing (ensuring only past tokens are attended)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, inputs):
        batch, num_tokens, dim = inputs.shape
        print(f"\n[INFO] Input Shape: {inputs.shape} (batch_size, seq_len, embedding_dim)")

        # Compute Q, K, V matrices
        query = self.W_query(inputs)  
        key = self.W_key(inputs)      
        value = self.W_value(inputs)  

        print(f"\n[INFO] Q, K, V before reshaping:")
        print(f"    Query Shape: {query.shape} (should be batch_size, seq_len, d_out)")
        print(f"    Key Shape: {key.shape}")
        print(f"    Value Shape: {value.shape}")

        # Reshape into multiple heads
        query = query.view(batch, num_tokens, self.num_heads, self.head_dim)
        key = key.view(batch, num_tokens, self.num_heads, self.head_dim)
        value = value.view(batch, num_tokens, self.num_heads, self.head_dim)

        print(f"\n[INFO] Q, K, V after reshaping into {self.num_heads} heads:")
        print(f"    Query Shape: {query.shape} (batch_size, seq_len, num_heads, head_dim)")
        print(f"    Key Shape: {key.shape}")
        print(f"    Value Shape: {value.shape}")

        # Transpose for correct matrix multiplication
        query = query.transpose(1, 2)  
        key = key.transpose(1, 2)      
        value = value.transpose(1, 2)  

        print(f"\n[INFO] Q, K, V after transposing:")
        print(f"    Query Shape: {query.shape} (batch_size, num_heads, seq_len, head_dim)")
        print(f"    Key Shape: {key.shape}")
        print(f"    Value Shape: {value.shape}")

        # Compute attention scores
        attn_scores = query @ key.transpose(2, 3)  
        print(f"\n[INFO] Attention Scores Shape: {attn_scores.shape} (batch_size, num_heads, seq_len, seq_len)")

        # Apply causal mask
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # Compute attention weights
        attn_weights = torch.softmax(attn_scores / key.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)  
        print(f"\n[INFO] Attention Weights Shape: {attn_weights.shape}")

        # Compute context vector
        context_vec = attn_weights @ value  
        print(f"\n[INFO] Context Vector before reshaping: {context_vec.shape}")

        # Reshape back to original shape
        context_vec = context_vec.transpose(1, 2).contiguous().view(batch, num_tokens, self.d_out)
        print(f"\n[INFO] Context Vector after reshaping: {context_vec.shape}")

        # **Optional Projection Step**
        # This linear projection ensures the output is in the expected feature space.
        # It's useful for residual connections when stacking multiple attention layers.
        context_vec = self.linear_projection(context_vec)
        print(f"\n[INFO] Context Vector after final projection: {context_vec.shape} (batch_size, seq_len, d_out)")

        return context_vec

# Sample Input
inputs = torch.tensor([
    [0.43, 0.15, 0.89],  
    [0.55, 0.87, 0.66],  
    [0.57, 0.85, 0.64],  
    [0.22, 0.58, 0.33],  
    [0.77, 0.25, 0.10],  
    [0.05, 0.80, 0.55]   
])

batch = torch.stack((inputs, inputs), dim=0)  
print(f"\n[INFO] Batch Shape: {batch.shape} (batch_size=2, seq_len=6, embedding_dim=3)")

# Initialize Multihead Attention
batch_size, context_length, d_in = batch.shape
d_out = 4  
multi_head = MultiheadAttention(d_in, d_out, context_length, dropout=0, num_heads=2)
print(multi_head)

# Run forward pass
output = multi_head(batch)
print("\n[INFO] Final Output:")
print(output)



[INFO] Batch Shape: torch.Size([2, 6, 3]) (batch_size=2, seq_len=6, embedding_dim=3)
MultiheadAttention(
  (W_query): Linear(in_features=3, out_features=4, bias=False)
  (W_key): Linear(in_features=3, out_features=4, bias=False)
  (W_value): Linear(in_features=3, out_features=4, bias=False)
  (linear_projection): Linear(in_features=4, out_features=4, bias=True)
  (dropout): Dropout(p=0, inplace=False)
)

[INFO] Input Shape: torch.Size([2, 6, 3]) (batch_size, seq_len, embedding_dim)

[INFO] Q, K, V before reshaping:
    Query Shape: torch.Size([2, 6, 4]) (should be batch_size, seq_len, d_out)
    Key Shape: torch.Size([2, 6, 4])
    Value Shape: torch.Size([2, 6, 4])

[INFO] Q, K, V after reshaping into 2 heads:
    Query Shape: torch.Size([2, 6, 2, 2]) (batch_size, seq_len, num_heads, head_dim)
    Key Shape: torch.Size([2, 6, 2, 2])
    Value Shape: torch.Size([2, 6, 2, 2])

[INFO] Q, K, V after transposing:
    Query Shape: torch.Size([2, 2, 6, 2]) (batch_size, num_heads, seq_len, h