<a href="https://colab.research.google.com/github/elliemci/building-LLM/blob/main/attention_comp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention Mechanism

RNN work for translatting short sentences but fail for longer texts as they don't have direct access to previous words in the input. *Bahdanau attention* mechanism for RNNs modifies the decoder to have selective access. to differeent parts of the input sequence at each decoding step. **Self-attention** allows each position in the input sequence to access all positions in the same sequence when computing the representation of a sequence

## Setup

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/LLM
%ls

/content/drive/MyDrive/Colab Notebooks/LLM
attention_comp.ipynb  sliding_window_sampling.ipynb  token_embedding.ipynb
pytorch_wormup.ipynb  the-verdict.txt                tokenizing_text.ipynb


In [None]:
!pip install tiktoken



In [None]:
import torch
import tiktoken
import torch.nn as nn

## Text tokenization and embedding

In [None]:
# instantiate BPE tokenizer from tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

with open("the-verdict.txt", "r", encoding="utf-8") as f:
  raw_text = f.read()

enc_text = tokenizer.encode(raw_text)


context_size = 6
context = enc_text[:context_size]

print(f"tokenized input text:\n{context}")
print(f"detokenizezd text:\n{tokenizer.decode(context)}")

tokenized input text:
[40, 367, 2885, 1464, 1807, 3619]
detokenizezd text:
I HAD always thought Jack


In [None]:
word_token_dict = {k.strip() : v  for k,v in zip(raw_text[:context_size].split(), context)}

for k in list(word_token_dict.keys())[:context_size]:
  print(f"{k} : {word_token_dict[k]}")

I : 40
HAD : 367


In [None]:
vocab_size = 50257
output_dim = 5
# max length of input text
# suported input size of the LLM is
# context_length = max_length

# embedding each tocken into a output_dim - dimensional vector
torch.manual_seed(123)
embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
#print(f"randomly initialized weight matrix:\n{embedding_layer.weight}")
token_embeddings = embedding_layer(torch.tensor(context, dtype=torch.int))
print(token_embeddings)

tensor([[ 1.1721, -0.4372, -0.4053,  0.7086,  0.9533],
        [ 2.0478,  1.8619, -1.4766, -1.4558, -0.5568],
        [ 0.4197,  0.6117, -0.2094,  0.9823,  0.9884],
        [ 1.1062, -0.5667, -1.8651,  0.3535,  2.5554],
        [-0.3660,  1.7561,  0.8017,  0.9675,  1.7021],
        [-1.1976, -1.5655, -1.2657, -0.3559, -0.7629]],
       grad_fn=<EmbeddingBackward0>)


In [None]:
inputs = token_embeddings
query = inputs[0]

for i, x_i in enumerate(inputs):
  print(f"token number {i} with {output_dim}-dim embeding\n{x_i.data}")
  print(f"its context scores to the first word 'I':\n{torch.dot(x_i, query)}")
  print()

token number 0 with 5-dim embeding
tensor([ 1.1721, -0.4372, -0.4053,  0.7086,  0.9533])
its context scores to the first word 'I':
3.1400835514068604

token number 1 with 5-dim embeding
tensor([ 2.0478,  1.8619, -1.4766, -1.4558, -0.5568])
its context scores to the first word 'I':
0.622178852558136

token number 2 with 5-dim embeding
tensor([ 0.4197,  0.6117, -0.2094,  0.9823,  0.9884])
its context scores to the first word 'I':
1.947705626487732

token number 3 with 5-dim embeding
tensor([ 1.1062, -0.5667, -1.8651,  0.3535,  2.5554])
its context scores to the first word 'I':
4.986784934997559

token number 4 with 5-dim embeding
tensor([-0.3660,  1.7561,  0.8017,  0.9675,  1.7021])
its context scores to the first word 'I':
0.7866894602775574

token number 5 with 5-dim embeding
tensor([-1.1976, -1.5655, -1.2657, -0.3559, -0.7629])
its context scores to the first word 'I':
-1.1858032941818237



## Simple Self-Attention

Self-attention computes attention weights by relating. different. positions wirhin a single input sequence.

### Context vector for a query

#### Attention scores wrt a query

