In [1]:
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [21]:
example = "Where engineering meets pure driving emotion"

In [60]:
torch.set_printoptions(sci_mode=False)

In [None]:
inputs = torch.tensor(
    [[0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566, 0.7936], # Where
    [0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677, 0.7411, 0.4294],  # engineering 
    [0.8854, 0.5739, 0.2666, 0.6274, 0.2696, 0.4414, 0.2969, 0.8317],  # meets
    [0.1053, 0.2695, 0.3588, 0.1994, 0.5472, 0.0062, 0.9516, 0.0753],  # pure
    [0.8860, 0.5832, 0.3376, 0.8090, 0.5779, 0.9040, 0.5547, 0.3423],  # driving
    [0.6343, 0.3644, 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895]]  # emotion
)

## Analyse basic parts

##### Intermediate Score - Attention score Attention score between input and query $x_i$ and $query$

**here, $query=x_4$** 

In [None]:
query = inputs[4]
attention_scores = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attention_scores[i] = torch.dot(x_i, query)

print(attention_scores)

tensor([3.4035, 3.2807, 2.7210, 1.4083, 3.4696, 3.1297])


*normalization* (softmax)

In [33]:
attn_weights_4 = torch.softmax(attention_scores, dim=0)
print("Attention weights:", attn_weights_4)

Attention weights: tensor([0.2296, 0.2031, 0.1161, 0.0312, 0.2453, 0.1746])


##### calculating the context vector $z_4$

In [None]:
query = inputs[1]
context_vec_4 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_4 += attn_weights_4[i]*x_i
    
print(context_vec_4)

tensor([0.8279, 0.5189, 0.5268, 0.7836, 0.5942, 0.5756, 0.5474, 0.5553])


In [36]:
print(inputs.shape)
print(attn_weights_4.shape)

torch.Size([6, 8])
torch.Size([6])


*more efficient*

In [38]:
context_vec_4 = attn_weights_4 @ inputs
print(context_vec_4)

tensor([0.8279, 0.5189, 0.5268, 0.7836, 0.5942, 0.5756, 0.5474, 0.5553])


### Computing attention weights for all input tokens

In [44]:
attn_scores = torch.empty(inputs.shape[0], inputs.shape[0])

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[3.8917, 3.0907, 3.1170, 1.1894, 3.4035, 3.2203],
        [3.0907, 3.9404, 2.5932, 1.8054, 3.2807, 3.5544],
        [3.1170, 2.5932, 2.6255, 0.9641, 2.7210, 2.6154],
        [1.1894, 1.8054, 0.9641, 1.4628, 1.4083, 1.8369],
        [3.4035, 3.2807, 2.7210, 1.4083, 3.4696, 3.1297],
        [3.2203, 3.5544, 2.6154, 1.8369, 3.1297, 3.6068]])


*more efficient*

In [49]:
attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.3224, 0.1447, 0.1486, 0.0216, 0.1979, 0.1648],
        [0.1424, 0.3331, 0.0866, 0.0394, 0.1722, 0.2264],
        [0.2779, 0.1646, 0.1700, 0.0323, 0.1870, 0.1683],
        [0.1232, 0.2280, 0.0983, 0.1619, 0.1533, 0.2353],
        [0.2296, 0.2031, 0.1161, 0.0312, 0.2453, 0.1746],
        [0.1793, 0.2504, 0.0979, 0.0449, 0.1637, 0.2638]])


**all context vectors**

In [50]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.8343, 0.5808, 0.4899, 0.8088, 0.5480, 0.5669, 0.4944, 0.6081],
        [0.8159, 0.4179, 0.6220, 0.7500, 0.6682, 0.5325, 0.6206, 0.5234],
        [0.8263, 0.5528, 0.4998, 0.7879, 0.5559, 0.5520, 0.5109, 0.5982],
        [0.7124, 0.4183, 0.5635, 0.6942, 0.6357, 0.4526, 0.6544, 0.4808],
        [0.8279, 0.5189, 0.5268, 0.7836, 0.5942, 0.5756, 0.5474, 0.5553],
        [0.7975, 0.4573, 0.5875, 0.7731, 0.6414, 0.5156, 0.6022, 0.5461]])


## Self Attention with trainable params

In [52]:
x_2 = inputs[1]         # the second input element
d_in = inputs.shape[1]  # input embedding dimension
d_out = 4               # output embedding dimension

torch.manual_seed(42)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [53]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([3.5132, 2.3542, 2.7826, 3.0156])


In [54]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 4])
values.shape: torch.Size([6, 4])


compute the attention score $ω_{22}$

