In [1]:
from importlib.metadata import version

print("torch version:", version("torch"))

torch version: 2.4.0


In [2]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2

In [4]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec

In [5]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [6]:
inputs.shape

torch.Size([6, 3])

In [7]:
sa_v1(inputs).shape

torch.Size([6, 2])

In [8]:
import torch.nn as nn

class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key   = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_key(x)
        values = self.W_query(x)
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [9]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v2(d_in, d_out)
print(sa_v1(inputs))

tensor([[-0.6648, -0.4255],
        [-0.6640, -0.4314],
        [-0.6640, -0.4310],
        [-0.6632, -0.4309],
        [-0.6642, -0.4230],
        [-0.6629, -0.4346]], grad_fn=<MmBackward0>)


In [10]:
mask = torch.triu(torch.ones(4, 4), diagonal=1)

In [11]:
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [12]:
batch = torch.stack((inputs, inputs), dim=0)

In [13]:
batch.shape

torch.Size([2, 6, 3])

In [14]:
inputs.shape

torch.Size([6, 3])

In [15]:
b, num_tokens, d_in = batch.shape # New batch dimension b

In [16]:
d_out = 32
W_query = nn.Linear(d_in, 32)
W_key   = nn.Linear(d_in, 32)
W_value = nn.Linear(d_in, 32)
dropout = nn.Dropout(0.25) # New

In [17]:
x = batch
keys = W_key(x)
queries = W_query(x)
values = W_value(x)

In [18]:
context_length = 32

In [19]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

In [20]:
mask

tensor([[0., 1., 1.,  ..., 1., 1., 1.],
        [0., 0., 1.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [21]:
mask.shape

torch.Size([32, 32])

In [22]:
mask.bool()[:num_tokens,:num_tokens].shape

torch.Size([6, 6])

In [26]:
attn_scores = queries@keys.transpose(1,2)

In [32]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])

In [48]:
attn_scores.masked_fill_(
    mask.bool()[:num_tokens,:num_tokens],
    -torch.inf    
)


tensor([[[1.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
         [0.5112, 0.4888,   -inf,   -inf,   -inf,   -inf],
         [0.3438, 0.3276, 0.3287,   -inf,   -inf,   -inf],
         [0.2532, 0.2440, 0.2445, 0.2583,   -inf,   -inf],
         [0.2095, 0.1888, 0.1892, 0.2048, 0.2076,   -inf],
         [0.1627, 0.1621, 0.1624, 0.1713, 0.1744, 0.1670]],

        [[1.0000,   -inf,   -inf,   -inf,   -inf,   -inf],
         [0.5112, 0.4888,   -inf,   -inf,   -inf,   -inf],
         [0.3438, 0.3276, 0.3287,   -inf,   -inf,   -inf],
         [0.2532, 0.2440, 0.2445, 0.2583,   -inf,   -inf],
         [0.2095, 0.1888, 0.1892, 0.2048, 0.2076,   -inf],
         [0.1627, 0.1621, 0.1624, 0.1713, 0.1744, 0.1670]]],
       grad_fn=<MaskedFillBackward0>)

In [45]:
attn_scores = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim = -1)

In [46]:
attn_scores

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5112, 0.4888, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3438, 0.3276, 0.3287, 0.0000, 0.0000, 0.0000],
         [0.2532, 0.2440, 0.2445, 0.2583, 0.0000, 0.0000],
         [0.2095, 0.1888, 0.1892, 0.2048, 0.2076, 0.0000],
         [0.1627, 0.1621, 0.1624, 0.1713, 0.1744, 0.1670]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5112, 0.4888, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3438, 0.3276, 0.3287, 0.0000, 0.0000, 0.0000],
         [0.2532, 0.2440, 0.2445, 0.2583, 0.0000, 0.0000],
         [0.2095, 0.1888, 0.1892, 0.2048, 0.2076, 0.0000],
         [0.1627, 0.1621, 0.1624, 0.1713, 0.1744, 0.1670]]],
       grad_fn=<SoftmaxBackward0>)

