# Coding Attention in PyTorch!!!

By Josh Starmer

<p style="background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>

---- 

In this tutorial, we will code a class that is capable of all **3** types of **Attention** that we have studied, **Self-Attention**, **Masked Self-Attention**, and **Encoder-Decoder Attention**. We'll also code a few lines that will make **Multi-Headed Attention** work.

In this tutorial, you will...

- **[Code an Attention Class!!!](#attention)** This class will be able to perform **Self-Attention**, **Masked-Self Attention**, and **Encoder-Decoder Attention**.

- **[Calculate Encoder-Decoder Attention Values!!!](#calculate)** We'll then use the class that we created, Attention, to calculate **Encoder-Decoder Attention** values for some sample data.
 
- **[Code Multi-Head Attention!!!](#multi)** We'll code **Multi-Head Attention**.

- **[Calculate Mult-Head Attention!!!!](#calcMulti)** Lastly, we calculate **Multi-Head Attention** values for some sample data.


----

# Import the modules that will do all the work

In [1]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.module() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax()

<p style="background-color:#fff6ff; padding:15px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px"> 💻 &nbsp; <b>Access <code>requirements.txt</code> file:</b> 1) click on the <em>"File"</em> option on the top menu of the notebook and then 2) click on <em>"Open"</em>. For more help, please see the <em>"Appendix - Tips and Help"</em> Lesson.</p>

----

# Code Attention
<a id="attention"></a>

## Again Masked self-attention for experimentation

In [15]:
class MaskedSelfAttention(nn.Module): 
                            
    def __init__(self, d_model=2,  
                 row_dim=0, 
                 col_dim=1):
        
        super().__init__()
        
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        self.row_dim = row_dim
        self.col_dim = col_dim

        
    def forward(self, token_encodings, mask=None):

        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            ## Here we are masking out things we don't want to pay attention to
            ##
            ## We replace values we wanted masked out
            ## with a very small negative number so that the SoftMax() function
            ## will give all masked elements an output value (or "probability") of 0.
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9) # I've also seen -1e20 and -9e15 used in masking

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## Attention implemneted for this chapter

In [3]:

class Attention(nn.Module): 
                            
    def __init__(self, d_model=2,  
                 row_dim=0, 
                 col_dim=1):
        
        super().__init__()
        
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        self.row_dim = row_dim
        self.col_dim = col_dim


    ## The only change from SelfAttention and attention is that
    ## now we expect 3 sets of encodings to be passed in...
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        ## ...and we pass those sets of encodings to the various weight matrices.
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
            
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## My Previous implementation of attention

In [21]:
class MyAttention(nn.Module): 
                            
    def __init__(self, 
                 emb_d=2,
                 token_emb_dim_idx=0,
                 emb_dim_idx=1):
        """
        Basically upgrades the entering token encodings by:
        1. Calculating similarity scores with the other tokens by the information learned in
        two matrices: q and k
        2. Scale those scores by sqrt(emb_dim_of_k)
        3. Calculate weights out of those scores with softmax
        4. Calculate the influence of those scores on the values though the info learn in v.
        
        emb_d : the number of embedding values per token.
        token_emb_dim_idx: the index of the dimension signaling each token (rows)
        emb_dim_idx: the index of the dimension of the embedding size (columns)
        """
        
        super().__init__()

        self.emb_d = emb_d
        self.token_emb_dim_idx=token_emb_dim_idx
        self.emb_dim_idx=emb_dim_idx
        self.q_w = nn.Linear(in_features=emb_d, out_features=emb_d, bias=False)
        self.k_w = nn.Linear(in_features=emb_d, out_features=emb_d, bias=False)
        self.v_w = nn.Linear(in_features=emb_d, out_features=emb_d, bias=False)
                
    def forward(self, token_encodings: torch.Tensor, mask: torch.Tensor = None):
        q = self.q_w(token_encodings)
        print(q.shape)  # Note the shape is 3x2 now, as token_encodings is a 3x2 matrix which is 
        # multiplied by q, which is a 2x2 matrix
        k = self.k_w(token_encodings)
        v = self.v_w(token_encodings)
        
        similarity_scores = torch.matmul(q, torch.transpose(k, dim0=self.token_emb_dim_idx, dim1=self.emb_dim_idx))
        scaled_scores = similarity_scores / torch.tensor(self.emb_d ** 0.5)
        if mask is not None:
            scaled_scores.masked_fill_(mask, value=-9e15)
        softmax_scores = F.softmax(scaled_scores, dim=self.emb_dim_idx)
        weighted_scores = torch.matmul(softmax_scores, v)
        return weighted_scores
    
## create a matrix of token encodings (which are word embedings + positional encoding)...
## The first dimension of the encoding matrix refers to the token encodings themselves
## and the second dimension refers to the embedding dimensions of each token
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])
print(encodings_matrix)