The scalar value of the sum of the element-wise multiplication of the every element of the sequence with th query is their dot product which is a measure of similarity as it quantifies how much two vectors are aligned.

In [None]:
# compute the intemediate attention scores between the query,
# the first word in the sequence with every other input token
# as the dot product of their embeding coordinates
attn_scores_0 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    attn_scores_0[i] = torch.dot(x_i, query)

print(f"second word attention score:\n{attn_scores_0}")


second word attention score:
tensor([ 3.1401,  0.6222,  1.9477,  4.9868,  0.7867, -1.1858,  4.0621, -1.7895,
         3.2612,  0.4932, -0.8053, -1.4909, -0.9968,  0.3160, -0.7135, -0.8053,
        -2.5393,  1.2947, -1.5226,  0.3160], grad_fn=<CopySlices>)


#### Attention Weights wrt a query

To obtain attention weights, normalize attention scores with sum up to 1 or using softmax, which is better at managing extreme values and offers more favorable gradient properties during training. Softmax function ensures that the attention weights are always positive. The output interpretable as probabilities or relative importance, where higher weights indicate greater importance.

In [None]:
# obtain attention weights by normalizing the attention scores,
# so that the weights sup up to 1:
attn_weights_0_tmp = attn_scores_0 / attn_scores_0.sum()
print("Attention weights from scores normalized with sum:\n", attn_weights_0_tmp)
print("Sum:", attn_weights_0_tmp.sum())

# or by using softmax function
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

# pytorch implementation of softmax, optimize for paerformance
attn_weights_0_naive = torch.softmax(attn_scores_0, dim=0)
print("Attention weights from scores with softmax:", attn_weights_0_naive)
print("Sum:", attn_weights_0_naive.sum())

# context vector for the second word is calculated as a weighted sum
# of all input vectors
context_vec_0 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_0 += attn_weights_0_naive[i]*x_i

print(f"context vector for second word:\n{context_vec_0}")

Attention weights from scores normalized with sum:
 tensor([ 0.3349,  0.0663,  0.2077,  0.5318,  0.0839, -0.1265,  0.4332, -0.1908,
         0.3478,  0.0526, -0.0859, -0.1590, -0.1063,  0.0337, -0.0761, -0.0859,
        -0.2708,  0.1381, -0.1624,  0.0337], grad_fn=<DivBackward0>)
Sum: tensor(1., grad_fn=<SumBackward0>)
Attention weights from scores with softmax: tensor([8.3840e-02, 6.7599e-03, 2.5445e-02, 5.3145e-01, 7.9687e-03, 1.1085e-03,
        2.1079e-01, 6.0610e-04, 9.4632e-02, 5.9422e-03, 1.6219e-03, 8.1700e-04,
        1.3391e-03, 4.9768e-03, 1.7776e-03, 1.6219e-03, 2.8635e-04, 1.3243e-02,
        7.9154e-04, 4.9768e-03], grad_fn=<SoftmaxBackward0>)
Sum: tensor(1., grad_fn=<SumBackward0>)
context vector for second word:
tensor([ 1.1629, -0.4412, -1.4258,  0.2621,  1.9211], grad_fn=<AddBackward0>)


### Context Vector for a query

The context vector z_0 is calculated by multiplying the embedded input tokens, x_i, with the corresponding attention weights and then summing the resulting vectors, i.e. **the weighted sum of all imput vectors**

In [None]:
context_vec_0 = torch.zeros(query.shape)

for i,x_i in enumerate(inputs):
    context_vec_0 += attn_weights_0_naive[i]*x_i
print(context_vec_0)

tensor([ 1.1629, -0.4412, -1.4258,  0.2621,  1.9211], grad_fn=<AddBackward0>)


### Context vectors for all input tokens

#### Attention Scores

In [None]:
# compute attention score bw each pair of inputs
attn_scores = torch.empty(inputs.shape[0], inputs.shape[0])

# with add an additional for-loop to compute the dot products
for i, x_i in enumerate(inputs):
  for j, x_j in enumerate(inputs):
    attn_scores[i, j] = torch.dot(x_i, x_j)

# matrix multiplication musch faster than for-loop
attn_scores = inputs @ inputs.T

print(f"All pairs attention scores:\n{attn_scores}")

#### Attention Weights