In [71]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

In [72]:
batch.shape

torch.Size([2, 6, 3])

In [89]:
cs_att = CausalAttention(d_in = 3, d_out = 32, context_length = 6, dropout = 0.24)

In [90]:
cs_att

CausalAttention(
  (W_query): Linear(in_features=3, out_features=32, bias=False)
  (W_key): Linear(in_features=3, out_features=32, bias=False)
  (W_value): Linear(in_features=3, out_features=32, bias=False)
  (dropout): Dropout(p=0.24, inplace=False)
)

In [91]:
cs_att(batch).shape

torch.Size([2, 6, 32])

In [98]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 3
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=8
)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.3326,  0.5659, -0.3132,  0.0752,  0.4566,  0.2729, -0.0322,
          -0.4610,  0.2847, -0.3886, -0.1471,  0.4106,  0.2869,  0.2964,
          -0.1885,  0.5982, -0.6053,  0.6575,  0.3087, -0.4569,  0.2649,
           0.3223,  0.1336, -0.2467],
         [ 0.3456,  0.5650, -0.2237,  0.0313,  0.5977,  0.3053, -0.2642,
          -0.3017,  0.4047, -0.1700, -0.2532,  0.3527,  0.3626,  0.2247,
          -0.3241,  0.6507, -0.5515,  0.6564,  0.2733, -0.6272,  0.3056,
           0.2533,  0.0792, -0.2004],
         [ 0.3440,  0.5604, -0.2000,  0.0178,  0.6413,  0.3138, -0.3371,
          -0.2502,  0.4369, -0.1055, -0.2813,  0.3304,  0.3880,  0.2032,
          -0.3675,  0.6669, -0.5324,  0.6538,  0.2588, -0.6824,  0.3120,
           0.2272,  0.0577, -0.1865],
         [ 0.3103,  0.4941, -0.1606,  0.0089,  0.5729,  0.2785, -0.3210,
          -0.2028,  0.4068, -0.0654, -0.2649,  0.2920,  0.3533,  0.1571,
          -0.3522,  0.5905, -0.4536,  0.5714,  0.2231, -0.6161,  0.2913,
          

In [96]:
batch.shape

torch.Size([2, 6, 3])

In [107]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        # print('CONTEXT VEC: ', context_vec.shape)
        # print('CONTEXT VEC: ', context_vec.contiguous().shape)
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 16
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=8)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.4060, -0.0417, -0.7316,  0.1562, -0.0687,  0.0826,  0.5215,
           0.2318,  0.3004,  0.2771, -0.1343,  0.0142, -0.4592,  0.4092,
           0.0128, -0.6038],
         [ 0.4464,  0.1098, -0.8581,  0.2798, -0.0649,  0.0994,  0.5518,
           0.1860,  0.2879,  0.2076, -0.0533, -0.0197, -0.5703,  0.4576,
          -0.0516, -0.4887],
         [ 0.4561,  0.1537, -0.8914,  0.3179, -0.0628,  0.1025,  0.5589,
           0.1680,  0.2832,  0.1851, -0.0255, -0.0322, -0.6010,  0.4690,
          -0.0720, -0.4524],
         [ 0.4308,  0.1482, -0.8345,  0.3191, -0.0717,  0.0896,  0.5268,
           0.1333,  0.2724,  0.1739,  0.0122, -0.0243, -0.5601,  0.4471,
          -0.0932, -0.3953],
         [ 0.4071,  0.0852, -0.7406,  0.3178, -0.0619,  0.0903,  0.4936,
           0.0839,  0.2705,  0.1813,  0.0401, -0.0211, -0.5070,  0.3798,
          -0.1069, -0.3775],
         [ 0.4096,  0.1227, -0.7687,  0.3228, -0.0725,  0.0852,  0.4970,
           0.0937,  0.2659,  0.1676,  0.0434, -0.019

In [108]:
batch.shape

torch.Size([2, 6, 3])