torch.manual_seed(42)
mask = torch.ones((3,3))
mask = torch.tril(mask)
mask = mask == 0

print(f"mask: {mask}")
print(encodings_matrix.shape[1])
my_attention = MyAttention(encodings_matrix.shape[1])
print(f"My attention: {my_attention(encodings_matrix, mask)}")
torch.manual_seed(42)
attention = MaskedSelfAttention(encodings_matrix.shape[1])
print(f"This Attention: {attention(encodings_matrix, mask)}")

tensor([[ 1.1600,  0.2300],
        [ 0.5700,  1.3600],
        [ 4.4100, -2.1600]])
mask: tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
2
torch.Size([3, 2])
My attention: tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)
Attention: tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)


# BAM!

----

# Calculate Encoder-Decoder Attention
<a id="calculate"></a>

In [22]:
## create matrices of token encodings...
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
attention = Attention(d_model=2,
                      row_dim=0,
                      col_dim=1)

## calculate encoder-decoder attention
attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

----

# Code Mutli-Head Attention
<a id="multi"></a>

In [24]:
class MultiHeadAttention(nn.Module):

    def __init__(self, 
                 d_model=2,  
                 row_dim=0, 
                 col_dim=1, 
                 num_heads=1):
        
        super().__init__()

        ## create a bunch of attention heads
        self.heads = nn.ModuleList(
            [Attention(d_model, row_dim, col_dim) 
             for _ in range(num_heads)]
        )

        self.col_dim = col_dim
        
    def forward(self, 
                encodings_for_q, 
                encodings_for_k,
                encodings_for_v):

        ## run the data through all of the attention heads
        return torch.cat([head(encodings_for_q, 
                               encodings_for_k,
                               encodings_for_v) 
                          for head in self.heads], dim=self.col_dim)

In [23]:
class MyMultiHeadAttention(nn.Module): 
                            
    def __init__(self, 
                 emb_dim=2,
                 token_emb_dim_idx=0,
                 emb_dim_idx=1, 
                 num_heads=1):
        """
        Basically upgrades the entering token encodings by:
        1. Calculating similarity scores with the other tokens by the information learned in
        two matrices: q and k
        2. Scale those scores by sqrt(emb_dim_of_k)
        3. Calculate weights out of those scores with softmax
        4. Calculate the influence of those scores on the values though the info learn in v.
        
        emb_dim : the number of embedding values per token.
        token_emb_dim_idx: the index of the dimension signaling each token (rows)
        emb_dim_idx: the index of the dimension of the embedding size (columns)
        """
        
        super().__init__()

        self.num_heads = num_heads
        self.head_dim = emb_dim
        self.q_w = nn.Linear(in_features=emb_dim, out_features=emb_dim * num_heads, bias=False)
        self.k_w = nn.Linear(in_features=emb_dim, out_features=emb_dim * num_heads, bias=False)
        self.v_w = nn.Linear(in_features=emb_dim, out_features=emb_dim * num_heads, bias=False)
        self.out_proj = nn.Linear(in_features=emb_dim * num_heads, out_features=emb_dim, bias=False)
                
    def forward(self, token_encodings: torch.Tensor, mask: torch.Tensor = None):        
        seq_len, _ = token_encodings.shape  # Expecting (T, D)
        q = self.q_w(token_encodings)
        # multiplied by q, which is a 2x2 matrix
        k = self.k_w(token_encodings)
        v = self.v_w(token_encodings)
        print(f"Q shape as input: {q.shape}")  # Note the shape is 3x2 now, as token_encodings is a 3x2 matrix which is 
        
         # Reshape into multiple heads: (T, emb_dim) → (T, num_heads, head_dim) → (num_heads, T, head_dim), so (3,2) to (3,1,2) to (1, 3, 2)
        q = q.view(seq_len, self.num_heads, self.head_dim)
        k = k.view(seq_len, self.num_heads, self.head_dim)
        v = v.view(seq_len, self.num_heads, self.head_dim)
        print(f"Q shape after reshaping: {q.shape}")  # Has to be (1, 3, 2)
        
        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)
        print(f"Q shape after transponsing: {q.shape}")  # Has to be (1, 3, 2) (num_heads, T, head_dim)
        
        similarity_scores = torch.matmul(q, k.transpose(dim0=-2, dim1=-1))
        scaled_scores = similarity_scores / (self.head_dim ** 0.5)
        print(f"scaled scores before ({scaled_scores.shape})") #:\n{scaled_scores}")
        
        if mask is not None:
            scaled_scores = scaled_scores.masked_fill(mask.unsqueeze(0) == False, value=-1e9)
            print(f"scaled scores after ({scaled_scores.shape}):\n{scaled_scores}")
        
        softmax_scores = F.softmax(scaled_scores, dim=-1)
        print(f"softmax scores ({scaled_scores.shape}):\n") #v scores: {v.shape}")
        
        weighted_scores = torch.matmul(softmax_scores, v)
        
        # Reshape back: (T, num_heads, head_dim) → (T, emb_dim)
        weighted_scores = weighted_scores.transpose(0,1).reshape(seq_len, self.num_heads * self.head_dim)
        
        # Final projection back to original embedding dimension
        output = self.out_proj(weighted_scores)  # (T, D)
        return output
    