In [None]:
# normalize each input element attention scores so that they sum up to 1
attn_weights = torch.softmax(attn_scores, dim=1)

print(f"All rowa sum to 1:{attn_weights.sum(dim=1)}")
print(f" Attention weights:\n{attn_weights}")

#### Context Vectors

In [None]:
all_context_vecs = attn_weights @ inputs
print(f"All context vectors:\n{all_context_vecs}")

All context vectors:
tensor([[ 1.1629, -0.4412, -1.4258,  0.2621,  1.9211],
        [ 2.0427,  1.8600, -1.4764, -1.4530, -0.5568],
        [ 0.5325,  0.4804, -0.4147,  0.5660,  1.4415],
        [ 1.1028, -0.5601, -1.8310,  0.3006,  2.4950],
        [-0.1962,  1.6651,  0.9834,  0.9394,  1.5002],
        [-1.1453, -1.4533, -1.2700, -0.3885, -0.6535],
        [ 1.1886, -0.5651, -1.6610,  0.1374,  2.3083],
        [ 1.0457,  0.2737, -0.2043, -0.5802, -2.0543],
        [ 1.0403, -0.3959, -1.8066,  0.3355,  2.0152],
        [-0.8372, -0.0726, -2.0837, -1.6954,  1.9354],
        [-0.0320,  0.7286,  1.4154,  0.5579,  0.0824],
        [-0.5520,  0.4166,  0.0678, -0.0263,  0.3155],
        [ 1.0582, -0.0538,  0.1632, -0.2290, -1.9805],
        [ 1.1630,  1.2001,  2.8794,  0.8324,  0.0437],
        [-0.4042, -0.0782,  0.2316,  0.1970, -0.3493],
        [-0.0320,  0.7286,  1.4154,  0.5579,  0.0824],
        [ 0.2798,  0.7391,  2.6261,  0.6391, -0.6024],
        [ 1.7956, -0.1916,  0.1868, -0.1356,

In [None]:
print(f"previousl calculated context_vec_0:\n{context_vec_0}")
print(f"\nFirst row of all contet vectors tensor:\n{all_context_vecs[0]}")

previousl calculated context_vec_0:
tensor([ 1.1629, -0.4412, -1.4258,  0.2621,  1.9211], grad_fn=<AddBackward0>)

First row of all contet vectors tensor:
tensor([ 1.1629, -0.4412, -1.4258,  0.2621,  1.9211],
       grad_fn=<SelectBackward0>)


## Self-attention with trainable weights

### Attention scores wrt quiery

In [None]:
print(f"query is the first sequence token: {inputs[0]}")
print(f"input sequence consists of {inputs.shape[0]} tokens.")

query is the first sequence token: tensor([ 1.1721, -0.4372, -0.4053,  0.7086,  0.9533],
       grad_fn=<SelectBackward0>)
input sequence consists of 13 tokens.


In [None]:
# set a quiery to the first word in text
x_0 = inputs[0]

d_in = 5
d_out = 5

In [None]:
# initialize weight matricess Wq, Wk, and Wv;
# use requires_grad=False for less clutered output, =True for model training
torch.manual_seed(123)

Wq = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
Wk  = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
Wv = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

The three trainable weight matrices Wq, Wk and Wv are used to project an embedded input token into query, key and value vectors. The terms **key**, **query** and **value** are borrowed from information retrival and databases:

* query is used to probe the other parts of the input sequence to determine how much attention to pay to them

* key is used for indexing and serching, each item in the input sequence has an associated key

* value is representation of the input items

When the model determines which keys and thus which parts of the input are most relevant to the query, the current focus item, it retrieves the corresponding values

In [None]:
inputs.shape

torch.Size([13, 5])

In [None]:
# compute query, key and value vectors for the first token
query_0 = x_0 @ Wq
key_0 = x_0 @ Wk
value_0 = x_0 @ Wv

print(query_0)

tensor([0.9389, 0.9168, 1.4128, 1.8544, 0.3706], grad_fn=<SqueezeBackward4>)


In [None]:
# compute all keys and values
keys = inputs @ Wk
values = inputs @ Wv

# project the first token from d_in to d_out embeding space
print("all keys shape:", keys.shape)
print("all values shape:", values.shape)

all keys shape: torch.Size([13, 5])
all values shape: torch.Size([13, 5])


In [None]:
# compute the attention score for query
keys_0 = keys[0]
attn_score_00 = query_0.dot(keys_0)

print(f"unormalized attention for the first element: {attn_score_00}")

unormalized attention for the first element: 7.827261924743652


In [None]:
# compute all attention scores with matrix mutiplication
attn_scores_0 = query_0 @ keys.T # All attention scores for given query

print(attn_scores_0)

tensor([  7.8273,   6.6440,  10.7050,   9.7635,  16.6555, -16.9381,   7.2245,
         -7.0322,   5.7098,  -4.8935,   1.2303,  -3.0847,  -5.4073],
       grad_fn=<SqueezeBackward4>)


### Attention weights

Weight parameters are the fundamental, learned coefficients that define the network's connections, while attention weights are dynamic, context-specific values.

The reason for the normalization by the embedding dimension size is to improve the training performance by avoiding small gradients, which can slow down learning.

In [None]:
# scale the attention scores by dividing them by the square toot (**0.5)
# of the embedding dimension of the keys
d_k = keys.shape[-1]

attn_weights_0 = torch.softmax(attn_scores_0 / d_k**0.5, dim=-1)
print(f"attention weights for the first element in input sequence:\n{attn_weights_0}")

attention weights for the first element in input sequence:
tensor([1.6490e-02, 9.7141e-03, 5.9721e-02, 3.9199e-02, 8.5478e-01, 2.5540e-07,
        1.2593e-02, 2.1437e-05, 6.3965e-03, 5.5791e-05, 8.6285e-04, 1.2528e-04,
        4.4338e-05], grad_fn=<SoftmaxBackward0>)


### Context Vector

In [None]:
# compute the context vector as a weighted sum over the value vectors
context_vec_0 = attn_weights_0 @ values

print(context_vec_0)

tensor([2.2702, 2.3605, 1.9509, 0.9412, 2.0605], grad_fn=<SqueezeBackward4>)


### Self-attention class

In [None]:
inputs = token_embeddings

d_in = 5
d_out = 5

In [None]:
class SelfAttention_v1(nn.Module):
  """ Class which implements self-attention mechanism, by
  transformation of the input matrix X with the three weight
  matrices, Wq, Wk and Wv, which are initialized for queries,
  key and values, and each of which transforms the input dimesion d_in to an
  output dimension d_out. The forward method computes the
  attention scores by multiplying queries and keys,
  normalizing using softmax to get attention weights and
  creating a context vector by weighting the values with
  attention weights. """

  def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.Wq = nn.Parameter(torch.rand(d_in, d_out))
        self.Wk = nn.Parameter(torch.rand(d_in, d_out))
        self.Wv = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
        keys = x @ self.Wk
        queries = x @ self.Wq
        values = x @ self.Wv

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec

In [None]:
torch.manual_seed(123)

sa_v1 = SelfAttention_v1(d_in, d_out)
print(f"the context vectors for the inputs embeded vectors:\n{sa_v1(inputs)}")

the context vectors for the inputs embeded vectors:
tensor([[ 2.2702,  2.3605,  1.9509,  0.9412,  2.0605],
        [ 0.2346, -0.0417,  0.2971,  0.1990,  0.2329],
        [ 2.3596,  2.5487,  2.0782,  0.9772,  2.1461],
        [ 2.3313,  2.4761,  2.0233,  0.9595,  2.1185],
        [ 2.3786,  2.5853,  2.1017,  0.9847,  2.1628],
        [-2.5004, -2.9339, -2.4945, -1.9440, -1.8519],
        [ 2.1642,  2.1103,  1.7717,  0.8904,  1.9541],
        [-2.4981, -2.9302, -2.4886, -1.9387, -1.8508],
        [ 1.9528,  1.7526,  1.5367,  0.8310,  1.7487],
        [-1.8738, -2.1373, -1.4870, -1.0815, -1.4365],
        [ 1.5137,  1.3669,  1.3375,  0.7416,  1.3573],
        [-1.1438, -1.3443, -0.7745, -0.5407, -0.9008],
        [-2.4802, -2.9036, -2.4500, -1.9045, -1.8412]], grad_fn=<MmBackward0>)


In [None]:
class SelfAttention_v2(nn.Module):
  """ Self-attention mechanism class with an optimized weight
  initialization with nn.Linear, which leads to more stable
  and efective model training. """
  def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.Wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wk = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wv = nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, x):
        keys = self.Wk(x)
        queries = self.Wq(x)
        values = self.Wv(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
        context_vec = attn_weights @ values

        return context_vec

In [None]:
torch.manual_seed(789)

sa_v2 = SelfAttention_v2(d_in, d_out)
print(f"inputs context vectors:\n{sa_v2(inputs)}")

inputs context vectors:
tensor([[ 0.0805, -0.3446, -0.4164,  0.3569,  0.2441],
        [ 0.2634, -0.3499, -0.3964,  0.5085,  0.3916],
        [ 0.1661, -0.3426, -0.4020,  0.4278,  0.3110],
        [ 0.0424, -0.4120, -0.3911,  0.3255,  0.2813],
        [ 0.2608, -0.3202, -0.3885,  0.5088,  0.3700],
        [ 0.1498, -0.3763, -0.2494,  0.3326,  0.3329]], grad_fn=<MmBackward0>)


**Note:** nn.Linear in SelfAttention_v2 uses a different weight initialization scheme as nn.Parameter(torch.rand(d_in, d_out)) used in SelfAttention_v1, which causes both mechanisms to produce different results.

## Causal Attention

Causal Attention is a specialized form of self-attention that restircts a model to only consider previous and current inputs in a sequence when processing any given token.

### Attention Mask

For each token processed, the future tokens, which come after the current token in the input text are masked. Thus the LLM cant access future tokens when computing the context vectors using attention weights. This is implemented masking out the attention weights above the diagonal, by applying the softmax function to the attentions cores, zeroing out the elemnts above the diagonal and normalizing the  the non-masked attention weights such that the attention weights sum up to 1 in each row

#### Compute Attention weights

In [None]:
# compute attention weights using softmax function
queries = sa_v2.Wq(inputs)
keys = sa_v2.Wk(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)

print(f"normalized attention weights matrix:\n{attn_weights}")


normalized attention weights matrix:
tensor([[0.1815, 0.1013, 0.1748, 0.2151, 0.1845, 0.1428],
        [0.2505, 0.2075, 0.1711, 0.1324, 0.1057, 0.1327],
        [0.1998, 0.1520, 0.1732, 0.1806, 0.1515, 0.1429],
        [0.1835, 0.0825, 0.1734, 0.2523, 0.2339, 0.0744],
        [0.2101, 0.2091, 0.1644, 0.1409, 0.1125, 0.1630],
        [0.1177, 0.1556, 0.1554, 0.1897, 0.2715, 0.1102]],
       grad_fn=<SoftmaxBackward0>)


#### Mask out

In [None]:
# create a mask where the values above the diagonal are zero
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))

print(f"mask with 0s above the diagonal:\n{mask_simple}")

mask with 0s above the diagonal:
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [None]:
# multiply the mask with attention weights to zero out the values
# above the diagonal
masked_attn_weights = attn_weights * mask_simple
print(f"(attention weight matrix:\n{masked_attn_weights}")

(attention weight matrix:
tensor([[0.1815, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2505, 0.2075, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1998, 0.1520, 0.1732, 0.0000, 0.0000, 0.0000],
        [0.1835, 0.0825, 0.1734, 0.2523, 0.0000, 0.0000],
        [0.2101, 0.2091, 0.1644, 0.1409, 0.1125, 0.0000],
        [0.1177, 0.1556, 0.1554, 0.1897, 0.2715, 0.1102]],
       grad_fn=<MulBackward0>)


#### Renormalize

In [None]:
# remormalize the attention weights to sum up to 1 gain in each row
row_sums = mask_simple.sum(dim=1, keepdim=True)
masked_simple_norm = mask_simple / row_sums
print(f"normalized attention weight matrix:\n{masked_simple_norm}")

normalized attention weight matrix:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])


