In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads" 
        
        self.d_model = d_model # The dimensionality of all representations
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)"""
        batch_size, seq_len, d_model = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """Combine heads back to original dimension"""
        batch_size, num_heads, seq_len, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

    ## Forward pass
    def forward(self, query, key, value, mask=None):
        # Linear projections
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Split into multiple heads
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask (for causal attention) 
        # This is to zero out all future positions, forcing the predictions to use past and present information only, which is especially important for text generation.   
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax and dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)
        
        # Combine heads
        attn_output = self.combine_heads(attn_output)
        
        # Final linear projection
        output = self.W_o(attn_output)
        
        return output

`d_model` --> the dimsensionality of all representations
- if `d_model` = 256
  - each token embedding is a 256-dimensional vector
  - each position embedding is 256-dimensional
  - output of each attention layer is 256-dimensional
  - everything flow through the model is 256-dimensional space
  - Memory size of the model
- Think 'width' of a neural network - how much information capacity each token holds.
  - Each token's size depends on d_model's vector size
- So is larger better?
  - Allows each token to store more information, more working memory, and allows more heads.
  - BUT, the weights scale with d_model so this means more parameters, more computation, and more memory
-Scales in quadratically in parameters -> d_model*d_model or d_model^2

`num_heads / d_k` --> number of attention heads
- Used to split attention operations on the d_model 
- the dimension of key (as well as query and value) vectors 
  - (`d_k` = `d_model` / `num_heads`)
- d_model = 256, heads = 8 -> 32 dimension heads to work with
  -Head 1: 0-31, Head 2: 32-63, etc...
  -Different heads learn different patterns, head 1 might learn subject-verb, head 2 might learn adjective-noun
-Concatenate all heads back together later to get the full d_model vector

`nn.Linear` --> Projects into matrix multiplication
- Creates a weight matrix
  - Let `d_model = 4` then W_q weight matrix is now 4x4
- Gradient descent is related as it updates the weights inside W_q, W_k and W_v to make better predictions (Backward Pass). But projection itself the `nn.Linear` is matrix multiplication (Forward Pass)

`Q`, `K`, `V` , `O`
- `Q` -> Asks a question on what information is needed. "What am I looking for?"
- `K` -> Key mapping to the question. "What do I contain?"
- `V` -> The content of the key. Think key-value pair in python `dict`. "What information do I have?"
- Think soft database lookup
- `O` -> Takes concatenated attention outputs from all heads and mixes them
- Full Flow is as follows 
  - Input (d_model=512) -> [W_q, W_k, W_v] is split into 8 heads (64 dim) -> Attention computation per head -> Concatenate back to 512 dim -> W_o final output at 512 dim

`seq_len` --> Token Length
- Attention computation is at O(n^2) complexity, scales in attention
- While `d_model` holds information capacity for token (How smart is each token), `seq_len` is how much context it can see (How much can it remember at once)
- **Note** That a token is roughly 0.75 of word. 

`Forward Function Walkthrough` 
- 1. Set up -> batch=2, seq_len=4, d_mode=512, num_heads=8, d_k=64 
- 2. Linear Projections -> x=(2,4,512) - `batch, seq_len, dim` , project into Q, K, V Spaces
- 3. Split into heads -> Before x=(2,4,512), After x=(2,8,4,64) - `batch, num_heads, seq_len, d_k` to now have 8 independent attention mechanisms with 64 dimensions each
- **Important - Core Mechanism** Scaled dot-product attention 
  1. Compute attention Scores -> Measure how much each token should focus on every other token
    - `K.transpose(-2,-1):` Swap last to dims -> K = (batch, num_heads, d_k, seq_len)
    - `Q @ K.T:` Becomes (batch, num_heads, seq_len, seq_len). If Q and K point in similar directions, their dot product is high = strong attention. Shape also follows (seq_len, seq_len) so for example = 3, then matrix is (3,3)
  2. Scale the scores -> As d_k gets larger, dot product gets larger, more dims to sum
    - Large scores -> softmax saturates -> gardients vanish
    - `sqrt(d_k):` To normalize, as if `Q` and `K` have unit variance, their dot product has variance = d_k. Therefore dividing by the square root normalizes the variance back to 1. 
    - Variance grows linearly with d_k as Var(Q*k) = Var(Q[0] * k[0]) + Var...
      - Independent variables with variance = 1 + 1 + 1... = d_k
  3. Apply mask if provided -> A boolean tensor indicating which positions to ignore
    - `-1e9/-inf instead of 0:` Softmax scores are still applied to padding if set to 0, but setting at -inf allows padding to be ignored.  
    ` Example: (mask = [1, 1, 0, 0]), scores = [2.5, 1.3, -1e9, -1e9]. Last 2 positions masked
  4. Softmax to get attention weight -> Converts scores into pdf (sums to 1)
    - **Refer to quick lesson on softmax**
    - Allows gradients to flow smoothly
      - Too small = No learning, Too large = unstable.
      - Gradients of softmax (derivative of softmax) bound between **0 and 0.25**
      - Average gradient often around 0.1-0.15
    - Amplifies differences as larger scores get proportionally more weight
      - Before `[2.0, 4.0, 1.0]` 4.0 is 2x bigger than 2.0, After `[0.12, 0.84, 0.04] ,84 is 7x bigger than 0.12
      - `dim=-1:` Normalize on the rows for each key, based off indexing.
        - Shape: (Batch=2, num_heads=8, seq_len=4, seq_len=4)
        - dim=0: batch, dim=1: num_heads, dim=2: seq_len (queries), dim=3: seq_len (keys)
        - dim=-4: batch, dim=-3:num_heads, dim=-2: seq_len (queries), **dim=-1: seq_len (keys)** 
        - This means for each query, we normalize across all keys
  5. Apply attention to values -> compute weighted average of value vectors and mix
    - Attention weights from softmax * value vectors
    - `Example:` seq_len ='The', 'Cat', 'Sat', d_k = 4 -> value vectors for V_The, V_cat, V_sat -> weights from softmax for cat = [0.3, 0.38, 0.32], Compute weighted sum for a word -> 0.3 * V_The + 0.38 * V_cat + 0.32 * V_sat -> Cat is now is a mix of all tokens weighted by attention **Influenced 30% by "The", 38% by "Cat/Itself", 32% by "Sat"**
- 4. Combine heads -> Before (2,8,4,64), After (2,4,512), concatenated back
  - Uses .contiguous() to reorganize memory after transpose before reshaping
- 5. Output projection -> Mix information across all heads into final output

`Mask` --> Prevents attention to certain positions
- Ignore padding tokens or can't look at future tokens

Attention formula 

`scores = Q @ K^T / sqrt(d_k)`   How relevant is each token?

`attention_weights = softmax(scores)`   Normalize to probabilities

`output = attention_weights @ V`   Weighted sum of values

Quick lesson on softmax 
- Given [0.9, -1.5, 3.2], we first want to exponentiate all values to make all values positive and amplify the differences -> [2.45960311116, 0.22313016014, 24.5325301971].
- We then normalize by dividing based on the summation of the exponentials to find the probability distribution -> [2.45960311116/27.2152634684, 0.22313016014/27.2152634684, 24.5325301971/27.2152634684] --> [0.090376136, 0.008198731, 0.901425133] = 1

`dropout`
- Controls regularization.
- In training, dropout randomly sets the percentage of neurons to zero.
- If `dropout` = 0.1, 10% of neurons are deactivated and 90% remains activated.
  - remaining neurons are scaled up and 'strengthened' to compensate
- This prevents overfitting so that the training does not rely on a single/few neurons, making the model more robust.
- Only activated during training, auto-disabled during inference/test. 

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [None]:
class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))
        
        return x

In [None]:
class TwoLayerTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=4, d_ff=1024, 
                 max_seq_len=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        # Token embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional embeddings
        self.positional_embedding = nn.Embedding(max_seq_len, d_model)
        
        # Two transformer layers
        self.layer1 = TransformerLayer(d_model, num_heads, d_ff, dropout)
        self.layer2 = TransformerLayer(d_model, num_heads, d_ff, dropout)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def create_causal_mask(self, seq_len, device):
        """Create a causal mask to prevent attending to future tokens"""
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        return mask.view(1, 1, seq_len, seq_len)
    
    def forward(self, x):
        batch_size, seq_len = x.size()
        
        # Create position indices
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        
        # Token + positional embeddings
        token_emb = self.token_embedding(x)
        pos_emb = self.positional_embedding(positions)
        x = self.dropout(token_emb + pos_emb)
        
        # Create causal mask
        mask = self.create_causal_mask(seq_len, x.device)
        
        # Pass through two transformer layers
        x = self.layer1(x, mask)
        x = self.layer2(x, mask)
        
        # Project to vocabulary
        logits = self.output_projection(x)
        
        return logits