## create a matrix of token encodings (which are word embedings + positional encoding)...
## The first dimension of the encoding matrix refers to the token encodings themselves
## and the second dimension refers to the embedding dimensions of each token
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

torch.manual_seed(42)
N_HEADS = 2
mask = torch.ones((3,3))
mask = torch.tril(mask)
mask = mask == 0
print(f"mask: {mask}\n{mask.shape}")

my_attention = MyMultiHeadAttention(encodings_matrix.shape[1], num_heads=N_HEADS)
print(f"My multihead attention with no mask: {my_attention(encodings_matrix)}")


mask: tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
torch.Size([3, 3])
Q shape as input: torch.Size([3, 4])
Q shape after reshaping: torch.Size([3, 2, 2])
Q shape after transponsing: torch.Size([2, 3, 2])
scaled scores before (torch.Size([2, 3, 3]))
softmax scores (torch.Size([2, 3, 3])):

My multihead attention with no mask: tensor([[-0.1703,  0.7427],
        [-0.0230,  1.0481],
        [-0.2958,  0.2219]], grad_fn=<MmBackward0>)


In [None]:
torch.manual_seed(42)
m_attention = MaskedSelfAttention(encodings_matrix.shape[1])
print(f"Masked Attention with no mask: {m_attention(encodings_matrix)}")


In [66]:
torch.manual_seed(42)
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])
regular_attention = Attention(encodings_matrix.shape[1])
print(f"Attention: {regular_attention(encodings_for_q, encodings_for_k, encodings_for_v)}")
torch.manual_seed(42)
mh_attention = MultiHeadAttention(encodings_matrix.shape[1], num_heads=N_HEADS)
print(f"Tihs MultiHeadAttention: {mh_attention(encodings_for_q, encodings_for_k, encodings_for_v)}")


Attention: tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)
Tihs MultiHeadAttention: tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)


----

# Calcualte Multi-Head Attention
<a id="calcMulti"></a>

First, verify that we can still correctly calculate attention with a single head...

In [25]:
## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=1)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)

Second, calculate attention with multiple heads...

In [26]:
## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=2)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)

# TRIPLE BAM!!

# Another MH Implementation

### Option 1. Split the embedding dimension (standard approach)

Set things up so that the projections go from [T, emb_dim] to [T, emb_dim] (not to [T, emb_dim × num_heads]), then split that into num_heads pieces. This means each head gets a slice of size

\text{head\_dim} = \frac{\text{emb\_dim}}{\text{num\_heads}}

