# Coding Attention Mechanisms

## The problem with modeling Long sequence

Before we dive into the self-attention mechanism at the heart of LLMS, let's consider the problem with pre-LLM architectures that do not include attention mechanisms.

Suppose we want to develop a language translation model that translates text from one language to another. We can't simply translate a text word by word due to grammatical structures in the source and target language

To address this problem, it is common to use a **deep neural network** with two submodules, an *encoder and a decoder*.

- Encoder: Read in and process the entire text
- Decoder: Produces the translated text

Before the advent of transformers, recurrent reural networks (RNNs) wer the most popular encoder-decoder architecture for language translation.

RNN is a type of neural network where outputs from previous steps are fed as inputs to the current step, making them well-suited for sequential data like text

**RNNs:**
- RNNs process sequences step-by-step, feeding the output of one step into the next
- In translation:
    - The encoder processes each word sequentially, updating its internal hidden state
    - The final hidden state summarizes the whole input sentence
    - The decoder uses thisfinal state to generate the translated text one word at a time

**Problem:**
- RNNs can't easily access earlier words once they have been processed
- They rely only on the final hidden state, which compresses all sentence information into one vector
- This causes loss of context

**Why Attention Was Invented:**
- Because RNNs struggled to retain long-term dependencies, researchers created attention mechanisms
- Attention allows the model to look back at all words in the input sequence - not just the last one - to make better, context-aware predictions

![](pic1.png)

## Attending to different parts of the input with self-attention

*Learning and coding the self-attention mechanism from scratch*

**The "SELF" in SELF-ATTENTION**
- The "self" means the model is paying attention to itself - to different parts of the same input sequence
- It learns how each token relates to every other token in that same sequence
- So, each word gets a context-aware representation that blends information from all other words in that sequence

### Simple self-attention mechanism without trainable weights
- Context vectors play a crucial role in self-attention
- Their purpose is to create enriched representations of each element in an input sequence  by incorporating information from all other elements in the sequence
- This is essential in LLMs, which need to understand the relationship and relevance of words in a sequence to each other
- Later, we will add trainable weights that help an LLM learn to construct these context vectors swo that they are relevant for the LLM to generate the next token

In [4]:
import torch

# Create embedding dimensions
inputs = torch.tensor([[0.43,0.15,0.89],
                     [0.55,0.87,0.66],
                     [0.57,0.85,0.64],
                     [0.22,0.58,0.33],
                     [0.77,0.25,0.10],
                     [0.05,0.8,0.55]])


query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])

for i,x_i  in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i,query)

print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


- Attention is a way for model to decide which parts of the input matter most for what it is doing right now

**Core Idea**
When processing a token, the model looks back at all tokens and assigns each a weight: pay a lot of attention to these, less to those. It then mixes information from all tokens using those weights to get a focused summary -> a context vector

**Self-attention:** Tokens attend to other tokens in the same sequence (used in Transformers to let every word see every other word)

**Cross-attention:** One sequence attends to another (e.g., decoder attending to encoder outputs in translation, or text attending to image features)

In [5]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print(attn_weights_2_tmp)
print(attn_weights_2_tmp.sum())

tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
tensor(1.0000)


In [6]:
# Turns a list of arbitrary scores into a probability distribution
# It is used in attention to convert similarity scores into attention weights that tell how much to focus on each token
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum()


attn_weights_naive = softmax_naive(attn_scores_2)
print(attn_weights_naive)
print(attn_weights_naive.sum())

# Softmax function ensures that the attention weights are always positivie
# This makes the output interpretable as probabilities or relative importance, where
# higher weights mean more focus on that token

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


In [7]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print(attn_weights_2)   
print(attn_weights_2.sum())

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


- After computing attention weights for each token -> we now build one summary vector - context vector by:
    1. Taking each token's embedding
    2. Multiplying it by its attention weight
    3. Adding them all up

- Query: What am I looking for?
    - A vector made from the current token. It encodes the need of this token - what information it wants from others

- Key: What do I contain?
    - A vector made from each token in a sequence. It summarizes what that token can offer to others

- How they are used:
    - For the current token, we compare its Query to every other token's Key. Big Q.K => That token is relevant. We softmax these similarities to get attention weights. Then we use those weights to mix the token's values into a single context vector for the current token

