# Demo: GPT from scratch

This notebook tries to explain some of the code used in the final GPT model step-by-step, as well as going over the fundamental mathematical trick needed for calculating self-attention.

In [1]:
from typing import List
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
with open("artifacts/input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [3]:
print(f"Amount of characters in dataset: {len(text)}")

Amount of characters in dataset: 1115394


In [4]:
# The firts 1000 characters of the dataset
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
# Get all unique characters from the dataset
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(chars)
print(vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


The following encoder and decoder function effectively function as our character-level tokenizer.

In [6]:
string_to_int = {ch:i for i, ch in enumerate(chars)}
int_to_string = {i:ch for i, ch in enumerate(chars)}

def encode(string: str) -> List[int]:
    """"Encoder: input a string, output a list of integers"""
    return [string_to_int[c] for c in string]

def decode(list: int) -> str:
    """Decoder: Take a list of integers, output a string"""
    return "".join(int_to_string[i] for i in list)

print(encode("Hello World!"))
print(decode(encode("Hello World!")))

[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]
Hello World!


In [7]:
# Tokenize the entire dataset
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)
print(data[:1000])

torch.Size([1115394])
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57

In [8]:
# Split up data to better identify overfitting
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

We're not going to feed the entire dataset into the Transformer all at once. That would be computationally way too expensive. Instead, we work with chunks of data. During training, we sample **random** chunks of data at a time and feed it to the Transformer. These chunks' size is called the `block_size` (sometimes called `context_length`).

In [9]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

When the above chunk is plugged into a Transformer, we are simultanuously training the model to make a prediction at each position individually. In this example there are 9 positions, but the transformer will only make predictions for positions 1-8 (skipping the first one). So, this means that with the above example: 
- In the context of 18, 47 should be predicted
- In the context of 18 and 47, 56 should be predicted
- ...

In [10]:
x = train_data[:block_size]
y = train_data[1: block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context}, the target is: {target}")

When input is tensor([18]), the target is: 47
When input is tensor([18, 47]), the target is: 56
When input is tensor([18, 47, 56]), the target is: 57
When input is tensor([18, 47, 56, 57]), the target is: 58
When input is tensor([18, 47, 56, 57, 58]), the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]), the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is: 58


The chunk size is actually not only implemented for computational reasons, but also to teach the Transformer to understand context of varying lengths from 1-`block_size`. We want the Transformer to be used to seeing every context length in between.

In [11]:
batch_size = 4 # Amount of independent sequences that will be processed in parallel by the Transformer
block_size = 8 # Max content length for predictions

def get_batch(split):
    """Generate a small batch of data of inputs x and targets y"""
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

# xb: input batches into Transformer (processed in parallel), yb: desired targets
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[57,  1, 58, 46, 43,  1, 51, 39],
        [58, 46, 47, 57,  1, 58, 56, 59],
        [50, 42,  6,  0, 13, 52, 42,  1],
        [ 1, 49, 52, 53, 61,  6,  0, 27]])
targets:
torch.Size([4, 8])
tensor([[ 1, 58, 46, 43,  1, 51, 39, 42],
        [46, 47, 57,  1, 58, 56, 59, 52],
        [42,  6,  0, 13, 52, 42,  1, 61],
        [49, 52, 53, 61,  6,  0, 27, 59]])
----
when input is [57] the target: 1
when input is [57, 1] the target: 58
when input is [57, 1, 58] the target: 46
when input is [57, 1, 58, 46] the target: 43
when input is [57, 1, 58, 46, 43] the target: 1
when input is [57, 1, 58, 46, 43, 1] the target: 51
when input is [57, 1, 58, 46, 43, 1, 51] the target: 39
when input is [57, 1, 58, 46, 43, 1, 51, 39] the target: 42
when input is [58] the target: 46
when input is [58, 46] the target: 47
when input is [58, 46, 47] the target: 57
when input is [58, 46, 47, 57] the target: 1
when input is [58, 46, 47, 57, 1] the target: 58
when input is [58, 46,

In [12]:
print(xb) # Input into Transformer

tensor([[57,  1, 58, 46, 43,  1, 51, 39],
        [58, 46, 47, 57,  1, 58, 56, 59],
        [50, 42,  6,  0, 13, 52, 42,  1],
        [ 1, 49, 52, 53, 61,  6,  0, 27]])


## Simplest possible model

In a bigram model we only look at one word and try to predict the next word. So, a word/token is predicted based ONLY on the value of the previous word/token. Tokens do NOT communicate with each other.

- B: Batch (`batch_size = 4`)
- T: Time (`block_size = 8`)
- C: Channel (`vocab_size = 65`)

In [13]:
class BigramLanguageModel(nn.Module):
    """Class that implements a BigramLanguageModel"""

    def __init__(self, vocab_size):
        """Initialization of the BigramLanguageModel"""
        super().__init__()
        # Create a lookup table that maps the current token to the logits of the next token (bigram)
        # Each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        """Forward propagation for the BigramLanguageModel"""

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # Reshape logits to satisfy pytorch cross_entropy
            targets = targets.view(B*T) # Reshape targets to satisfy pytorch cross_entropy

            # Cross entropy attempts to calculate how well the predicted logits correspond to the actual target
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        """Generate inference results for the BigramLanguageModel"""
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx) # Loss not actually used, because we are not actually generating prediction
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(4.9512, grad_fn=<NllLossBackward0>)

