## Approach #1: Simplified Attention

In [1]:
import torch
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
# pick one as your query
query = inputs[1]

# initialize attention scores
attention_scores_2 = torch.empty(inputs.shape[0])

# compute its value in a simplified manner.. just take
# dot products
for i, x_i in enumerate(inputs):
    attention_scores_2[i] = torch.dot(x_i, query)
print(attention_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


Next step: convert from 'scores' -> 'weights' (normalized)

Attention **Scores** are respresented by $\omega$

Attention **Weights** are represented by $\alpha$

In [3]:
attn_weights_2_tmp = attention_scores_2 / attention_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


It's preferable to normalize using Softmax instead, to manage extreme values and provide favorable gradients during training. 

Let's try that

In [4]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)    

attn_weights_2_naive = softmax_naive(attention_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Summed:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Summed: tensor(1.)


In [5]:
attn_weights_2 = torch.softmax(attention_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Summed:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Summed: tensor(1.)


Final step! We have the **attention weights**, and now need to combine them into a **context vector**. 

In this simplified version, we'll just sum them.

In [6]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


Same thing, but now we'll do it for ALL of the input tokens, treating each input as a **query**.

In [7]:
d = len(inputs)
attn_scores = torch.empty(d, d)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [8]:
attn_scores = inputs @ inputs.T # matrix multiplication ftw!
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [9]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [10]:
torch.softmax(attn_scores, dim=0)

tensor([[0.2098, 0.1385, 0.1390, 0.1435, 0.1526, 0.1385],
        [0.2006, 0.2379, 0.2369, 0.2074, 0.1958, 0.2184],
        [0.1981, 0.2333, 0.2326, 0.2046, 0.1975, 0.2128],
        [0.1242, 0.1240, 0.1242, 0.1462, 0.1367, 0.1420],
        [0.1220, 0.1082, 0.1108, 0.1263, 0.1879, 0.0988],
        [0.1452, 0.1581, 0.1565, 0.1720, 0.1295, 0.1896]])

In [11]:
torch.softmax(attn_scores, dim=1)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [12]:
# torch.softmax(attn_scores, dim=2)

In [13]:
torch.sum(attn_scores[1])

tensor(6.5617)

In [14]:
torch.sum(attn_weights[1])

tensor(1.)

In [15]:
attn_weights.sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [16]:
attn_weights.sum(dim=0)

tensor([0.9220, 1.2970, 1.2788, 0.7974, 0.7540, 0.9508])

I'll have to keep practicing to get a good intuition on dim. Dim=0 seems like it refers to the "first" dimension, whereas dim=1 is the "second" dimension. dim=-1 is the "last" dimension, is the same as the second in this case.

In [17]:
all_context_vecs = attn_weights @ inputs

In [18]:
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [19]:
all_context_vecs.sum(dim=-1)

tensor([1.6141, 1.6617, 1.6598, 1.6112, 1.5847, 1.6326])

In [20]:
context_vec_2

tensor([0.4419, 0.6515, 0.5683])

In [21]:
context_vec_2.eq(all_context_vecs[1])

tensor([False,  True, False])

In [22]:
torch.isclose(context_vec_2, all_context_vecs[1])

tensor([True, True, True])

## Approach #2: Self-attention with trainable weights

### aka "scaled-dot product attention"

We will add three weight matrices which are trainable: $W_q$, $W_k$, $W_v$.

Multiplying the input $x^{(i)}$ by a weight projects it into that space. For example:

- $x^{(i)} * W_k = k^{(i)}$ ... project into the "key" space
- $x^{(i)} * W_v = v^{(i)}$ ... project into the "value" space
- $x^{(i)} * W_q = q^{(i)}$ ... project into the "query" space

Recall that $x$ is an **embedding** of the text token, with a certain dimension.
When we project into the "output" space, it is an embedding of some dimension which needn't be the same as $x$'s dimension.

In [23]:
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # input embedding size
d_out = 2 # output embedding size

In [24]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # nit: we are ignoring grad for now for simplicity. we'll want it later!
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [25]:
print(W_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [26]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [27]:
keys = inputs @ W_key

In [28]:
values = inputs @ W_value

In [29]:
print("keys.shape:", keys.shape)

keys.shape: torch.Size([6, 2])


In [30]:
print("values.shape:", values.shape)

values.shape: torch.Size([6, 2])


To compute an attention score for a given token idx ("queried token"), we multiply the query by the key.

In [31]:
attn_score_22 = query_2.dot(keys[1]) # naming is 1-indexed in math variable name land and 0-indexed in computer code land
print(attn_score_22)

tensor(1.8524)


Now let's get the same value via matrix multiplication...

In [32]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [33]:
torch.isclose(attn_score_22, attn_scores_2[1])

tensor(True)

In [34]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=-1)

In [35]:
attn_weights_2

tensor([0.1401, 0.2507, 0.2406, 0.1157, 0.0687, 0.1842])

In [36]:
# that was wrong actually! We do a adjustment to scale it further
# (1) NEW: divide by the sq root of the embedding dimension of the keys
# (2) take the softmax

# The reason for the normalization by the embedding dimension size 
# is to improve the training performance by avoiding small gradients.

d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


Awesome, now we have the **Attention weights** $\alpha_{2i}$, and we want to compute the **Context Vector** $Z^{(2)}$

In [37]:
context_vec_2 = attn_weights_2 @ values

In [38]:
values

tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])