In [8]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


## Computing Attention Weights for all input tokens

![](pic2.png)

In [9]:
attn_scores = torch.empty(6,6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i,j] = torch.dot(x_i, x_j)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [10]:
# Normalize each row of attention scores to get attention weights
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [10]:
# Use these attention weights to compute all context vectors via matrix multiplication
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Implementing self-attention with trainable weights

- Scaled dot-product attention
- The self-attention mechanism with trainable weights builds on the previous concepts: we want to compute context vectors as weighted sums over the input vectors specific to a certain input element
- The most notable difference compared to the previous part is the introduction of weight matrices that are updated during model training
- These trainable wieght matrices are crucial so that the model can learn to produdce "good" context vectors --> train the LLM

### Computing the attention weights step by step
- Introduce 3 trainable weight matrices:
    1. W_k
    2. W_q
    3. W_v
- These matrices are used to project the embedded input tokens into query, key, and value vectors

![](pic3.png)

In [11]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2 

torch.manual_seed(123)
# Initialize weight matrices
# Use requires_grad=False to indicate these are not trainable in this example
# In practice, these would be trainable parameters -> use requires_grad=True

W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

In [12]:
query_2 = x_2 @ W_query
key_2   = x_2 @ W_key       
value_2 = x_2 @ W_value
print("Query vector:", query_2)
print("Key vector:", key_2)
print("Value vector:", value_2)

Query vector: tensor([-1.1729, -0.0048])
Key vector: tensor([-0.1142, -0.7676])
Value vector: tensor([0.4107, 0.6274])


In [13]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


![](pic%204.png)

### Implementing a compact self-attention Python class

In [14]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.randn(d_in, d_out))
        self.W_key   = nn.Parameter(torch.randn(d_in, d_out))
        self.W_value = nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        values = x @ self.W_value
        queries = x @ self.W_query

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[1]**0.5, dim=-1)

        context_vecs = attn_weights @ values
        return context_vecs


In [15]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)

print(sa_v1.forward(inputs))

tensor([[0.2845, 0.4071],
        [0.2854, 0.4081],
        [0.2854, 0.4075],
        [0.2864, 0.3974],
        [0.2863, 0.3910],
        [0.2860, 0.4039]], grad_fn=<MmBackward0>)


We can improve the *SelfAttention_v1* implementation further by utilizing PyTorch's nn.linear layers, which effectively perform matrix multiplication wehn the bias units are disabled.

nn.Linear has an **optimized weight** initialization scheme, contributing to more stable and effective model training

In [18]:
import torch.nn as nn

# nn.linear: Fully connected (dense) layer that performs linear transformation

class SelfAttention_v2(nn.Module):
    def __init__(self,d_in,d_out, qkv_bias = False):
        super().__init__()
        self.W_query = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in,d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[1]**0.5, dim=-1)

        context_vecs = attn_weights @ values
        return context_vecs


In [19]:
import torch

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in,d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


1. Casual Attention
This means the model can't peek into the future. When predicting the next word, it should only look at the previous words, not the ones that come after

2. Multi-head attention
Instead of one big attention calculation, the model uses multiple smaller attention layers (called "heads"). Each head learns to focus on different parts of the sentence or different kinds of relationships

## Hiding Future Words with Casual Attention

For many LLM task, we want the self-attention mechanism to consider only the tokens that appear prior to the current position when predicting the next token in a sequence -> Use casual attention

It restricts a model to only consider previous and current inputs in a sequence when processing any given token when computing attention scores. This is in contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once

![](pic5.png)

### Applying a casual Attention Mask

In [20]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [28]:
# Create a mask where values above the diagonal are zero
# torch.tril(input, diagonal=0,*,out=None): Creates a lower triangular matrix
# torch.ones(size,*,out=None,dtype=None,...): Creates a tensor filled entirely with 1s

context_length = attn_scores.shape[0]
ones_layer = torch.ones(context_length,context_length)
print(ones_layer)
print("\n")
mask_simple = torch.tril(ones_layer)
print(mask_simple)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])


tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [29]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [30]:
# Renormalize the attention weights to sum up to 1 again in each row
row_sums = masked_simple.sum(dim=-1,keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [32]:
# Implement more efficient mask by creating a mask with 1s above the diagonal and then replacing the 1s with negative infinity

mask = torch.triu(ones_layer,diagonal = 1)
print(mask)
print("\n")
masked = attn_scores.masked_fill(mask.bool(),-torch.inf)
print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [34]:
attn_weights = torch.softmax(masked/keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### Masking additional attention weights with dropout

**Dropout:** Is a technique where randomly selected hidden layer units are ignored during training, effectively "dropping" them out.

This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units -> dropout is only used during training

Dropout in attention mechnaism is typically applied at two specific times: after calculating the attention weights or after applying the attention weights to the value vectors.

In [35]:
torch.manual_seed(123)

dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [36]:
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4925, 0.4638, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3941, 0.0000],
        [0.3869, 0.3327, 0.0000, 0.3084, 0.3331, 0.3058]],
       grad_fn=<MulBackward0>)


### Implementing a compact casual attention class

Incorporate the casual attention and dropout modifications into the SelfAttention Python Class -> serve as a template for developing multi-head attention -> final attention class

In [37]:
# torch.stack() = joins a sequence of tensors along a new dimension -> adds a new axis and stacks tensors on top of each other like layers

batch = torch.stack((inputs,inputs),dim=0)
print(batch.shape)
print(batch)

torch.Size([2, 6, 3])
tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [40]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out,context_length,dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)

        # Adds a tensor to your model that is not a learnable parameter but still gets saved with the model
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b,num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens],-torch.inf)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec
        


In [41]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CasualAttention(d_in, d_out, context_length, dropout=0.0)
context_vecs = ca(batch)
print(context_vecs.shape)

torch.Size([2, 6, 2])


## Extending single-head attention to multi-head attention

Dividing the attention mechanism into multiple "heads," each operating independently. In this context, a single casual attention module can be considered single-attention, where there is only one set of attention weights processing the input sequentially



### Stacking multiple single-head attention layers

Implementing multi-head attention involves creating multiple instances of the self-attention mechanism, each with its own weights, and then combining their outputs

![](pic6.png)

In [42]:
# Main idea of multi-head is to run attention mechanism multiple times in parallel
# Achieve this by implementing a simple wrapper class that stacks multiple instances of CasualAttention class

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        # nn.ModuleList: A list that can store nn.Module objects -> ensures all modules are properly registered and their parameters are tracked
        # Designed to hold multiple layers in a way that Pytorch can track and train them
        self.heads = nn.ModuleList([
            CasualAttention(d_in, d_out, context_length, dropout, qkv_bias)
            for _ in range(num_heads)
        ])

    # torch.cat(): concatenates tensors along an exisiting dimension
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


In [43]:
torch.manual_seed(123)
context_length = batch.shape[1] # Number of tokens
d_in, d_out = 3,2 # Input and output dimensions
mha = MultiHeadAttention(d_in, d_out, context_length, dropout=0.0, num_heads=4)
context_vecs = mha(batch)
print(context_vecs.shape)
print(context_vecs)

