In [2]:
import torch
import torch.nn.functional as F
from torch import nn

In [3]:
torch.manual_seed(3333);

Consider the row vectors in b to be token embeddings for three tokens in a simple sequence

In [4]:
b = torch.randint(0,10,(3,2)).float()
b

tensor([[2., 9.],
        [7., 9.],
        [4., 4.]])

Our goal is that each vector (each token embedding) uses the information from other tokens in the sentence.
However, only the previous tokens should be considered.

One way of injecting information from the previous tokens to the current token could be to use a simple weighted average of the previous tokens up until the current token. 

How can we obtain a weighted sum of the rows in b such that row 1 is a weighted sum of row 0 and 1, and row 2 is weighted sum of row 0, 1 and 2?

In [6]:
# we can apply matrix multiplication (faster than using a loop)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
print("a:\n", a)
c = a @ b
print("c:\n", c)

a:
 tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
c:
 tensor([[2.0000, 9.0000],
        [4.5000, 9.0000],
        [4.3333, 7.3333]])


In [8]:
# sanity check for row with index 1 and 2
(b[0] + b[1] )/ 2, (b[0] + b[1] + b[2]) / 3

(tensor([4.5000, 9.0000]), tensor([4.3333, 7.3333]))

# Toy example

In [12]:
torch.manual_seed(1337)
B,T,C = 1,3,2 # batch, time (seq_len), channels (dim)
x = torch.randn(B,T,C)
x, x.shape

(tensor([[[-2.0260, -2.0655],
          [-1.2054, -0.9122],
          [-1.2502,  0.8032]]]),
 torch.Size([1, 3, 2]))

What we want is 

$x_{b,t} = \text{mean}_{i \leq t} \, x_{b,i}$


In [13]:
# lets code it - version 1: using for loops
x_new = torch.zeros(B,T,C)
for b in range(B): # for every batch
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        x_new[b,t] = torch.mean(xprev,0)
x_new

tensor([[[-2.0260, -2.0655],
         [-1.6157, -1.4889],
         [-1.4939, -0.7248]]])

In [10]:
# version 2: using matrix multiplication (faster than looping)
a = torch.tril(torch.ones(T,T))
a = a / torch.sum(a, 1, keepdim=True)
xnew2 = a @ x
xnew2

tensor([[[-2.0260, -2.0655],
         [-1.6157, -1.4889],
         [-1.4939, -0.7248]]])

In [11]:
# version 3: use "weights" + Softmax
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
a = F.softmax(wei, dim=1)
wei, a

(tensor([[0., -inf, -inf],
         [0., 0., -inf],
         [0., 0., 0.]]),
 tensor([[1.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000],
         [0.3333, 0.3333, 0.3333]]))

In [12]:
xnew3 = a @ x
xnew3

tensor([[[-2.0260, -2.0655],
         [-1.6157, -1.4889],
         [-1.4939, -0.7248]]])

Now, instead of explicitly defining the weights so we end up with using the mean (given equal importance to each token), we can let the network learn the weights in the `wei` matrix.

In [13]:
torch.manual_seed(3333)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C) # input: 4 sentences, 8 tokens long with, with each token consisting of a 32 dimensional vector (embedding)

# let's compute the wei matrix
head_size = 16 # we perform a single head self-attention
key = nn.Linear(C, head_size, bias=False) # (32, 16)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16): (B=4,T=8,C=32) @ (T=32,C=16) = (B=4,T=8,C=16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
out.shape

torch.Size([4, 8, 16])

In [14]:
wei.shape # each batch has a different weight matrix

torch.Size([4, 8, 8])

# Self-attention class

In [15]:
class SelfAttention(nn.Module):
    def __init__(self, C, head_size):
        super(SelfAttention, self).__init__()
        self.key = nn.Linear(C, head_size, bias=False)
        self.query = nn.Linear(C, head_size, bias=False)
        self.value = nn.Linear(C, head_size, bias=False)
        
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        wei =  q @ k.transpose(-2, -1) * C**-0.5 # (B, T, head_size) @ (B, head_size, T) ---> (B, T, T), C**-0.5 is the scaling factor
        tril = torch.tril(torch.ones(T, T))
        wei = wei.masked_fill(tril == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        out = wei @ v # (B, T, T) @ (B, T, head_size) ---> (B, T, head_size)
        return out
    
# test
torch.manual_seed(3333)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

sa = SelfAttention(C, 16)
out = sa(x)
out.shape

torch.Size([4, 8, 16])

# Multi-head self-attention

In [16]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, C, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_size = C // num_heads
        assert self.head_size * num_heads == C, "C must be divisible by num_heads"
        self.heads = nn.ModuleList([SelfAttention(C, head_size) for _ in range(num_heads)]) # create num_heads SelfAttention modules
        self.fc = nn.Linear(head_size * num_heads, C) # used to combine the "information" from each head 
        
    def forward(self, x):
        B,T,C = x.shape
        # for each token the head embeddings are concatenated
        out = torch.cat([head(x) for head in self.heads], dim=-1) # (B, T, head_size * num_heads)
        # the info from each head is combined using a linear layer (linear transformation)
        out = self.fc(out) # (B, T, C)
        return out
    
# test
torch.manual_seed(3333)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

mhsa = MultiHeadSelfAttention(C, num_heads=2)
mhsa(x).shape

torch.Size([4, 8, 32])

# Multi-head self-attention with parallel computation

In [14]:
torch.triu(torch.ones(3, 3), diagonal=1)

tensor([[0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 0.]])

In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(-2, -1)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec
    
    
# test
torch.manual_seed(3333)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

mha = MultiHeadAttention(d_in=C, d_out=C, context_length=T, dropout=0.1, num_heads=4)
mha(x).shape
        


torch.Size([4, 8, 32])