#### Efficient masking

The softmax function converts its inputs into a probability distribution. When -∞ are present in a row, the softmax function treats them as zero probability, since because e-∞ approaches 0. A more efficient masking possible by creating a mask with 1's above the diagonal and then replacing these 1's with -inf values

In [None]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)

print(f"mask with -inf values above diagonal:\n{masked}")

attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)

print(f"\nnormalized attention weights:n{attn_weights}")

mask with -inf values above diagonal:
tensor([[ 0.4379,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.7773,  0.3562,    -inf,    -inf,    -inf,    -inf],
        [ 0.4701, -0.1414,  0.1503,    -inf,    -inf,    -inf],
        [ 0.8943, -0.8932,  0.7666,  1.6058,    -inf,    -inf],
        [ 0.4120,  0.4010, -0.1357, -0.4809, -0.9841,    -inf],
        [-0.3805,  0.2437,  0.2405,  0.6873,  1.4886, -0.5277]],
       grad_fn=<MaskedFillBackward0>)

normalized attention weights:ntensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5469, 0.4531, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3806, 0.2895, 0.3299, 0.0000, 0.0000, 0.0000],
        [0.2653, 0.1193, 0.2506, 0.3648, 0.0000, 0.0000],
        [0.2510, 0.2498, 0.1965, 0.1684, 0.1344, 0.0000],
        [0.1177, 0.1556, 0.1554, 0.1897, 0.2715, 0.1102]],
       grad_fn=<SoftmaxBackward0>)


