In [1]:
import torch

# Suppose we have the following token embedding vector.
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x1)
   [0.55, 0.87, 0.66], # journey  (x2)
   [0.57, 0.85, 0.64], # starts   (x3)
   [0.22, 0.58, 0.33], # with     (x4)
   [0.77, 0.25, 0.10], # one      (x5)
   [0.05, 0.80, 0.55]] # step     (x6)
)

In [2]:
# This code implements a self-attention mechanism using PyTorch. Here's a breakdown:
d_in = inputs.shape[1]  # d_in: Input dimension (e.g., 3 for our token embeddings)
d_out = 2               # d_out: Output dimension (e.g., 2 for smaller context vectors)
import torch
import torch.nn as nn
class SelfAttention(nn.Module): # Defines a neural network module for self-attention.
    
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()

        # The "bias" parameter in the linear layers controls whether to include a bias term in the Query, Key, and Value projections.
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)  # W_query: Projects input to query vectors (what to look for)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)  # W_key: Projects input to key vectors (what to offer)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)  # W_value: Projects input to value vectors (actual content)

    
    # Transforms input x (token embeddings) into queries, keys, and values.
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T  # Computes attention scores via dot products between all query-key pairs.

        # Applies softmax to convert scores to probabilities (weights sum to 1 per row).
        attn_weights = torch.softmax(
            # Uses scaling (/ keys.shape[-1]**0.5) for stability.
            # Scales scores by √d_k (here, d_k = d_out = 2) to prevent gradient issues.
            attn_scores / keys.shape[-1]**0.5, dim=-1 
        )

        # Computes context vectors as weighted sums of value vectors (using attention weights).
        context_vec = attn_weights @ values
        return context_vec

In [3]:
# use the SelfAttention_v2

torch.manual_seed(789)
sa_v2 = SelfAttention(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


In [4]:
# Check the values of queries, keys, attn_scores, and attn_weights
queries = sa_v2.W_query(inputs)     #1
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2034, 0.1635, 0.1642, 0.1501, 0.1741, 0.1446],
        [0.2212, 0.1650, 0.1654, 0.1425, 0.1658, 0.1401],
        [0.2204, 0.1651, 0.1655, 0.1429, 0.1657, 0.1405],
        [0.1958, 0.1665, 0.1667, 0.1531, 0.1657, 0.1522],
        [0.1902, 0.1668, 0.1670, 0.1556, 0.1654, 0.1551],
        [0.2054, 0.1659, 0.1662, 0.1490, 0.1662, 0.1472]],
       grad_fn=<SoftmaxBackward0>)


In [5]:
# Hiding future words with causal attention

# attn_scores is a 6*6 matrix (for our 6-token sequence)
# context_length = 6 (number of tokens)
context_length = attn_scores.shape[0] 

# torch.ones(6, 6) creates a 6×6 matrix filled with 1s
#torch.triu(..., diagonal=1) extracts the upper triangular
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

# mask.bool() converts the mask to boolean (True/False)
# masked_fill replaces True positions with -torch.inf (negative infinity)
# Effect: Future token scores become -inf, which becomes 0 after softmax.
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [6]:
# Now all we need to do is apply the softmax function to these masked results, and we are done:

"""
 applies the softmax function to the masked attention scores to convert them into valid attention weights.
 dim=1: Apply softmax across each row (for each token's attention distribution)
 exp(-inf) = 0
"""
attn_weights = torch.softmax(masked , dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5728, 0.4272, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4000, 0.2996, 0.3004, 0.0000, 0.0000, 0.0000],
        [0.2870, 0.2441, 0.2444, 0.2245, 0.0000, 0.0000],
        [0.2251, 0.1974, 0.1976, 0.1842, 0.1957, 0.0000],
        [0.2054, 0.1659, 0.1662, 0.1490, 0.1662, 0.1472]],
       grad_fn=<SoftmaxBackward0>)