torch.Size([2, 6, 8])
tensor([[[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729, -0.5684,
           0.5063],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011, -0.5388,
           0.6447],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102, -0.5242,
           0.6954],
         [-0.5675, -0.0843,  0.5478,  0.3589,  0.5691,  0.2785, -0.4578,
           0.6471],
         [-0.5526, -0.0981,  0.5321,  0.3428,  0.5543,  0.2520, -0.4006,
           0.5921],
         [-0.5299, -0.1081,  0.5077,  0.3493,  0.5337,  0.2499, -0.3997,
           0.5971]],

        [[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729, -0.5684,
           0.5063],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011, -0.5388,
           0.6447],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102, -0.5242,
           0.6954],
         [-0.5675, -0.0843,  0.5478,  0.3589,  0.5691,  0.2785, -0.4578,
           0.6471],
         [-0.5526, -0.0981,  0.5321,  0.3428, 

### Implementing multi-head attention with weight splits

In the new class, you won't need separate CasualAttention objects anymore.

Instead, it will:
1. Split the input into multiple heads (by reshaping the data)
2. Compute attention for all heads in one go
3. Merge the results back together

In [None]:
class MultiHeadAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias =False):
        # Calls the constructor of the parent class nn.Module
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        
        # Initialize weights and biases for query, key, and value projections
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Contain the combined results from all attention heads
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',

            # torch.triu(input, diagonal=0,*,out=None): Creates a upper triangular matrix
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        # Compute queries, keys, and values
        # Q = x @ W_query
        # K = x @ W_key
        # V = x @ W_value
        # nn.Linear layers handle the matrix multiplications internally batch by batch
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # Reshape to (batch_size, num_tokens, num_heads, head_dim)
        # .view(): Returns a new tensor with the same data as the original tensor but with a different shape
        """
        batch
            head 0
                token 0 -> head_dim
                token 1 -> head_dim
                ...
            
            head 1
                token 0 -> head_dim
                token 1 -> head_dim
                ...
            ...
        
        """
        # input: (b, num_tokens, d_out)
        # b: batch size
        # num_tokens: number of tokens in the sequence
        # d_out: output dimension - aka number of features per token after projection
        # Before view: Each token has d_out features but these 8 must be split into num_heads heads, each of head_dim features

        # .view(b, num_tokens, self.num_heads, self.head_dim): Reshape the tensor to have separate dimensions for heads
        # .transpose(1,2): Swap the num_tokens and num_heads dimensions to facilitate attention computation

        """
        Example: If a token embedding is: [v0,v1,v2,v3,v4,v5,v6,v7] and num_heads=2, head_dim=4
        After view:
        [v0,v1,v2,v3]  # head 0
        [v4,v5,v6,v7]  # head 1
        After transpose:
        (batch, heads,tokens, head_dim) )
        """
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1,2)

        attn_scores = queries @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vecs = attn_weights @ values

        context_vecs = context_vecs.transpose(1,2).contiguous().view(b, num_tokens, self.d_out)
        output = self.out_proj(context_vecs)
        return output


In [None]:
# # Multihead - rewrite - learn

# class MultiHeadAttention_v3(nn.Module):
#     def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
#         super().__init__()
#         assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
#         self.d_out = d_out
#         self.num_heads = num_heads
#         self.head_dim = d_out // num_heads

#         # Initialize weights and biases for QKV
#         self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
#         self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
#         self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

#         # Output projection layer
#         self.out_proj = nn.Linear(d_out, d_out)
#         # Dropout layer for regularization
#         self.dropout = nn.Dropout(dropout)

#         # Register a buffer for the causal mask
#         self.register_buffer(
#             'mask',
#             torch.triu(torch.ones(context_length, context_length), diagonal=1)
#         )

#         # Feedforward method to compute the output of the multi-head attention
#         def forward(self, x):
#             # b: batch size
#             # num_tokens: number of tokens in the sequence
#             # d_in: input dimension
#             b, num_tokens, d_in = x.shape

#             # Compute Q, K, V
#             # No need for explicit matrix multiplication, nn.Linear handles it internally
#             keys = self.W_key(x)
#             queries = self.W_query(x)
#             values = self.W_value(x)

#             # Reshape and transpose for multi-head attention
#             keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
#             queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
#             values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose

#             # Calculating attention scores
#             attn_scores = queries @ keys.transpose(2, 3)
#             # Make the mask boolean for current number of tokens 
#             mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
#             # Attention scores masking
#             attn_scores.masked_fill_(mask_bool, -torch.inf)
#             # Turning scores into weights
#             attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
#             # Apply dropout to attention weights
#             attn_weights = self.dropout(attn_weights)


#             # Construct context vectors
#             context_vecs = attn_weights @ values

#             context_vecs = context_vecs.transpose(1, 2).contiguous().view(b, num_tokens, self.d_out)
#             output = self.out_proj(context_vecs)
#             return output



In [8]:
import torch
x = torch.randn(2,4,3)
torch.manual_seed(123)

print(x[0])
print("\n")
print(x[1])
print("\n")
print(x)

b, num_tokens, d_in = x.shape
print("b:", b)
print("num_tokens:", num_tokens)
print("d_in:", d_in)

tensor([[ 0.2403, -0.5516, -0.5697],
        [ 1.0076, -0.0770, -1.0205],
        [-0.1690,  0.9178, -0.3885],
        [-0.9343, -0.4991, -1.0867]])


