In [3]:
# Build an Attention Neural Network using PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

In [6]:
class AttentionNeuralNet(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model    # embedding dimension (e.g., 512)
    self.num_heads = num_heads  # number of attention heads (e.g., 8)
    self.head_dim = d_model // num_heads  # dimension per head (e.g., 64)

    # Create the Q, K, V projection layers
    self.q_proj = nn.Linear(d_model, d_model)
    self.k_proj = nn.Linear(d_model, d_model)
    self.v_proj = nn.Linear(d_model, d_model)

    # Final output projection
    self.out_proj = nn.Linear(d_model, d_model)

    # scaled dot product attention
    def attention(self, Q, K, V):
      """
      Q, K, V are expected to be of shape:
        [batch_size, seq_len, d_k]
      or possibly
        [batch_size, num_heads, seq_len, d_k]
      if you’re already doing multi-head splitting.
      """
      d_k = K.shape[-1]
      scores = Q @ K.Transpose(-2, -1)
      scores = scores / math.sqrt(d_k)
      attention_weights = F.softmax(scores, dim=-1)
      output = attention_weights @ V

      return output, attention_weights

    def reshape_attention(self, output, attention_weights, batch_size, seq_len):
      # re-order dimensions back to original
      output = torch.permute(output, (0, 2, 1, 3))
      attention_weights = torch.permute(attention_weights, (0, 2, 1, 3))
      # reshape the dimensions to "combine" the attention heads outputs
      output = output.reshape(batch_size, seq_len, d_model)
      # attention_weights has shape [batch_size, num_heads, seq_len, seq_len]
      # Average across the heads dimension (dim=1)
      attention_weights = attention_weights.mean(dim=1)

      return output, attention_weights

    def forward(self, x):
      batch_size, seq_len, d_model = x.shape
      Q = self.q_proj(x)
      K = self.k_proj(x)
      V = self.v_proj(x)

      # Reshape to separate the heads
      Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
      K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
      V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

      # re-order dimensions to be compatible with attention method
      Q = torch.permute(Q, (0, 2, 1, 3))
      K = torch.permute(K, (0, 2, 1, 3))
      V = torch.permute(V, (0, 2, 1, 3))

      output, attention_weights = self.attention(Q, K, V)
      output, attention_weights = self.reshape_attention(output, attention_weights, batch_size, seq_len)

      output = self.out_proj(output)

      return output, attention_weights


# NOTES

Input embedding dimension is the embedding of each token's embedding matrix that comes into the attention layer. -> Split words into tokens and each token is converted to an embedding matrix that numerically represents what the word is in some language embedding space.

Attention weights have a different dimension than the outputs.

For the attention weights:

They come from the scores calculation: scores = Q @ K.transpose(-2, -1)

* Q shape: [batch_size, num_heads, seq_len, head_dim]
* K.transpose shape: [batch_size, num_heads, head_dim, seq_len]
* When you multiply these, you get: [batch_size, num_heads, seq_len, seq_len]

The key difference is that attention weights represent how much each token attends to every other token.