c.rzPGkyG v
fp-J kT!sOioFLftAnYbh!tUP-eKT3crUNejQ',OAkp&J
jQ3
foz?;XEjqtdkftG:AlyvmiUnzsxTNI.mWp pTB


In [14]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [15]:
batch_size = 32
for steps in range(1000): # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

3.749495267868042


In [16]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


cD.Wobzpj'lUHAnky!NA;byG f!CAH?NyLfWAG
-Wft

VUW!iene
qJBlXoaMKkLOner?qld-
'KW,j;uP piE:DAotJWUnhuSZNkccetUiRY n&'bm3$-cJh$cCsv&,AhrEye,ffwrd
F;UVAlLZRNak!cur&vV'ZHivoV&MHoziRD.N?qZwU,tIWUvAfqw;!wsY!:
S;UHdRsz;pleSgWmVHA' D3d-wLfc.mWh ;KhLGIAlSBilGtHere;;!LfBEXstsar3hXv&uVOV'Ar hxTUHBELOJPXmaKCzEPkvisDU cmSXE,$aTanTboulZNuTWEFfcur3YckLm3z&mpiv
VoWGROA,CarEna:3Gjx-.NL?tCsUad;UWoZwuvp&cmoE&??N!wzWoHiEiVofoocyvBmtxpAn$w;UV?Y.ro

  SHisOT&er uRHlnI.kLPFiRELZiROKewoNeilozG3QI'xIqB-3doFRtZTmOgWrtKVMur


This output is not great, but already better than random. The reason for this is that the tokens are not talking to each other. Currently, we only look at the value of the previous token to determine the predicted token. Next we will make the tokens start talking to each other, so that they can figure out what the context is (`block_size`).

## The mathematical trick in self-attention

In [17]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

This tensor contains 4 batches, each batch consists of 8 words.

In the code snippet below we define a bag of words variable, this will contain the average of all words in the sequence up to the word at timestep t (where we are now).

For now we use `for`-loops, later we will optimise this.

`xprev = x[b,:t+1] # (t,C)`: At the current batch dimension, take all words from the sequence up to and including the current word at timestep t. (So it contains t elements, with C channels)

`xbow[b,t] = torch.mean(xprev, 0)`: Calculate the mean of all elements in the 0'th dimension, which is the t dimension in this case. This will result in a 1D vector with C (=2) values.

In [18]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)
print(x[0])
print(xbow[0])

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


As you can see, at the first timestep the vector is the same in both tensors. This is because xbow is averaging over only 1 word. At the second timestep, xbow averages out the tokens at both time steps, and at the third timestep it averages tokens at the first 3 time steps, ... At the final timestep the entire sequence gets averaged.

The mathematical trick that makes attention possible: **matrix multiplication** (`@` in python)

In [19]:
# Version 2: using matrix multiplication for a weighted aggregation
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True) # (T, T)
# PyTorch will see that weights and x do not have the correct dimensions and add the B dimension
xbow2 = weights @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
print(xbow2[0]) # Same as xbow[0] above
torch.allclose(xbow, xbow2)

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


False

`weights` represents the weight we want to give to each element in a batch, here we want to calculate averages so we divide by the row id.

In [20]:
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

