In [107]:
from importlib.metadata import version

import torch
import torch.nn as nn

print("torch version:", version("torch"))

# Sample input tensor
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

d_in, d_out = 3, 2

# First version of self-attention using parameter matrices directly
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)

# Second version using Linear layers
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)  # Note: dim=1 here

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)

# Use the same weights for both implementations to compare them
sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)

# Compare outputs
print("SA v1 output:")
print(sa_v1(inputs))

print("\nSA v2 output:")
print(sa_v2(inputs))

# Multi-head attention implementation
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads=2):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.num_heads = num_heads
        self.d_k = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        self.W_out = nn.Linear(d_out, d_out, bias=False)

    def forward(self, x):
        batch_size = 1  # Since we're not using batches in this example
        seq_len = x.size(0)

        # Add batch dimension if it's missing
        if x.dim() == 2:
            x = x.unsqueeze(0)
            batch_size = 1
        else:
            batch_size = x.size(0)

        # Linear projections
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Reshape for multi-head attention
        queries = queries.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        values = values.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = torch.softmax(scores, dim=-1)

        # Apply attention to values
        context = torch.matmul(attn_weights, values)

        # Reshape and apply output transformation
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.d_k)
        output = self.W_out(context)

        return output

# Test the multi-head attention
torch.manual_seed(123)
d_in, d_out = 3, 4  # d_out must be divisible by num_heads
num_heads = 2

mha = MultiHeadAttention(d_in, d_out, num_heads)
mha_output = mha(inputs)

print("\nMulti-Head Attention output:")
print(mha_output)
print("mha_output.shape:", mha_output.shape)

# Function to count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\nNumber of parameters in MHA:", count_parameters(mha))

torch version: 2.6.0
SA v1 output:
tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

SA v2 output:
tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

Multi-Head Attention output:
tensor([[[-0.1288,  0.1090, -0.0044,  0.1115],
         [-0.1304,  0.1060, -0.0043,  0.1148],
         [-0.1304,  0.1061, -0.0043,  0.1147],
         [-0.1310,  0.1064, -0.0043,  0.1154],
         [-0.1304,  0.1075, -0.0044,  0.1141],
         [-0.1311,  0.1059, -0.0043,  0.1157]]], grad_fn=<UnsafeViewBackward0>)
mha_output.shape: torch.Size([1, 6, 4])

Number of parameters in MHA: 52