### Weights Dropout

Randomly selected hidden layer units are ignored during training, whcih prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units.
Drop outs are apply either after calculating attention score or after appling the attention weights

In [None]:
torch.manual_seed(123)
# masking out half of the attention weights
dropout = nn.Dropout(0.5)

print(f"attention weight matrix with drop outs:\n{dropout(attn_weights)}")

attention weight matrix with drop outs:
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.9061, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6598, 0.0000, 0.0000, 0.0000],
        [0.5307, 0.2386, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5020, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3112, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


### Causal Attention Class

In [None]:
class CausalAttention(nn.Module):
  """ A self attention class with added dropouts and causel mask. """
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
      super().__init__()
      self.d_out = d_out
      self.Wq = nn.Linear(d_in, d_out, bias=qkv_bias)
      self.Wk = nn.Linear(d_in, d_out, bias=qkv_bias)
      self.Wv = nn.Linear(d_in, d_out, bias=qkv_bias)
      self.dropout = nn.Dropout(dropout)
      # with the use of register_buffer there is no need to manually
      # ensure the tensors are on the same device as the model parameters
      self.register_buffer(
          'mask',
          torch.triu(torch.ones(context_length, context_length),
          diagonal=1)
        )

  def forward(self, x):
      b, num_tokens, d_in = x.shape
      #new batch dimension b
      keys = self.Wk(x)
      queries = self.Wq(x)
      values = self.Wv(x)

      attn_scores = queries @ keys.transpose(1, 2)
      attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
      attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
      attn_weights = self.dropout(attn_weights)

      context_vec = attn_weights @ values
      return context_vec

In [None]:
# to simulate a batch inputs, duplicate the input text
batch = torch.stack((inputs, inputs), dim=0)
print(f"{batch.shape[0]} input texts with {batch.shape[1]} token each, where each token is {batch.shape[2]}-dimensional vector:\n{batch.shape}")

2 input texts with 6 token each, where each token is 5-dimensional vector:
torch.Size([2, 6, 5])


In [None]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)