In [39]:
context_vec_2

tensor([0.3061, 0.8210])

In [40]:
values.size()

torch.Size([6, 2])

In [41]:
attn_weights_2.size()

torch.Size([6])

In [42]:
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [43]:
values.T

tensor([[0.1855, 0.3951, 0.3879, 0.2393, 0.1492, 0.3221],
        [0.8812, 1.0037, 0.9831, 0.5493, 0.3346, 0.7863]])

In [44]:
attn_weights_2.dot(values.T[0])

tensor(0.3061)

In [45]:
attn_weights_2.dot(values.T[1])

tensor(0.8210)

Let's now implement class to compute self-attention for all queried tokens.

In [46]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        """ given inputs `x` (embedding dim = d_in),
            execute the whole 'attention' module, 
            returns the computed context vectors (dim = d_out)"""
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        # scores ("omega") are queries * keys
        attn_scores = queries @ keys.T
        
        # weights ("alpha") are normalized via softmax AND sqroot of embedding size
        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

        

In [47]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [48]:
context_vec_2

tensor([0.3061, 0.8210])

In [49]:
torch.isclose(context_vec_2, sa_v1(inputs)[1])

tensor([True, True])

Let's tweak our implementation to use `nn.Linear`, which works well because
- without a bias unit, it just performs matrix multiplication
- it has a optimizated random weights initialization scheme vs `torch.rand` -> stable and effective model training

In [50]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec
        
        

In [51]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


Ex 3.1 -- try to make v1 and v2 output the same thing by transferring the weights from V2 into a V1 instance...

>  we can transfer the weight matrices from a SelfAttention_v2 object to a SelfAttention_v1, such that both objects then produce the same results.

In [52]:
sa_v2.W_key.weight

Parameter containing:
tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]], requires_grad=True)

In [53]:
sa_v1.W_key

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]], requires_grad=True)

In [54]:
print(sa_v2.W_key.bias)

None


In [55]:
print(sa_v2.W_key.weight.data)

tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]])


In [56]:
sa_v1.W_key.data = sa_v2.W_key.weight.data.T
sa_v1.W_value.data = sa_v2.W_value.weight.data.T
sa_v1.W_query.data = sa_v2.W_query.weight.data.T

In [57]:
sa_v1(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

In [58]:
torch.isclose(sa_v1(inputs).data, sa_v2(inputs).data)

tensor([[True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True]])

The main trick to transfering were:
- accessing the weights via `.weight.data`... `.weight` alone is a `Parameter` and has other information about the `grad_fn`.
- transposing the weight's data (`.weight.data.T`) to pass it from V2 to V1

## Approach #3: Causal attention (hiding future words)

bonus: We'll also use NN dropout here to reduce overfitting

The general way we'll do "casual attention" is by applying a **mask**.

1. (simple) mask along the diagonal with 0s
  ... then compute attention weights as before
2. (fancier) mask along the diagonal with $-\infty$
  ... this allows lets us skip one matrix multiplication, since we're applying softmax anyway

Let's start with the simple mask...

In [59]:
queries = sa_v2.W_query(inputs)     #1
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [60]:
attn_scores

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)

In [61]:
attn_scores.shape[0]

6

In [62]:
context_length = attn_scores.shape[0]
# tril returns the "lower triangle" of the matrix, including the diagonal 
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [63]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


We now want to renormalize each row in the masked attention weights

In [64]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
print(row_sums)

tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)