In [7]:
# Masking additional attention weights with dropout
"""
    Dropout is a regularization technique.
    Randomly "turns off" (sets to zero) a fraction of neurons or weights during training to prevent overfitting.
    0.5: Dropout rate (50% of values will be set to zero during training)
    The remaining 50% are scaled up by 2× (1/0.5 = 2) to maintain the overall magnitude
"""
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) 
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8544, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6008, 0.0000, 0.0000, 0.0000],
        [0.5739, 0.4882, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4501, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3318, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


In [8]:
# Implementing a compact causal attention class

"""
    d_in: Input dimension (e.g., embedding size)
    d_out: Output dimension for queries/keys/values
    context_length: Maximum sequence length the model can handle
    dropout: Dropout rate (e.g., 0.5 for 50%)
    qkv_bias: Whether to include bias in Q/K/V projections
"""

import torch
import torch.nn as nn
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out

        """
         The "bias" parameter in the linear layers controls whether to include a bias term in the Query, Key, and Value projections.
          W_query: Projects input to query vectors (what to look for)
          W_key: Projects input to key vectors (what to offer)
          W_value: Projects input to value vectors (actual content)
        """
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)    # Learnable projections for Query
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)    # Learnable projections for Key
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)    # Learnable projections for Value


        self.dropout = nn.Dropout(dropout)   # Dropout layer for regularization

        """
          torch.ones(6, 6) creates a 6×6 matrix filled with 1s
           torch.triu(..., diagonal=1) extracts the upper triangular
        """
        self.register_buffer(
           'mask',
           torch.triu(torch.ones(context_length, context_length),
           diagonal=1)
        )  
        
    # Transforms input x (token embeddings) into queries, keys, and values.
    def forward(self, x):
        b, num_tokens, d_in = x.shape        # Extract batch size, actual sequence length, and input dim
        keys = self.W_key(x)                 # Project input to Key. shape [batch_size, num_tokens, d_out]
        queries = self.W_query(x)            # Project input to Query. shape [batch_size, num_tokens, d_out]
        values = self.W_value(x)             # Project input to Value. shape [batch_size, num_tokens, d_out]

        """
         Computes attention scores via dot products between all query-key pairs.
         Result attn_scores: [b, T, T] (batch_size, num_tokens, num_tokens)
        """
        attn_scores = queries @ keys.transpose(1, 2) 

        """
            mask.bool() converts the mask to boolean (True/False)
            masked_fill replaces True positions with -torch.inf (negative infinity)
            Effect: Future token scores become -inf, which becomes 0 after softmax.
        """
        attn_scores.masked_fill_(                    
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 

        """
             Applies the softmax function to the masked attention scores to convert them into valid attention weights.
             dim=1: Apply softmax across each row (for each token's attention distribution)
             exp(-inf) = 0
        """
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=1
        )
       
        """
          Dropout is a regularization technique.
          Randomly "turns off" (sets to zero) a fraction of neurons or weights during training to prevent overfitting.
          0.5: Dropout rate (50% of values will be set to zero during training)
          The remaining 50% are scaled up by 2× (1/0.5 = 2) to maintain the overall magnitude
        """
        attn_weights = self.dropout(attn_weights)

        # Compute weighted sum of values
        # Return context-enriched representations
        context_vec = attn_weights @ values
        return context_vec

In [9]:
"""
   We can use the CausalAttention class as follows, similar to SelfAttention previously:
   For simplicity, to simulate such batch inputs, we duplicate the input text example:
   inputs: A single sequence tensor of shape [6, 3] (6 tokens, 3-dimensional embeddings)
   torch.stack((inputs, inputs), dim=0):
      Takes two copies of inputs
      Stacks them along a new dimension at position 0 (batch dimension)
      Creates shape [2, 6, 3], batch.shape: torch.Size([2, 6, 3])
      [2, 6, 3] = [batch_size, num_tokens, embedding_dim]
         2: Number of sequences in batch
         6: Tokens per sequence
         3: Dimensions per token embedding
"""

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)  