print("context_vecs.shape:", context_vecs.shape)
print(f"\n{context_vecs.shape[2]}-dim context vectors for context_vecs.shape[0] input ssequences of {context_vecs.shape[1]} tokens each:\n{context_vecs}")

context_vecs.shape: torch.Size([2, 6, 5])

5-dim context vectors for context_vecs.shape[0] input ssequences of 6 tokens each:
tensor([[[ 0.1682, -0.8474, -0.2907, -0.0228,  0.3905],
         [ 0.1450, -1.4204,  1.1329,  0.2614,  1.5192],
         [ 0.1730, -1.1179,  0.5192,  0.0413,  0.9298],
         [ 0.0618, -1.4199,  0.4697,  0.2299,  1.0985],
         [ 0.1228, -1.2190,  0.3619,  0.0986,  0.8890],
         [ 0.0015, -0.8100,  0.2812,  0.1129,  0.5850]],

        [[ 0.1682, -0.8474, -0.2907, -0.0228,  0.3905],
         [ 0.1450, -1.4204,  1.1329,  0.2614,  1.5192],
         [ 0.1730, -1.1179,  0.5192,  0.0413,  0.9298],
         [ 0.0618, -1.4199,  0.4697,  0.2299,  1.0985],
         [ 0.1228, -1.2190,  0.3619,  0.0986,  0.8890],
         [ 0.0015, -0.8100,  0.2812,  0.1129,  0.5850]]],
       grad_fn=<UnsafeViewBackward0>)


