In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
query = inputs[1]  # 2nd input token is the query

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)
attn_scores_2_ = query @ inputs.T
print(attn_scores_2)
print(attn_scores_2_)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [3]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [4]:

# context_vec_2 = torch.zeros(query.shape)
# for i,x_i in enumerate(inputs):
#     context_vec_2 += attn_weights_2[i]*x_i
context_vec_2= attn_weights_2 @ inputs
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


##**Scaled dot-product attention**

The most notable difference is the introduction of weight matrices that are updated during model training.
These trainable weight matrices are crucial so that the model (specifically, the attention module inside the model) can learn to produce "good" context vectors

*Implementing the self-attention mechanism step by step, we will start by introducing the three training weight matrices $W_q$,$W_k$ and $W_v$

*These three matrices are used to project the embedded input tokens,
 into query, key, and value vectors via matrix multiplication:

These three matrices are used to project the embedded input tokens,  into query, key, and value vectors via matrix multiplication:

Query vector: $q^i=W_q *x^i $ \\
Key vector: $k^i=W_k *x^i $ \\
Value vector: $v^i=W_v *x^i $ \\

In [5]:
x = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2

In [6]:
import torch.nn as nn
torch.manual_seed(123)

W_query = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [7]:
query_2 = x @ W_query # _2 because it's with respect to the 2nd input element
key_2 = x @ W_key
value_2 = x @ W_value

print(query_2,query_2.size())

tensor([0.4306, 1.4551]) torch.Size([2])


In [8]:
keys = inputs @ W_key
values = inputs @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


we compute the attention weights by computing the dot product between the query and each key vector:

In [9]:
attn_score= query_2.unsqueeze(dim=0) @ keys.T
print(attn_score)
attn_weights = torch.softmax(attn_score/ d_out**0.5,dim=1)
print(attn_weights)

tensor([[1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440]])
tensor([[0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820]])


In [10]:
context_vec_2 = attn_weights @ values
context_vec_2

tensor([[0.3061, 0.8210]])

##Implementing a compact SelfAttention class

In [11]:
import torch.nn as nn

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(789)
sa = SelfAttention(d_in, d_out)
print(sa(inputs))


tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


##Causal self-attention

Causal self-attention ensures that the model's prediction for a certain position in a sequence is only dependent on the known outputs at previous positions, not on future positions.


In [12]:
#Let's start with some attention weights
attn_weights=torch.tensor([[0.1972, 0.1910, 0.1894, 0.1361, 0.1344, 0.1520],
        [0.1476, 0.2164, 0.2134, 0.1365, 0.1240, 0.1621],
        [0.1479, 0.2157, 0.2129, 0.1366, 0.1260, 0.1608],
        [0.1505, 0.1952, 0.1933, 0.1525, 0.1375, 0.1711],
        [0.1571, 0.1874, 0.1885, 0.1453, 0.1819, 0.1399],
        [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])
attn_weights

tensor([[0.1972, 0.1910, 0.1894, 0.1361, 0.1344, 0.1520],
        [0.1476, 0.2164, 0.2134, 0.1365, 0.1240, 0.1621],
        [0.1479, 0.2157, 0.2129, 0.1366, 0.1260, 0.1608],
        [0.1505, 0.1952, 0.1933, 0.1525, 0.1375, 0.1711],
        [0.1571, 0.1874, 0.1885, 0.1453, 0.1819, 0.1399],
        [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])

In [13]:
block_size = attn_weights.size(0)
mask_simple = torch.tril(torch.ones(block_size, block_size))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [14]:
#multiply the attention weights with this mask to zero out the attention scores above the diagonal:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1972, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1476, 0.2164, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1479, 0.2157, 0.2129, 0.0000, 0.0000, 0.0000],
        [0.1505, 0.1952, 0.1933, 0.1525, 0.0000, 0.0000],
        [0.1571, 0.1874, 0.1885, 0.1453, 0.1819, 0.0000],
        [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])


In [15]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4055, 0.5945, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2565, 0.3742, 0.3693, 0.0000, 0.0000, 0.0000],
        [0.2176, 0.2823, 0.2795, 0.2205, 0.0000, 0.0000],
        [0.1826, 0.2179, 0.2191, 0.1689, 0.2115, 0.0000],
        [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])


So, instead of zeroing out attention weights above the diagonal and renormalizing the results, we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function

In [16]:
# mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
# masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
# print(masked)
# attn_weights = torch.softmax(masked / d_out**0.5, dim=1)
# print(attn_weights)

**Masking additional attention weights with dropout** \\
In addition, we also apply dropout to reduce overfitting during training.

we will apply the dropout mask after computing the attention weights because it's more common.