so that the total remains emb_dim. In this case, if you set num_heads = 1 you get the full embedding per token, but if you set num_heads = 2 each head sees half the features. (That’s why head 0 from a 2‐head model won’t match a single-head model—the projections are done on a smaller slice.) This is the most common approach in Transformer models.

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyMultiHeadAttention(nn.Module):
    def __init__(self, emb_dim=2, token_emb_dim_idx=0, emb_dim_idx=1, num_heads=1):
        super().__init__()
        assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads

        # These layers now map from [T, emb_dim] to [T, emb_dim]
        self.q_w = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=False)
        self.k_w = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=False)
        self.v_w = nn.Linear(in_features=emb_dim, out_features=emb_dim, bias=False)
    
    def forward(self, token_encoding, mask=None, concatenate_heads: bool = True):
        T = token_encoding.size(0)  # assuming tokens are in dimension 0
        # 1. Project to queries, keys, and values: shape [T, emb_dim]
        q = self.q_w(token_encoding)
        k = self.k_w(token_encoding)
        v = self.v_w(token_encoding)
        
        # 2. Split into heads: reshape to [T, num_heads, head_dim]
        q = q.view(T, self.num_heads, self.head_dim)
        k = k.view(T, self.num_heads, self.head_dim)
        v = v.view(T, self.num_heads, self.head_dim)
        
        # 3. Transpose to get [num_heads, T, head_dim]
        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)
        
        # 4. Scaled dot-product attention per head
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # shape: [num_heads, T, head_dim]
        
        if concatenate_heads:
            # 5. Concatenate heads back: reshape to [T, emb_dim]
            attn_output = attn_output.transpose(0, 1).reshape(T, self.emb_dim)
        return attn_output, attn_weights

torch.manual_seed(42)
token_encoding = torch.tensor([[1.16, 0.23],
                               [0.57, 1.36],
                               [4.41, -2.16]])
# With num_heads=1, head_dim = 2 → same as single-head attention.
mha_single = MyMultiHeadAttention(emb_dim=2, num_heads=1)
out_single, attn_weigths = mha_single(token_encoding)
print(f"Single-head output:\n{out_single}\nWeights:\n{attn_weights}")

torch.manual_seed(42)
# With num_heads=2, head_dim = 1 → each head gets half the features.
mha_multi = MyMultiHeadAttention(emb_dim=2, num_heads=2)
out_multi, attn_weigths = mha_multi(token_encoding)
print(f"\nMulti-head (2 heads) output:\n{out_multi}\nWeights:\n{attn_weights}")

torch.manual_seed(42)
out_multi, attn_weigths = mha_multi(token_encoding, concatenate_heads=False)
print(f"\nMulti-head (2 heads) output:\n{out_multi}\nWeights:\n{attn_weights}")

Single-head output:
tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<ViewBackward0>)
Weights:
tensor([[[0.1403, 0.0845, 0.7752],
         [0.0292, 0.0123, 0.9586],
         [0.3715, 0.2413, 0.3872]],

        [[0.3656, 0.3994, 0.2350],
         [0.2863, 0.2599, 0.4538],
         [0.3330, 0.6550, 0.0119]]], grad_fn=<SoftmaxBackward0>)

Multi-head (2 heads) output:
tensor([[0.8080, 1.2760],
        [0.6116, 0.7280],
        [0.6064, 2.4013]], grad_fn=<ReshapeAliasBackward0>)
Weights:
tensor([[[0.1403, 0.0845, 0.7752],
         [0.0292, 0.0123, 0.9586],
         [0.3715, 0.2413, 0.3872]],

        [[0.3656, 0.3994, 0.2350],
         [0.2863, 0.2599, 0.4538],
         [0.3330, 0.6550, 0.0119]]], grad_fn=<SoftmaxBackward0>)

Multi-head (2 heads) output:
tensor([[[0.8080],
         [0.6116],
         [0.6064]],

        [[1.2760],
         [0.7280],
         [2.4013]]], grad_fn=<UnsafeViewBackward0>)