## Multi-head attention

Multi-head attention consists of stacking multiple instances of causal self-attention mechanisms each with its own weights, and combining their outputs.

For example a multi-head attention with two single head attention modules uses two Wv value matrices, two Wq and Wk matrices, obtaining two sets of context vectors which are concatenated into a single context vector matrix, where each context vector is d_out * num_heads-dimensional.

### Sequential multiple single-head attention modules

In [None]:
class MultiHeadAttentionWrap(nn.Module):
  """A wrapper class that stacks multiple instances of the CausalAttention Class,
     which are process sequentially with a loop in the forward method"""
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    self.head = nn.ModuleList(
        [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])

  def forward(self, x):
    return torch.cat([head(x) for head in self.head], dim=-1)

In [None]:
context = 12
inputs = embedding_layer(torch.tensor(context, dtype=torch.int))
type(inputs)

torch.Tensor

In [None]:
torch.manual_seed(123)

context_size = 12
context = enc_text[:context_size]
inputs = embedding_layer(torch.tensor(context, dtype=torch.int))

# extract two input text sequences of lenth 6 tokens each into a batch
max_length = 6

first_chunk = inputs[:max_length]
second_chunk = inputs[max_length : 2*max_length]
batch = torch.stack((first_chunk, second_chunk), dim=0)
print(f"batch of two input sequences:\n{batch}")

# the number of input tokens
context_length = batch.shape[1]

d_in, d_out = 5, 5
num_heads = 2

mha = MultiHeadAttentionWrap(d_in, d_out, context_length, dropout=0.0, num_heads=num_heads)

context_vecs = mha(batch)

print("context_vecs.shape:", context_vecs.shape)
print(f"context vector tensor with {context_vecs.shape[0]} input texts with {context_vecs.shape[1]} token each, where each token is {context_vecs.shape[2]}-dimensional vector:\n{context_vecs}")

batch of two input sequences:
tensor([[[ 1.1721, -0.4372, -0.4053,  0.7086,  0.9533],
         [ 2.0478,  1.8619, -1.4766, -1.4558, -0.5568],
         [ 0.4197,  0.6117, -0.2094,  0.9823,  0.9884],
         [ 1.1062, -0.5667, -1.8651,  0.3535,  2.5554],
         [-0.3660,  1.7561,  0.8017,  0.9675,  1.7021],
         [-1.1976, -1.5655, -1.2657, -0.3559, -0.7629]],

        [[ 1.6024, -0.6694, -1.0676, -0.3782,  1.8112],
         [ 0.8424,  0.2846, -0.2265, -0.7069, -2.3531],
         [ 0.9281, -0.1081, -1.9729,  0.7692,  0.8198],
         [-0.8608, -0.0757, -2.0866, -1.7262,  1.9371],
         [-0.7313,  0.2600,  0.6740,  0.4362,  0.1359],
         [-1.0130,  0.1652,  0.3707, -0.4285,  0.2333]]],
       grad_fn=<StackBackward0>)