In [17]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # dropout rate of 50%
example = torch.ones(6, 6) # create a matrix of ones
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [18]:
attn_dropout = dropout(attn_weights)
attn_dropout

tensor([[0.3944, 0.3820, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4268, 0.0000, 0.0000, 0.3242],
        [0.0000, 0.0000, 0.4258, 0.2732, 0.0000, 0.3216],
        [0.0000, 0.3904, 0.0000, 0.3050, 0.2750, 0.0000],
        [0.0000, 0.3748, 0.3770, 0.2906, 0.3638, 0.2798],
        [0.2946, 0.4066, 0.0000, 0.0000, 0.2320, 0.3678]])

In [19]:
#Normalized attn_weights after dropout
sum_weights=torch.sum(attn_dropout, dim=1, keepdim=True)
attn_dropout=attn_dropout/sum_weights
attn_dropout

tensor([[0.5080, 0.4920, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5683, 0.0000, 0.0000, 0.4317],
        [0.0000, 0.0000, 0.4172, 0.2677, 0.0000, 0.3151],
        [0.0000, 0.4023, 0.0000, 0.3143, 0.2834, 0.0000],
        [0.0000, 0.2223, 0.2236, 0.1724, 0.2158, 0.1660],
        [0.2264, 0.3125, 0.0000, 0.0000, 0.1783, 0.2827]])

One more thing is to implement the code to **handle batches** consisting of more than one input so that our CausalSelfAttention class supports the batch outputs

In [20]:
batch = torch.stack((inputs, inputs, inputs), dim=0)
print(batch.shape) # batch with size 3

torch.Size([3, 6, 3])


In [23]:
class CausalSelfAttention(nn.Module):

    def __init__(self, d_in, d_out, block_size, dropout):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key   = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        self.dropout = nn.Dropout(dropout) # New
        self.mask=torch.triu(torch.ones(block_size, block_size), diagonal=1)
        # self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New

    def forward(self, x):
        b, n_tokens, d_in = x.shape # New batch dimension b
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = torch.bmm(queries,keys.permute(0,2,1)) #queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_( self.mask.bool()[:n_tokens, :n_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)

block_size = batch.shape[1]
csa = CausalSelfAttention(d_in, d_out, block_size, 0.0)

context_vecs = csa(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.0844,  0.0414],
         [-0.2264, -0.0039],
         [-0.4163, -0.0564],
         [-0.5014, -0.1011],
         [-0.7754, -0.1867],
         [-1.1632, -0.3303]],

        [[-0.0844,  0.0414],
         [-0.2264, -0.0039],
         [-0.4163, -0.0564],
         [-0.5014, -0.1011],
         [-0.7754, -0.1867],
         [-1.1632, -0.3303]],

        [[-0.0844,  0.0414],
         [-0.2264, -0.0039],
         [-0.4163, -0.0564],
         [-0.5014, -0.1011],
         [-0.7754, -0.1867],
         [-1.1632, -0.3303]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([3, 6, 2])


**Note that dropout is only applied during training, not during inference.**

##Extending single-head attention to multi-head attention

The main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections. This allows the model to jointly attend to information from different representation subspaces at different positions.

In [25]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, block_size, dropout, num_heads):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalSelfAttention(d_in, d_out, block_size, dropout)
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


torch.manual_seed(123)

block_size = batch.shape[1]
mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]],

        [[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]],

        [[-0.0844,  0.0414,  0.0766,  0.0171],
         [-0.2264, -0.0039,  0.2143,  0.1185],
         [-0.4163, -0.0564,  0.3878,  0.2453],
         [-0.5014, -0.1011,  0.4992,  0.3401],
         [-0.7754, -0.1867,  0.7387,  0.4868],
         [-1.1632, -0.3303,  1.1224,  0.8460]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([3, 6, 4])


**More efficient implementation** \\
While the above is an intuitive and fully functional implementation of multi-head attention (wrapping the single-head attention CausalSelfAttention implementation from earlier), we can write a stand-alone class called MultiHeadAttention to achieve the same.

We don't concatenate single attention heads for this stand-alone MultiHeadAttention class. Instead, we create single W_query, W_key, and W_value weight matrices and then split those into individual matrices for each attention head:

In [26]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))

    def forward(self, x):
        b, n_tokens, d_in = x.shape
        # (b, n_heads, T) -> (b, T, n_heads, head_dim)

        keys = self.W_key(x) # Shape: (b, T, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitely split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, T, d_out) -> (b, T, num_heads, head_dim)
        keys = keys.view(b, n_tokens, self.num_heads, self.head_dim)
        values = values.view(b, n_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, n_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, T, num_heads, head_dim) -> (b, num_heads, T, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
        attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)
        attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, T, n_heads, head_dim)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, block_size, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)


tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([3, 6, 2])
