## Naive Self Attention without Trainable weights.

In [15]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

query = inputs[1]

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)

print("Attention scores:", attn_scores_2)


Attention scores: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [12]:
# generic normalization
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights (generic normalization):", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights (generic normalization): tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [13]:
# naive softmax function
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights (naive softmax):", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights (naive softmax): tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [14]:
# standard torch softmax
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights (torch softmax):", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights (torch softmax): tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [19]:
# context evctor
query = inputs[1]
context_vector_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vector_2 += attn_weights_2[i] * x_i
    print(f"Step {i}:", context_vector_2, " (attention ", attn_weights_2[i], "; input ", x_i, ")")
print("Context vector:", context_vector_2)

Step 0: tensor([0.0596, 0.0208, 0.1233])  (attention  tensor(0.1385) ; input  tensor([0.4300, 0.1500, 0.8900]) )
Step 1: tensor([0.1904, 0.2277, 0.2803])  (attention  tensor(0.2379) ; input  tensor([0.5500, 0.8700, 0.6600]) )
Step 2: tensor([0.3234, 0.4260, 0.4296])  (attention  tensor(0.2333) ; input  tensor([0.5700, 0.8500, 0.6400]) )
Step 3: tensor([0.3507, 0.4979, 0.4705])  (attention  tensor(0.1240) ; input  tensor([0.2200, 0.5800, 0.3300]) )
Step 4: tensor([0.4340, 0.5250, 0.4813])  (attention  tensor(0.1082) ; input  tensor([0.7700, 0.2500, 0.1000]) )
Step 5: tensor([0.4419, 0.6515, 0.5683])  (attention  tensor(0.1581) ; input  tensor([0.0500, 0.8000, 0.5500]) )
Context vector: tensor([0.4419, 0.6515, 0.5683])


In [26]:
"""
Steps - 
1. Compute attention scores as dot products between query and each input.
2. Normalize attention scores using PyTorch softmax.
3. Compute context vector as weighted sum of inputs using attention weights.
"""

attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print("Attention scores matrix:\n", attn_scores)

# for loops ar slow - use matrix multiplication
attn_scores_matmul = inputs @ inputs.T
print("Attention scores matrix (matrix multiplication):\n", attn_scores_matmul)

# normalize attention scores using softmax along rows
attn_weights_matmul = torch.softmax(attn_scores_matmul, dim=-1)
print("Normalized Attention weights matrix:\n", attn_weights_matmul)

# compute context vectors for all inputs
context_vectors = attn_weights_matmul @ inputs
print("Context vectors:\n", context_vectors)

Attention scores matrix:
 tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
Attention scores matrix (matrix multiplication):
 tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
Normalized Attention weights matrix:
 tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.143

## Self Attention with trainable weights. (scaled dot-product attention)

In [None]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print("Query vector:", query_2)
print("Key vector:", key_2)
print("Value vector:", value_2)

# require key and value vectors for all inputs
keys = inputs @ W_key
values = inputs @ W_value
print("Keys matrix:\n", keys, "shape:", keys.shape)
print("Values matrix:\n", values, "shape:", values.shape)

Query vector: tensor([0.4306, 1.4551])
Key vector: tensor([0.4433, 1.1419])
Value vector: tensor([0.3951, 1.0037])
Keys matrix:
 tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]]) shape: torch.Size([6, 2])
Values matrix:
 tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]]) shape: torch.Size([6, 2])


In [35]:
# attention score w22
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print("Attention score w22:", attn_scores_22)

# generalize attention scores for all
attn_scores_2 = query_2 @ keys.T
print("Attention scores for query 2:", attn_scores_2)

d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim=-1)
print("Attention weights for query 2:", attn_weights_2)

# context vector for query 2
context_vector_2 = attn_weights_2 @ values
print("Context vector for query 2:", context_vector_2)

Attention score w22: tensor(1.8524)
Attention scores for query 2: tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
Attention weights for query 2: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
Context vector for query 2: tensor([0.3061, 0.8210])


In [47]:
# Compact self attention class 

import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_Scores = queries @ keys.T
        attn_weights = torch.softmax(attn_Scores / (keys.shape[-1] ** 0.5), dim=-1)
        context_vectors = attn_weights @ values
        return context_vectors

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in=3, d_out=2)
print("Context vectors from SelfAttention_v1:\n", sa_v1(inputs))

Context vectors from SelfAttention_v1:
 tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [48]:
# Compact self attention class using nn.Linear

import torch.nn as nn

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
        context_vectors = attn_weights @ values
        return context_vectors

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in=3, d_out=2)
print("Context vectors from SelfAttention_v2:\n", sa_v2(inputs))

Context vectors from SelfAttention_v2:
 tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


## Causal Self Attention with trainable weights.

In [54]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
print("Attention weights from SelfAttention_v2:\n", attn_weights)

# create a mask
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print("Simple lower triangular mask:\n", mask_simple)

masked_simple = attn_weights * mask_simple
print("Masked attention weights (simple):\n", masked_simple)

# renormalize = 
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print("Renormalized masked attention weights (simple):\n", masked_simple_norm)

Attention weights from SelfAttention_v2:
 tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)
Simple lower triangular mask:
 tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
Masked attention weights (simple):
 tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1

In [56]:
# effecient causal self attention with mask
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print("Masked attention scores (causal):\n", masked) 

attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=-1)
print(attn_weights)

Masked attention scores (causal):
 tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [58]:
# dropping out attention weights

torch.manual_seed(123)
dropout = nn.Dropout(p=0.5)
example = torch.ones(6,6)
print(dropout(example))

torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


In [68]:
# compact causal attention class

batch = torch.stack((inputs, inputs), dim=0)
print("Batch shape:", batch.shape)

class CausalAttention(nn.Module):
    def __init__(self, d_in,d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] **0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vectors = attn_weights @ values
        return context_vectors

torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("Context vectors from CausalAttention:\n", context_vecs, "shape:", context_vecs.shape)



Batch shape: torch.Size([2, 6, 3])
Context vectors from CausalAttention:
 tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>) shape: torch.Size([2, 6, 2])


In [None]:
## Multi-head Self Attention

In [66]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out // num_heads, context_length, dropout, qkv_bias)
            for _ in range(num_heads)
        ])
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [71]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3,2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, dropout=0.0, num_heads=2)
context_vecs_mha = mha(batch)
print("Context vectors from MultiHeadAttentionWrapper:\n", context_vecs_mha, "shape:", context_vecs_mha.shape)

Context vectors from MultiHeadAttentionWrapper:
 tensor([[[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]],

        [[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]]], grad_fn=<CatBackward0>) shape: torch.Size([2, 6, 2])


In [75]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        queries = queries.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1,2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

        context_vec = self.out_proj(context_vec)
        return context_vec


In [76]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, num_heads=2, dropout=0.0)
context_vecs_mha = mha(batch)
print("Context vectors from MultiHeadAttention:\n", context_vecs_mha, "shape:", context_vecs_mha.shape)

Context vectors from MultiHeadAttention:
 tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>) shape: torch.Size([2, 6, 2])