In [65]:
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


Let's try the fancier mask with the matrix math optimization...

In [66]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [67]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [68]:
context_vec = attn_weights @ values

In [69]:
context_vec

tensor([[0.1855, 0.8812],
        [0.2795, 0.9361],
        [0.3133, 0.9508],
        [0.2994, 0.8595],
        [0.2702, 0.7554],
        [0.2772, 0.7618]], grad_fn=<MmBackward0>)

Let's add **dropout** to randomly ignore some hidden layer units during training. This helps prevent overfitting.

In [70]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [71]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


We didn't add dropout yet, but we've built the intuition of how it will work.

Let's do one more thing. We want to ensure the `SelfAttention` Python class we create is able to handle **batched inputs**.

In [72]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [73]:
print(batch)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [101]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # ensures our buffers are moved to CPU, GPU as needed 
        self.register_buffer('mask', 
                             torch.triu(torch.ones(context_length, context_length), 
                                        diagonal=1)
                            )
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # previously, we just did keys.T
        # to handle a batch we need to transponse the later 2 dims but keep the first
        attn_scores = queries @ keys.transpose(1,2) 
        
        # This is subtly change from our experiment above.
        # [:num_tokens, :num_tokens] .. ?? only update these ??
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        # masked_fill_  -> update in place
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights  = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [102]:
torch.manual_seed(123)
print("batch.shape:", batch.shape)
context_length = batch.shape[1] 
    # [0] - number of inputs ("batch size")
    # [1] - number of tokens per input
    # [2] - size of the input embedding
dropout_ratio = 0.0
ca = CausalAttention(d_in, d_out, context_length, dropout_ratio)
context_vecs = ca(batch)
print("context_vecs.shape", context_vecs.shape)
    # [0] - number of inputs ("batch size")
    # [1] - number of tokens per input
    # [2] - size of the output embedding

batch.shape: torch.Size([2, 6, 3])
context_vecs.shape torch.Size([2, 6, 2])


## Approach #4: Multi-head attention

Intuitively, our multi-head attention model can be thought of as several CausalAttention modules stacked. We execute then all and concat their results...

In [82]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for n in range(num_heads)]

    def forward(self, x):
        return torch.cat([h.forward(x) for h in self.heads], dim = -1)
            

In [83]:
torch.manual_seed(123)
context_length = batch.shape[1] # number of tokens
d_in = batch.shape[2] # input embedding size
d_out = 2
droput_ratio = 0.0
num_heads = 2

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, droput_ratio, num_heads)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape torch.Size([2, 6, 4])


The final dim of the context vecs is now 4. Because: `4 = (2 heads) * (output embedding dim of 2)`

The remaining work is to take this idea and make it efficient by combining the iterative `forward()` steps into a single one, with one matrix mult

## Approach #4 (variant): Efficient multi-head attention

In [133]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout: float, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            f"d_out ({d_out}) must be divisible by num_heads ({num_heads})"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        # b is number inputs in a batch
        b, num_tokens, d_in = x.shape

        # after multiplying by weights, the shape of these is: 
        # (b, num_tokens, d_out)
        queries = self.W_query(x) 
        keys = self.W_key(x)
        values = self.W_value(x)

        # view() aka "reshape"
        # We can think of this as breaking up the single matrix into multiple, one per head (num_heads)
        # recall that `d_out = num_heads * head_dim`
        # after running view(), the new dimensions are: (b, num_tokens, num_heads, head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) 
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim) 
        
        # transpose results in shape 
        # before: (b, num_tokens, num_heads, head_dim)
        #                 -> 
        # after:  (b, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1,2) 
        keys = keys.transpose(1,2) 
        values = values.transpose(1,2)

        # 1st, we called .T 
        # 2nd, we called .transpose(1,2) to handle batches (idx=0)
        # now, we call   .transpose(2,3) to handle batches (idx=0) with multiple heads (idx=1)
        attn_scores = queries @ keys.transpose(2,3)
        
        # apply mask for causal attention
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # normalize
        d_k = keys.shape[-1]
        attn_weights = torch.softmax(attn_scores / d_k**0.5, dim=-1)

        # dropout
        attn_weights = self.dropout(attn_weights)
        
        # compute context vector
        # 
        # we transpose to convert from:
        # before: (b, num_heads, num_tokens, head_dim)
        #     ->
        # after: (b, num_tokens, n_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1,2) 

        # combine the heads
        # before: (b, num_heads, num_tokens, head_dim)
        #     ->
        # after: (b, num_tokens, d_out)
        #                        d_out = num_tokens * head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)

        # combine the heads through a linear layer
        # this is considered optional... why?
        # TODO: appendix B for more details
        context_vec = self.out_proj(context_vec)

        return context_vec