torch.manual_seed(123)
context_length = batch.shape[1]

"""
   This code initializes and uses the causal attention mechanism on a batch of input sequences:
   Creates a causal attention instance with:
     d_in: Input dimension (3 - embedding size of each token)
     d_out: Output dimension (2 - size of query/key/value vectors)
     context_length: Maximum sequence length (6 - tokens per sequence)
     0.0: Dropout rate (0% dropout = no regularization)
"""
ca = CausalAttention(d_in, d_out, context_length, 0.0)

"""
    Processes the input batch through the attention mechanism:
    batch: Shape [2, 6, 3] (2 sequences, 6 tokens each, 3D embeddings)
    context_vecs: Output shape [2, 6, 2] (2 sequences, 6 tokens, 2D context vectors)
"""
context_vecs = ca(batch)

print('context_vecs.shape:', context_vecs.shape)

torch.Size([2, 6, 3])
context_vecs.shape: torch.Size([2, 6, 2])


In [10]:
# Implementing multi-head attention with weight splits
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, 
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        # Ensures the output dimension can be split evenly among heads
        # Example: If d_out=8 and num_heads=4, each head gets 2 dimensions
        assert (d_out % num_heads == 0), \
            'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads    
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)# Single large projection matrices (more efficient than separate ones per head)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)   # Optional layer to mix information from different heads
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)         # [b, num_tokens, d_out]
        queries = self.W_query(x)    # [b, num_tokens, d_out]
        values = self.W_value(x)     # [b, num_tokens, d_out]

        """
            Reshape for Multiple Heads
            Reshapes [b, T, d_out] → [b, T, h, d_h] where d_out = h * d_h
            Example: [2, 6, 8] → [2, 6, 4, 2] (4 heads, 2 dims each)
        """
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)  
        queries = queries.view(                                             
            b, num_tokens, self.num_heads, self.head_dim                    
        )                                                                   

        """
            Transpose for Batch Computation
            Rearranges to [batch, heads, tokens, dims_per_head]
            Allows parallel computation across heads
        """
        keys = keys.transpose(1, 2)          # [b, h, T, d_h]
        queries = queries.transpose(1, 2)    # [b, h, T, d_h]
        values = values.transpose(1, 2)      # [b, h, T, d_h]

        """
            Compute Attention Scores
            Batched matrix multiplication across all heads
            Computes all attention scores in parallel
        """
        attn_scores = queries @ keys.transpose(2, 3)  

        """
            Apply Causal Mask
            Uses pre-computed triangular mask
            Blocks future tokens for autoregressive generation
        """
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        
        # Softmax: Standard scaled softmax attention
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)

        # Dropout for regularization
        attn_weights = self.dropout(attn_weights)

        """
            Apply Attention to Values
            Weighted sum of values
            ranspose back to [batch, tokens, heads, dims]
        """
        context_vec = (attn_weights @ values).transpose(1, 2)   # [b, T, h, d_h]

        """
            Combine Heads
            Flatten heads: [b, T, h, d_h] → [b, T, h*d_h] = [b, T, d_out]
            Example: [2, 6, 4, 2] → [2, 6, 8]
        """
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )

        """
            Output Projection
            Optional linear transformation
            Helps mix information across heads
        """
        context_vec = self.out_proj(context_vec)    #11
        return context_vec

In [11]:
#The MultiHeadAttention class can be used similar to the SelfAttention and CausalAttention classes we implemented earlier:

torch.manual_seed(123)
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)  
d_in = inputs.shape[1]  # d_in: Input dimension (e.g., 3 for our token embeddings)
d_out = 2               # d_out: Output dimension (e.g., 2 for smaller context vectors)
batch_size, context_length, d_in = batch.shape

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print('context_vecs.shape:', context_vecs.shape)

torch.Size([2, 6, 3])
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