`torch.allclose()`: Checks if 2 tensors are equal with a small tolerance (they don't have to be EXACTLY equal)
`torch.tril()`: returns the lower triangular portion of the matrix

**Note**: We use this `torch.tril()` function because we employ **masking** in text generation, where we hide information in the future and focus on (generated) information from the past!

In [21]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

We can modify this and initialise the weights a different way (other than dividing by its own row id's). 

First, we just have to initialise the matrix as all zero's and set all values in the upper triangle to negative infinity (only the lower triangle will have values 0). 

Softmax will go over the rows and treat zero values equally, while -inf values will result in 0 (see Softmax function). 

Since softmax sums up to 1 for each row, we get exactly the same weights output as before.

In [22]:
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T,T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [23]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T,T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
xbow3 = weights @ x
torch.allclose(xbow, xbow3)

False

Now we put this all together to get **self-attention**!

The `weights` matrix cannot be initialised the same as before, because different tokens will find different other tokens more or less interesting. So we want this `weigths` matrix to be data-dependent. This problem is two-fold:
1) I want to **gather information from the past** (aka earlier on in the sequence)
2) I want to do this data gathering **in a data-dependent way**, because I care more about some tokens than others.

Self attention solves this issue; every single token at each position will emit 2 vectors:
1) `query` vector: rougly speaking says "What am I looking for?"
2) `key` vector:  roughly speaking says "What do I contain?"

The affinity between different tokens (**attention**) is a result of the dot-product between the `keys` and the `queries`. This dot product then becomes the new `weights` matrix. A high value in this new `weights` matrix means I will pay much **attention** to that token during the generation of the token at timestep t.

In [24]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
weights =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
#weights = torch.zeros((T,T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

v = value(x)
out = weights @ v
# out = weights @ x

out.shape

torch.Size([4, 8, 16])

`k = key(x)` and `q = query(x)`: In this code snippet the key and query matrices are created. All tokens in the (B, T) arrangement create a key and a query **in parallel and independently**. No communication between tokens has happened yet.

`weigths =  q @ k.transpose(-2, -1)`: Now the communication starts. We want to define the affinity between tokens by performing the dot-product between `query` and `key` matrices (as we can see from the output below, it is now no longer the average but instead **data-dependent**). Before we can do this however, we must transpose the `-2` dimension and `-1` dimension of `k`.

In [25]:
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

If we look at the last timestep, we can see that the tokens at timestep 3 & 6 have a high value. This means the current token at timestep t(=7) cares more about these two tokens than the other ones.

Softmax assured that a lot of information from these positions will be aggregated into the current position by normalizing (consequently the current position will learn a lot about those positions).

This tells us in a **data-dependent** manner how much information to aggregate from the tokens in the past.

In [26]:
v[0]

tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.8321, -0.8144, -0.3242,  0.5191, -0.1252, -0.4898, -0.5287, -0.0314,
          0.1072,  0.8269,  0.8132, -0.0271,  0.4775,  0.4980, -0.1377,  1.4025],
        [ 0.6035, -0.2500, -0.6159,  0.4068,  0.3328, -0.3910,  0.1312,  0.2172,
         -0.1299, -0.8828,  0.1724,  0.4652, -0.4271, -0.0768, -0.2852,  1.3875],
        [ 0.6657, -0.7096, -0.6099,  0.4348,  0.8975, -0.9298,  0.0683,  0.1863,
          0.5400,  0.2427, -0.6923,  0.4977,  0.4850,  0.6608,  0.8767,  0.0746],
        [ 0.1536,  1.0439,  0.8457,  0.2388,  0.3005,  1.0516,  0.7637,  0.4517,
         -0.7426, -1.4395, -0.4941, -0.3709, -1.1819,  0.1000, -0.1806,  0.5129],
        [-0.8920,  0.0578, -0.3350,  0.8477,  0.3876,  0.1664, -0.4587, -0.5974,
          0.4961,  0.6548,  0.0548,  0.9468,  0.4511,  0.1200,  1.0573, -0.2257],
        [-0.4849,  0.1

When the actual aggregation is done in practice, we don't use the tokens exactly: `out = weights @ x`. We actually produce another value called `value`. 

This means that the token at position t does not actually communicate the raw token values contained in `x` (you can consider this vector private to the token), but rather communicates the values in it's corresponding `v` vector.

In summary:
- `key` and `query` vectors are used to compute `weights`, which determines the attention the token at the current timestep will pay to previous time steps.
- the `value` vector is then used to aggregate the current token, how much it aggregates is dependent on the `weights` vector. The past tokens communicate using their `values` vector, not their raw `x` tokens.

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.