In [134]:
# reminding myself of behavior.. 
# ... diagonal=0 means diagonal values are RETAINED
print("\ntriu(..., diagonal=0) retains the diag")
print(torch.triu(torch.ones(context_length, context_length), diagonal=0))

# ... diagonal=1 means diagonal values are OMITTED (0's)
print("\ntriu(..., diagonal=1) omits the diag")
print(torch.triu(torch.ones(context_length, context_length), diagonal=1))


triu(..., diagonal=0) retains the diag
tensor([[1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.]])

triu(..., diagonal=1) omits the diag
tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


Let's review the intuition behind batch matrix multiplication...

In [135]:
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],   
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])
print(a.shape)

torch.Size([1, 2, 3, 4])


In [136]:
print(a.transpose(2,3))

tensor([[[[0.2745, 0.8993, 0.7179],
          [0.6584, 0.0390, 0.7058],
          [0.2775, 0.9268, 0.9156],
          [0.8573, 0.7388, 0.4340]],

         [[0.0772, 0.4066, 0.4606],
          [0.3565, 0.2318, 0.5159],
          [0.1479, 0.4545, 0.4220],
          [0.5331, 0.9737, 0.5786]]]])


In [137]:
print(a @ a.transpose(2,3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


This above single matrix multiplication is equal to computing each head independently. For example:

In [138]:
first_head = a[0,0,:,:]
first_res = first_head @ first_head.T
print("First head:\n", first_res)

second_head = a[0,1,:,:]
second_res = second_head @ second_head.T
print("Second head\n", second_res)


First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])
Second head
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


Let's try running it!

In [139]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape torch.Size([2, 6, 2])


_Exercise 3.3: Initializing GPT-2 size attention modules_

Using the MultiHeadAttention class, initialize a multi-head attention module that has the same number of attention heads as the smallest GPT-2 model (12 attention heads). Also ensure that you use the respective input and output embedding sizes similar to GPT-2 (768 dimensions). Note that the smallest GPT-2 model supports a context length of 1,024 tokens.

In [180]:
d_in = 768
d_out = 768
context_length = 1024
dropout_ratio = 0.0
num_heads = 12

gpt2 = MultiHeadAttention(d_in, d_out, context_length, dropout_ratio, num_heads=num_heads)

In [181]:
torch.manual_seed(123)
gpt2_input = torch.rand((1024,768))
print(gpt2_input.shape)
print(gpt2_input[0].shape)

torch.Size([1024, 768])
torch.Size([768])


In [182]:
gpt2_batch = torch.stack((gpt2_input, gpt2_input), dim=0)
print(gpt2_batch.shape)

torch.Size([2, 1024, 768])


In [183]:
gpt2(gpt2_batch)

tensor([[[-0.0685,  0.0206, -0.3180,  ...,  0.1948, -0.1485, -0.2868],
         [-0.1240, -0.0154, -0.2823,  ...,  0.1384, -0.1513, -0.2925],
         [-0.0590, -0.0165, -0.2872,  ...,  0.0375, -0.1005, -0.3074],
         ...,
         [ 0.0014, -0.0153, -0.2097,  ...,  0.0947, -0.0958, -0.3008],
         [ 0.0013, -0.0153, -0.2097,  ...,  0.0941, -0.0960, -0.3007],
         [ 0.0014, -0.0155, -0.2098,  ...,  0.0944, -0.0956, -0.3011]],

        [[-0.0685,  0.0206, -0.3180,  ...,  0.1948, -0.1485, -0.2868],
         [-0.1240, -0.0154, -0.2823,  ...,  0.1384, -0.1513, -0.2925],
         [-0.0590, -0.0165, -0.2872,  ...,  0.0375, -0.1005, -0.3074],
         ...,
         [ 0.0014, -0.0153, -0.2097,  ...,  0.0947, -0.0958, -0.3008],
         [ 0.0013, -0.0153, -0.2097,  ...,  0.0941, -0.0960, -0.3007],
         [ 0.0014, -0.0155, -0.2098,  ...,  0.0944, -0.0956, -0.3011]]],
       grad_fn=<ViewBackward0>)