In [55]:
keys_2 = keys[1]
attn_score_22 = torch.dot(query_2, keys_2)
print(attn_score_22)

tensor(34.1881)


In [57]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([34.3551, 34.1881, 27.6320, 15.6713, 33.0249, 33.1059])


In [62]:
d_k = keys.shape[-1] # embedding dimension of the keys
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.3329, 0.3062, 0.0115, 0.0000, 0.1712, 0.1782])


In [63]:
print(values.shape)

torch.Size([6, 4])


In [65]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([2.6011, 2.2305, 3.1904, 1.9242])


we’ve only computed a single context vector, z(2)

$\rightarrow$ we will generalize the
code to compute all context vectors in the input sequence, z(1) to z(T)

### Implementing a compact self-attention Python class

In [67]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [68]:
torch.manual_seed(42)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[2.5996, 2.2310, 3.1960, 1.9265],
        [2.6011, 2.2305, 3.1904, 1.9242],
        [2.5885, 2.2279, 3.1881, 1.9231],
        [2.5467, 2.2085, 3.1475, 1.9004],
        [2.5983, 2.2291, 3.1982, 1.9245],
        [2.5964, 2.2313, 3.1890, 1.9263]], grad_fn=<MmBackward0>)


**A v2 with nn.Linear thanks to its optimized weight initialization scheme**

In [None]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
            
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [70]:
torch.manual_seed(43)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[ 0.3723,  0.0532, -0.0368,  0.5675],
        [ 0.3724,  0.0532, -0.0399,  0.5679],
        [ 0.3725,  0.0533, -0.0383,  0.5678],
        [ 0.3725,  0.0529, -0.0390,  0.5679],
        [ 0.3730,  0.0530, -0.0378,  0.5683],
        [ 0.3726,  0.0530, -0.0383,  0.5680]], grad_fn=<MmBackward0>)


## Hiding future words with causal attention

### 1. Applying a causal mask

In [73]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1739, 0.1591, 0.1735, 0.1604, 0.1671, 0.1659],
        [0.1691, 0.1647, 0.1660, 0.1661, 0.1685, 0.1656],
        [0.1726, 0.1607, 0.1693, 0.1620, 0.1683, 0.1671],
        [0.1670, 0.1664, 0.1713, 0.1670, 0.1654, 0.1629],
        [0.1714, 0.1621, 0.1723, 0.1617, 0.1657, 0.1669],
        [0.1697, 0.1638, 0.1715, 0.1642, 0.1660, 0.1648]],
       grad_fn=<SoftmaxBackward0>)


*create a mask where the values above the diagonal are zero*

In [74]:
context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [76]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1739, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1691, 0.1647, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1726, 0.1607, 0.1693, 0.0000, 0.0000, 0.0000],
        [0.1670, 0.1664, 0.1713, 0.1670, 0.0000, 0.0000],
        [0.1714, 0.1621, 0.1723, 0.1617, 0.1657, 0.0000],
        [0.1697, 0.1638, 0.1715, 0.1642, 0.1660, 0.1648]],
       grad_fn=<MulBackward0>)


**renormalize the attention weights to sum up to 1 again in each
row**

v1

In [77]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5067, 0.4933, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3434, 0.3197, 0.3368, 0.0000, 0.0000, 0.0000],
        [0.2487, 0.2477, 0.2550, 0.2486, 0.0000, 0.0000],
        [0.2057, 0.1945, 0.2068, 0.1940, 0.1989, 0.0000],
        [0.1697, 0.1638, 0.1715, 0.1642, 0.1660, 0.1648]],
       grad_fn=<DivBackward0>)


*masked attention weights more efficient*

In [81]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)  # attention SCORE not weights
print(masked)

tensor([[-0.0514,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0352, -0.0886,    -inf,    -inf,    -inf,    -inf],
        [ 0.0114, -0.1316, -0.0273,    -inf,    -inf,    -inf],
        [-0.1490, -0.1570, -0.0985, -0.1492,    -inf,    -inf],
        [-0.0603, -0.1721, -0.0493, -0.1771, -0.1280,    -inf],
        [-0.1057, -0.1760, -0.0839, -0.1710, -0.1498, -0.1642]],
       grad_fn=<MaskedFillBackward0>)


In [82]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5067, 0.4933, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3434, 0.3197, 0.3368, 0.0000, 0.0000, 0.0000],
        [0.2487, 0.2477, 0.2550, 0.2486, 0.0000, 0.0000],
        [0.2057, 0.1945, 0.2068, 0.1940, 0.1989, 0.0000],
        [0.1697, 0.1638, 0.1715, 0.1642, 0.1660, 0.1648]],
       grad_fn=<SoftmaxBackward0>)