context_vecs.shape: torch.Size([2, 6, 10])
context vector tensor with 2 input texts with 6 token each, where each token is 10-dimensional vector:
tensor([[[ 1.6818e-01, -8.4739e-01, -2.9074e-01, -2.2761e-02,  3.9050e-01,
           8.2460e-03, -6.0613e-01, -5.5

### Split multy heads attention

In [None]:
class MultiHeadAttention(nn.Module):
  """ It starts with multi-head layer which gets splits into multiple individual
      attention heads by reshaping and transposing query, key and value tesnors,
      coumputes the attention for each and then combines the results."""
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.Wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wk = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wv = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
             torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

  def forward(self, x):
    batch, num_tokens, d_in = x.shape
    keys = self.Wk(x)
    queries = self.Wq(x)
    values = self.Wv(x)

    # reshape to represent multiple heads, spliting the d_out into
    # num_heads = d_out/ head_dim
    keys = keys.view(batch, num_tokens, self.num_heads, self.head_dim)
    values = values.view(batch, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(batch, num_tokens, self.num_heads, self.head_dim)

    # transpose to bring the num_heads dim before the num_tokens dim,
    # needed for aligning the queries, keys and values across the different
    # heads and performing batched matrix multiplication
    keys = keys.transpose(1, 2)
    queries = queries.transpose(1, 2)
    values = values.transpose(1, 2)

    # a batched matrix multiplication between the tensor itself and a view of
    # the tensor with last two dimentions num_tokens and head_dim being transposed
    attn_scores = queries @ keys.transpose(2, 3)
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

    attn_scores.masked_fill_(mask_bool, -torch.inf)

    attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)
    # the context vectors from all heads are transposed back to the shape (batch, num_tokens, num_heads, head_dim
    context_vec = (attn_weights @ values).transpose(1, 2)
    # to combine the outputs from all heads, the context vectors are reshaped into(batch, num_tokens, d_out)
    context_vec = context_vec.contiguous().view(batch, num_tokens, self.d_out)
    # adding an output projection layer
    context_vec = self.out_proj(context_vec)

    return context_vec

In [None]:
tensor = torch.rand(size = [1,2,3,4])

tensor

tensor([[[[0.6918, 0.3545, 0.7969, 0.0061],
          [0.2528, 0.0882, 0.6997, 0.4855],
          [0.4067, 0.4168, 0.1092, 0.6418]],

         [[0.5125, 0.1549, 0.6881, 0.4900],
          [0.0164, 0.7690, 0.7674, 0.4058],
          [0.1548, 0.5201, 0.8773, 0.9577]]]])

In [None]:
# test batched matrix multiplication with example
b = torch.rand(size = [1,2,3,4])

print(f"batch:\n{b}")
# matrix multiplication between the tensor itself and a view of the tensor
# transposing the last two dimensions
print(f"\nbatch matrix multiplication:\n{b @ b.transpose(2, 3)}")

# matrix multiplication carried out bw the two last dimensions
# of num_tokens and head_dim, and then repeated for individiual heads
first_head = b[0, 0, :, :]
first_res = first_head @ first_head.T
print("\nFirst head:\n", first_res)

second_head = b[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

batch:
tensor([[[[0.7239, 0.3604, 0.1829, 0.2956],
          [0.8646, 0.8010, 0.8044, 0.0733],
          [0.7355, 0.6248, 0.1638, 0.5158]],

         [[0.6000, 0.2299, 0.2890, 0.9078],
          [0.4596, 0.4947, 0.1836, 0.2010],
          [0.9603, 0.6861, 0.4209, 0.8046]]]])

batch matrix multiplication:
tensor([[[[0.7748, 1.0834, 0.9401],
          [1.0834, 2.0415, 1.3059],
          [0.9401, 1.3059, 1.2242]],

         [[1.3205, 0.6250, 1.5860],
          [0.6250, 0.5301, 1.0198],
          [1.5860, 1.0198, 2.2175]]]])

First head:
 tensor([[0.7748, 1.0834, 0.9401],
        [1.0834, 2.0415, 1.3059],
        [0.9401, 1.3059, 1.2242]])

Second head:
 tensor([[1.3205, 0.6250, 1.5860],
        [0.6250, 0.5301, 1.0198],
        [1.5860, 1.0198, 2.2175]])


In [None]:
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2 # set so that num_heads = d_out/ head_dim

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(f"context vectors:\n{context_vecs}")
print("context_vecs.shape:", context_vecs.shape)

context vectors:
tensor([[[ 0.8227,  0.0246],
         [ 0.6647, -0.0818],
         [ 0.6566, -0.1010],
         [ 0.5320, -0.1808],
         [ 0.4877, -0.2188],
         [ 0.4077, -0.2375]],

        [[ 0.1708, -0.4087],
         [ 0.6776, -0.0564],
         [ 0.5588, -0.1505],
         [ 0.1270, -0.4430],
         [ 0.2802, -0.3071],
         [ 0.2610, -0.3173]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