Weights:
tensor([[[0.1403, 0.0845, 0.7752],
         

### Option 2. Preserve full embedding per head and use a final projection

Goal:t one of the heads (say head 0) behaves exactly as the single-head attention does, then each head must process the full embedding. In that case, linear layers will map from [T, emb_dim] to [T, emb_dim * num_heads] (so that each head gets an emb_dim-dimensional representation), and then—after computing attention independently per head—you combine them with a final linear layer that projects from the concatenated space back to emb_dim. 

This means if you set num_heads = 1, you get the same computation as the single-head case, and for num_heads = 2 the first head’s result will match the single-head’s output (provided you initialize that head’s parameters identically).

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyMultiHeadAttention(nn.Module):
    def __init__(self, emb_dim=2, token_emb_dim_idx=0, emb_dim_idx=1, num_heads=1):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads        
        # Here we keep each head's dimension equal to emb_dim.
        self.head_dim = emb_dim  
        # Project from [T, emb_dim] to [T, emb_dim * num_heads]
        self.q_w = nn.Linear(in_features=emb_dim, out_features=emb_dim * num_heads, bias=False)
        self.k_w = nn.Linear(in_features=emb_dim, out_features=emb_dim * num_heads, bias=False)
        self.v_w = nn.Linear(in_features=emb_dim, out_features=emb_dim * num_heads, bias=False)
        # Final projection to bring concatenated heads back to emb_dim.
        self.out_proj = nn.Linear(in_features=emb_dim * num_heads, out_features=emb_dim, bias=False)
    
    def forward(self, token_encoding, mask=None, concatenate_heads: bool = True, project: bool = False):
        T = token_encoding.size(0)
        q = self.q_w(token_encoding)  # shape: [T, emb_dim*num_heads]
        k = self.k_w(token_encoding)
        v = self.v_w(token_encoding)
        
        # Reshape to separate heads: [T, num_heads, head_dim]
        q = q.view(T, self.num_heads, self.head_dim)
        k = k.view(T, self.num_heads, self.head_dim)
        v = v.view(T, self.num_heads, self.head_dim)
        
        # Transpose: [num_heads, T, head_dim]
        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # shape: [num_heads, T, head_dim]
        
        if concatenate_heads:            
            # Concatenate heads: [T, num_heads * head_dim]
            attn_output = attn_output.transpose(0, 1).reshape(T, self.num_heads * self.head_dim)

        # Final projection: [T, emb_dim]
        if project: 
            attn_output = self.out_proj(attn_output)
        return attn_output, attn_weights

# Example usage:
torch.manual_seed(42)
token_encoding = torch.tensor([[1.16, 0.23],
                               [0.57, 1.36],
                               [4.41, -2.16]])
# Single-head attention
mha_single = MyMultiHeadAttention(emb_dim=2, num_heads=1)
out_single, attn_weights = mha_single(token_encoding)
print(f"Single-head output:\n{out_single}\nWeights:\n{attn_weights}")

torch.manual_seed(42)
# Two-head attention – if you initialize head0 to match the single-head parameters,
# then after the final projection the corresponding contribution from head0 can match.
mha_multi = MyMultiHeadAttention(emb_dim=2, num_heads=2)
out_multi, attn_weights = mha_multi(token_encoding, concatenate_heads=False)
print(f"\nMulti-head (2 heads) output:\n{out_multi}\nWeights:\n{attn_weights}")

torch.manual_seed(42)
out_multi, attn_weights = mha_multi(token_encoding, concatenate_heads=True)
print(f"\nMulti-head (2 heads) output:\n{out_multi}\nWeights:\n{attn_weights}")

torch.manual_seed(42)
out_multi, attn_weights = mha_multi(token_encoding, concatenate_heads=True, project=True)
print(f"\nMulti-head (2 heads) output:\n{out_multi}\nWeights:\n{attn_weights}")

Single-head output:
tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<ViewBackward0>)
Weights:
tensor([[[0.3573, 0.4011, 0.2416],
         [0.3410, 0.6047, 0.0542],
         [0.0722, 0.0320, 0.8959]]], grad_fn=<SoftmaxBackward0>)

Multi-head (2 heads) output:
tensor([[[ 1.8188, -1.4734],
         [ 2.1126, -1.7779],
         [ 1.1966, -0.8276]],

        [[-0.5599, -0.4288],
         [-0.7620, -0.9759],
         [-0.3427,  0.2084]]], grad_fn=<UnsafeViewBackward0>)
Weights:
tensor([[[0.1403, 0.0845, 0.7752],
         [0.0292, 0.0123, 0.9586],
         [0.3715, 0.2413, 0.3872]],

        [[0.3656, 0.3994, 0.2350],
         [0.2863, 0.2599, 0.4538],
         [0.3330, 0.6550, 0.0119]]], grad_fn=<SoftmaxBackward0>)

Multi-head (2 heads) output:
tensor([[ 1.8188, -1.4734, -0.5599, -0.4288],
        [ 2.1126, -1.7779, -0.7620, -0.9759],
        [ 1.1966, -0.8276, -0.3427,  0.2084]], grad_fn=<UnsafeViewBackward0>)
Weights:
tensor([[[0.1403, 0.0845, 0.7752]