### 2. Masking additional attention weights with dropout

In [83]:
torch.manual_seed(42)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 0., 2.],
        [0., 0., 2., 2., 2., 2.],
        [0., 0., 2., 0., 2., 0.],
        [0., 2., 2., 0., 2., 2.],
        [2., 2., 0., 2., 2., 2.],
        [2., 2., 2., 2., 0., 0.]])


In [89]:
torch.manual_seed(42)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6737, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4953, 0.5101, 0.0000, 0.0000, 0.0000],
        [0.4114, 0.3891, 0.0000, 0.3881, 0.3977, 0.0000],
        [0.3393, 0.3276, 0.3431, 0.3285, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


### 3. Implementing a compact causal attention class

In [90]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 8])


In [None]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
        dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(  # performing in place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [92]:
torch.manual_seed(42)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 4])


## Extending single-head attention to multi-head attention

the main idea behind multi-head attention is to run the attention
mechanism multiple times (in parallel) with different, learned linear projections

### 1. Simply stacking multiple single-head attention layers

**final embedding dimension = (`d_out` × `num_heads`)**

In [None]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(
                d_in, d_out, context_length, dropout, qkv_bias
            )
            for _ in range(num_heads)]
        )
        
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1) # !! processed sequentially

In [95]:
torch.manual_seed(42)
context_length = batch.shape[1] # This is the number of tokens (6)
d_in, d_out = 8, 4

mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.0945, -0.1076, -0.5796, -0.1055,  0.2987, -0.3192, -0.2739,
           0.2104],
         [ 0.0641, -0.1119, -0.2634, -0.0971,  0.2914, -0.2776, -0.1710,
           0.1810],
         [ 0.0490, -0.0245, -0.3543, -0.1330,  0.3192, -0.2653, -0.2086,
           0.1692],
         [ 0.1104, -0.0754, -0.1794, -0.1285,  0.2299, -0.2482, -0.1160,
           0.1229],
         [ 0.0456, -0.0607, -0.2410, -0.1031,  0.2299, -0.2520, -0.1447,
           0.1438],
         [ 0.0686, -0.1009, -0.1996, -0.1179,  0.2107, -0.2531, -0.1413,
           0.1272]],

        [[-0.0945, -0.1076, -0.5796, -0.1055,  0.2987, -0.3192, -0.2739,
           0.2104],
         [ 0.0641, -0.1119, -0.2634, -0.0971,  0.2914, -0.2776, -0.1710,
           0.1810],
         [ 0.0490, -0.0245, -0.3543, -0.1330,  0.3192, -0.2653, -0.2086,
           0.1692],
         [ 0.1104, -0.0754, -0.1794, -0.1285,  0.2299, -0.2482, -0.1160,
           0.1229],
         [ 0.0456, -0.0607, -0.2410, -0.1031,  0.2299, -0.2520, -0.1

### 2. Implementing multi-head attention with weight splits

**Processing the heads in parallel** $\rightarrow$  **matrix multiplication**

In [97]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out,
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query  = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key    = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value  = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout  = nn.Dropout(dropout)

        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys    = self.W_key(x)   # tensor of shape (b, num_tokens, d_out)
        queries = self.W_query(x) # --- 
        values  = self.W_value(x) # ---

        keys    = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values  = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        
        keys    = keys.transpose(1, 2)     # transpose to (b, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        values  = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3) # compute the dot product for each head
                                                     # shape: (b, num_heads, num_tokens, num_tokens)
       
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2) # transpose to (b, num_tokens, num_heads, head_dim)
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )
        context_vec = self.out_proj(context_vec) # Adds an optional linear projection
        return context_vec

In [98]:
torch.manual_seed(42)
batch_size, context_length, d_in = batch.shape
d_out = 4
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.2828, -0.3762, -0.0178, -0.3905],
         [ 0.3052, -0.4477, -0.1416, -0.3152],
         [ 0.3165, -0.4136, -0.1247, -0.3781],
         [ 0.3104, -0.4478, -0.1900, -0.2927],
         [ 0.2837, -0.4357, -0.1740, -0.3000],
         [ 0.2990, -0.4441, -0.1730, -0.2779]],

        [[ 0.2828, -0.3762, -0.0178, -0.3905],
         [ 0.3052, -0.4477, -0.1416, -0.3152],
         [ 0.3165, -0.4136, -0.1247, -0.3781],
         [ 0.3104, -0.4478, -0.1900, -0.2927],
         [ 0.2837, -0.4357, -0.1740, -0.3000],
         [ 0.2990, -0.4441, -0.1730, -0.2779]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])