tensor([[ 0.9624,  0.2492, -0.4845],
        [-2.0929,  0.0983, -0.0935],
        [ 0.2662, -0.5850, -0.3430],
        [-0.6821, -0.9887, -1.7018]])


tensor([[[ 0.2403, -0.5516, -0.5697],
         [ 1.0076, -0.0770, -1.0205],
         [-0.1690,  0.9178, -0.3885],
         [-0.9343, -0.4991, -1.0867]],

        [[ 0.9624,  0.2492, -0.4845],
         [-2.0929,  0.0983, -0.0935],
         [ 0.2662, -0.5850, -0.3430],
         [-0.6821, -0.9887, -1.7018]]])
b: 2
num_tokens: 4
d_in: 3


In [13]:
import torch.nn as nn

W_query = nn.Linear(3,8)
print(W_query.weight)
print("\n")
queries = W_query(x)
print(queries)
print(queries.shape)

Parameter containing:
tensor([[ 0.5144,  0.3530,  0.2049],
        [ 0.1255,  0.1361,  0.2230],
        [-0.0746, -0.5366, -0.3570],
        [ 0.4928,  0.0345, -0.4677],
        [ 0.0911,  0.4770, -0.5456],
        [-0.3887, -0.2299,  0.0232],
        [-0.1347, -0.0634, -0.5629],
        [ 0.2704,  0.5068,  0.3529]], requires_grad=True)


tensor([[[-0.5967, -0.6374,  0.7212,  0.3788,  0.3062, -0.5439,  0.2889,
          -0.0084],
         [-0.1268, -0.5771,  0.5702,  0.9841,  0.8485, -0.9617,  0.4092,
           0.2804],
         [-0.2514, -0.4483, -0.1014,  0.1431,  0.8711, -0.7185,  0.1488,
           0.6895],
         [-1.2884, -0.8930,  0.9653,  0.0435,  0.5063, -0.1114,  0.7347,
          -0.4819]],

        [[ 0.0749, -0.4188,  0.2072,  0.7225,  0.7075, -1.0068,  0.0929,
           0.6227],
         [-1.4700, -0.7355,  0.3765, -0.9713,  0.1440,  0.2246,  0.2938,
          -0.1419],
         [-0.5487, -0.5882,  0.6563,  0.2844,  0.1690, -0.5411,  0.1599,
           0.0616],
      

In [None]:
# .view(b,num_tokens,num_heads, head_dim)
# head_dim: number of features per head
# num_heads: number of parallel attention heads
# b: batch size
# num_tokens: number of tokens per batch
queries_reshaped = queries.view(2,4,2,4)
print(queries_reshaped)

print("After transpose\n")

print(queries_reshaped.transpose(1,2))

tensor([[[[-0.5967, -0.6374,  0.7212,  0.3788],
          [ 0.3062, -0.5439,  0.2889, -0.0084]],

         [[-0.1268, -0.5771,  0.5702,  0.9841],
          [ 0.8485, -0.9617,  0.4092,  0.2804]],

         [[-0.2514, -0.4483, -0.1014,  0.1431],
          [ 0.8711, -0.7185,  0.1488,  0.6895]],

         [[-1.2884, -0.8930,  0.9653,  0.0435],
          [ 0.5063, -0.1114,  0.7347, -0.4819]]],


        [[[ 0.0749, -0.4188,  0.2072,  0.7225],
          [ 0.7075, -1.0068,  0.0929,  0.6227]],

         [[-1.4700, -0.7355,  0.3765, -0.9713],
          [ 0.1440,  0.2246,  0.2938, -0.1419]],

         [[-0.5487, -0.5882,  0.6563,  0.2844],
          [ 0.1690, -0.5411,  0.1599,  0.0616]],

         [[-1.4576, -1.0652,  1.4288,  0.4386],
          [ 0.6314, -0.1111,  1.0780, -0.8789]]]], grad_fn=<ViewBackward0>)
After transpose

tensor([[[[-0.5967, -0.6374,  0.7212,  0.3788],
          [-0.1268, -0.5771,  0.5702,  0.9841],
          [-0.2514, -0.4483, -0.1014,  0.1431],
          [-1.2884, -0.8930