In [3]:
# Build an Attention Neural Network using PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

In [18]:
class AttentionNeuralNet(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model    # embedding dimension (e.g., 512)
    self.num_heads = num_heads  # number of attention heads (e.g., 8)
    self.head_dim = d_model // num_heads  # dimension per head (e.g., 64)

    # Create the Q, K, V projection layers
    self.q_proj = nn.Linear(d_model, d_model)
    self.k_proj = nn.Linear(d_model, d_model)
    self.v_proj = nn.Linear(d_model, d_model)

    # Final output projection
    self.out_proj = nn.Linear(d_model, d_model)

  # scaled dot product attention
  def attention(self, Q, K, V):
    """
    Q, K, V are expected to be of shape:
      [batch_size, seq_len, d_k]
    or possibly
      [batch_size, num_heads, seq_len, d_k]
    if you’re already doing multi-head splitting.
    """
    d_k = K.shape[-1]
    scores = Q @ K.transpose(-2, -1)
    scores = scores / math.sqrt(d_k)
    attention_weights = F.softmax(scores, dim=-1)
    output = attention_weights @ V

    return output, attention_weights

  def reshape_attention(self, output, attention_weights, batch_size, seq_len, d_model):
    # re-order dimensions back to original
    output = torch.permute(output, (0, 2, 1, 3))
    # reshape the dimensions to "combine" the attention heads outputs
    output = output.reshape(batch_size, seq_len, d_model)
    # attention_weights has shape [batch_size, num_heads, seq_len, seq_len]
    # Average across the heads dimension (dim=1)
    attention_weights = attention_weights.mean(dim=1)

    return output, attention_weights

  def forward(self, x):
    batch_size, seq_len, d_model = x.shape
    Q = self.q_proj(x)
    K = self.k_proj(x)
    V = self.v_proj(x)

    # Reshape to separate the heads
    Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
    K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
    V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

    # re-order dimensions to be compatible with attention method
    Q = torch.permute(Q, (0, 2, 1, 3))
    K = torch.permute(K, (0, 2, 1, 3))
    V = torch.permute(V, (0, 2, 1, 3))

    output, attention_weights = self.attention(Q, K, V)
    output, attention_weights = self.reshape_attention(output, attention_weights, batch_size, seq_len, d_model)

    output = self.out_proj(output)

    return output, attention_weights


In [21]:
def test_attention_shapes(batch_size=32, seq_len=10, d_model=512, num_heads=8):
    # Create model
    model = AttentionNeuralNet(d_model=d_model, num_heads=num_heads)

    # Create dummy input
    x = torch.randn(batch_size, seq_len, d_model)

    # Forward pass
    output, attention_weights = model(x)

    # Check shapes
    assert output.shape == (batch_size, seq_len, d_model)
    assert attention_weights.shape == (batch_size, seq_len, seq_len)

def test_attention_weights_sum_to_one(batch_size=32, seq_len=10, d_model=512, num_heads=8):
  # Create model
  model = AttentionNeuralNet(d_model=d_model, num_heads=num_heads)

  # Create dummy input
  x = torch.randn(batch_size, seq_len, d_model)

  # Forward pass
  output, attention_weights = model(x)
  assert torch.allclose(attention_weights.sum(dim=-1), torch.ones_like(attention_weights.sum(dim=-1)))

In [25]:
test_attention_shapes()
test_attention_weights_sum_to_one()

In [28]:
def create_sequence_dataset(num_sequences=1000, seq_length=10, d_model=512):
    # Create random input sequences
    X = torch.randn(num_sequences, seq_length, d_model)
    # Create target sequences (initially same as input)
    y = X.clone()

    # For positions 2, 5, 8, etc., make the target the sum of previous two tokens
    for pos in range(2, seq_length, 3):
        y[:, pos] = X[:, pos-1] + X[:, pos-2]

    return X, y

# Let's test the dataset creation
def test_dataset():
    X, y = create_sequence_dataset(num_sequences=5, seq_length=10, d_model=4)
    print("Input shape:", X.shape)
    print("Target shape:", y.shape)

    # Verify the pattern for first sequence
    print("\nFirst sequence, first few dimensions:")
    print("Position 2 should equal sum of positions 0 and 1:")
    print(f"X[0, 0]: {X[0, 0][:2]}")  # First token
    print(f"X[0, 1]: {X[0, 1][:2]}")  # Second token
    print(f"y[0, 2]: {y[0, 2][:2]}")  # Third token (should be sum)

# Training loop
def train_attention_model(model, num_epochs=10):
    X_train, y_train = create_sequence_dataset()
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        output, _ = model(X_train)
        loss = criterion(output, y_train)
        loss.backward()
        optimizer.step()

        if epoch % 2 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

In [33]:
batch_size=32
seq_len=10
d_model=512
num_heads=8
model = AttentionNeuralNet(d_model=d_model, num_heads=num_heads)
train_attention_model(model, num_epochs=50)

Epoch 0, Loss: 1.3141
Epoch 2, Loss: 1.2966
Epoch 4, Loss: 1.2799
Epoch 6, Loss: 1.2622
Epoch 8, Loss: 1.2419
Epoch 10, Loss: 1.2172
Epoch 12, Loss: 1.1863
Epoch 14, Loss: 1.1490
Epoch 16, Loss: 1.1069
Epoch 18, Loss: 1.0641
Epoch 20, Loss: 1.0253
Epoch 22, Loss: 0.9915
Epoch 24, Loss: 0.9599
Epoch 26, Loss: 0.9270
Epoch 28, Loss: 0.8917
Epoch 30, Loss: 0.8549
Epoch 32, Loss: 0.8181
Epoch 34, Loss: 0.7823
Epoch 36, Loss: 0.7476
Epoch 38, Loss: 0.7134
Epoch 40, Loss: 0.6789
Epoch 42, Loss: 0.6442
Epoch 44, Loss: 0.6099
Epoch 46, Loss: 0.5767
Epoch 48, Loss: 0.5448


In [36]:
def analyze_model(model, seq_length=10, d_model=512):
    # Create a test sequence
    X_test = torch.randn(1, seq_length, d_model)
    y_test = X_test.clone()
    for pos in range(2, seq_length, 3):
        y_test[:, pos] = X_test[:, pos-1] + X_test[:, pos-2]

    # Get model predictions and attention weights
    with torch.no_grad():
        pred, attention_weights = model(X_test)

    # Calculate prediction error
    mse = nn.MSELoss()(pred, y_test)
    print(f"Test MSE: {mse.item():.4f}")

    # Analyze attention patterns
    print("\nAttention patterns for summed positions:")
    for pos in range(2, seq_length, 3):
        print(f"\nPosition {pos} attention weights:")
        print(attention_weights[0, pos, pos-2:pos+1])  # Show attention to previous tokens

def analyze_predictions(model, seq_length=10, d_model=512):
    X_test = torch.randn(1, seq_length, d_model)
    y_test = X_test.clone()

    # Create expected sums
    for pos in range(2, seq_length, 3):
        y_test[:, pos] = X_test[:, pos-1] + X_test[:, pos-2]

    with torch.no_grad():
        pred, _ = model(X_test)

    # Compare predictions with expected sums
    for pos in range(2, seq_length, 3):
        expected_sum = X_test[0, pos-2] + X_test[0, pos-1]
        print(f"\nPosition {pos}:")
        print(f"Expected sum: {expected_sum[:5]}")  # Show first 5 dimensions
        print(f"Prediction:   {pred[0, pos][:5]}")

In [37]:
analyze_model(model)

Test MSE: 0.9548

Attention patterns for summed positions:

Position 2 attention weights:
tensor([0.0551, 0.0303, 0.6893])

Position 5 attention weights:
tensor([0.0646, 0.0256, 0.6734])

Position 8 attention weights:
tensor([0.0313, 0.0477, 0.5942])


In [38]:
analyze_predictions(model)


Position 2:
Expected sum: tensor([-3.8203, -0.2196, -0.1141, -0.7116,  0.2415])
Prediction:   tensor([-0.2672, -0.2304,  0.3185,  0.7512,  1.3425])

Position 5:
Expected sum: tensor([ 0.1566, -0.0233,  0.6566, -1.4544, -2.7014])
Prediction:   tensor([-1.0759, -1.0394,  0.6194, -1.0045,  0.4681])

Position 8:
Expected sum: tensor([ 0.1154, -0.6889,  1.8166, -0.6379, -1.6967])
Prediction:   tensor([-0.3267, -0.0108, -0.0664, -0.3468, -0.2812])


# NOTES

Input embedding dimension is the embedding of each token's embedding matrix that comes into the attention layer. -> Split words into tokens and each token is converted to an embedding matrix that numerically represents what the word is in some language embedding space.

Attention weights have a different dimension than the outputs.

For the attention weights:

They come from the scores calculation: scores = Q @ K.transpose(-2, -1)

* Q shape: [batch_size, num_heads, seq_len, head_dim]
* K.transpose shape: [batch_size, num_heads, head_dim, seq_len]
* When you multiply these, you get: [batch_size, num_heads, seq_len, seq_len]

The key difference is that attention weights represent how much each token attends to every other token.