# Dev notebook for data exploration, model baselines, and experimentation with self-attention

Preliminary exploration of Russian literature text dataset.

In [544]:
# path to data file
data_path = 'data/tiny-russian-lit/very_clean_tiny_russian_lit.txt'

In [545]:
# read it in for inspection
with open(data_path, 'r', encoding='utf-8') as f:
    text = f.read()

In [546]:
print(f'Length of dataset at {data_path} is {len(text)} characters')

Length of dataset at data/tiny-russian-lit/very_clean_tiny_russian_lit.txt is 34824628 characters


In [547]:
print(f'First 1000 characters of the dataset:\n {text[:1000]}')

First 1000 characters of the dataset:
 Михаил Лермонтов
  

Выхожу один я на дорогу;
Сквозь туман кремнистый путь блестит;
Ночь тиха. Пустыня внемлет богу,
И звезда с звездою говорит.

В небесах торжественно и чудно!
Спит земля в сиянье голубом...
Что же мне так больно и так трудно?
Жду ль чего? жалею ли о чем?

Уж не жду от жизни ничего я,
И не жаль мне прошлого ничуть;
Я ищу свободы и покоя!
Я б хотел забыться и заснуть!

Но не тем холодным сном могилы...
Я б желал навеки так заснуть,
Чтоб в груди дремали жизни силы,
Чтоб, дыша, вздымалась тихо грудь;

Чтоб всю ночь, весь день мой слух лелея,
Про любовь мне сладкий голос пел,
Надо мной чтоб, вечно зеленея,
Темный дуб склонялся и шумел.
Михаил Лермонтов
ВАЛЕРИК
Я к вам пишу случайно; право,
Не знаю как и для чего.
Я потерял уж это право.
И что скажу вам? — ничего!
Что помню вас? — но, боже правый,
Вы это знаете давно;
И вам, конечно, все равно.
И знать вам также нету нужды,
Где я? что я? в какой глуши?
Душою мы друг другу чужды,
Да вр

In [548]:
# find the unique characters that occur in the text
chars = sorted(list(set(text)))
vocab = ''.join(chars)
vocab_size = len(chars)
print(f'Text vocabulary: {vocab}\nVocabulary size: {vocab_size}')

Text vocabulary: 
 !&,-.:;?i ̀́ЁІЉАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяёєі–—’
Vocabulary size: 87


Now, we need to be able to tokenize our input - convert raw string text into a sequence of integers according to our vocabulary of possible elements.

For a character-level language model, each character in our vocabulary gets tokenized.

In [549]:
# create a simple character-level tokenizer: a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # encoder: convert string to list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: convert list of integers to string

In [550]:
def verify(string):
    print(f"The string '{string}' has the encoding {encode(string)}")
    print(decode(encode(string)) == string)

In [551]:
char = decode([1])
utf8_encoded = char.encode('utf-8')
print(utf8_encoded.hex())

20


In [552]:
verify(' ')
verify('\n')
verify('и')
verify('Мой дядя самых честных правил')

The string ' ' has the encoding [1]
True
The string '
' has the encoding [0]
True
The string 'и' has the encoding [57]
True
The string 'Мой дядя самых честных правил' has the encoding [29, 63, 58, 1, 53, 80, 53, 80, 1, 66, 49, 61, 76, 70, 1, 72, 54, 66, 67, 62, 76, 70, 1, 64, 65, 49, 51, 57, 60]
True


In [553]:
# encode the entire text dataset and store in a tensor
import torch

data = torch.tensor(encode(text), dtype=torch.long)
print(f'Input data tensor has shape {data.shape} and type {data.dtype}')
print(f'First 1000 elements of data tensor:\n {data[:1000]}')

Input data tensor has shape torch.Size([34824628]) and type torch.int64
First 1000 elements of data tensor:
 tensor([29, 57, 70, 49, 57, 60,  1, 28, 54, 65, 61, 63, 62, 67, 63, 51,  0,  1,
         1,  0,  0, 19, 76, 70, 63, 55, 68,  1, 63, 53, 57, 62,  1, 80,  1, 62,
        49,  1, 53, 63, 65, 63, 52, 68,  8,  0, 34, 59, 51, 63, 56, 77,  1, 67,
        68, 61, 49, 62,  1, 59, 65, 54, 61, 62, 57, 66, 67, 76, 58,  1, 64, 68,
        67, 77,  1, 50, 60, 54, 66, 67, 57, 67,  8,  0, 30, 63, 72, 77,  1, 67,
        57, 70, 49,  6,  1, 32, 68, 66, 67, 76, 62, 80,  1, 51, 62, 54, 61, 60,
        54, 67,  1, 50, 63, 52, 68,  4,  0, 25,  1, 56, 51, 54, 56, 53, 49,  1,
        66,  1, 56, 51, 54, 56, 53, 63, 79,  1, 52, 63, 51, 63, 65, 57, 67,  6,
         0,  0, 19,  1, 62, 54, 50, 54, 66, 49, 70,  1, 67, 63, 65, 55, 54, 66,
        67, 51, 54, 62, 62, 63,  1, 57,  1, 72, 68, 53, 62, 63,  2,  0, 34, 64,
        57, 67,  1, 56, 54, 61, 60, 80,  1, 51,  1, 66, 57, 80, 62, 77, 54,  1,
        52,

In [554]:
# split data into train and validation sets to test for overfitting
split = 0.8
n = int(split*len(data))
train_data = data[:n]
val_data = data[n:]

Block size, or context length, is the max length of any individual chunk of text that the transformer is trained on. A chunk of text of length `block_size + 1` has `block_size` individual training examples. This also means that the size of the input to the transformer at sampling time will never exceed `block_size`.

In [555]:
block_size = 8
first_block = train_data[:block_size + 1]
print(f'First block of the training data, + 1 character: {first_block}')

First block of the training data, + 1 character: tensor([29, 57, 70, 49, 57, 60,  1, 28, 54])


For a given block of text with length `block_size + 1`, we will train the transformer on each sequence/target pair from length 1 to block_size (where target is character immediately following the last character in the sequence). This is done so that the transformer is 'used' to predicting the next token given contexts of length as small as 1 and as large as block_size. This is important at sampling time, where the transformer has to begin generating targets from a context of potentially less than block_size.

In [556]:
print(f'Training examples/sequences in first block of data')
for i in range(1, block_size + 1):
    print(f'{i}/{block_size}: When input is, {first_block[:i]} target is {first_block[i]}')

Training examples/sequences in first block of data
1/8: When input is, tensor([29]) target is 57
2/8: When input is, tensor([29, 57]) target is 70
3/8: When input is, tensor([29, 57, 70]) target is 49
4/8: When input is, tensor([29, 57, 70, 49]) target is 57
5/8: When input is, tensor([29, 57, 70, 49, 57]) target is 60
6/8: When input is, tensor([29, 57, 70, 49, 57, 60]) target is 1
7/8: When input is, tensor([29, 57, 70, 49, 57, 60,  1]) target is 28
8/8: When input is, tensor([29, 57, 70, 49, 57, 60,  1, 28]) target is 54


In [557]:
torch.manual_seed(3)
batch_size = 4  # the number of independent sequences that we will process in parallel
block_size = 8  # maximum context length for predictions

def get_batch(split):
    # generate a batch of data consisting of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))  # generate batch_size random offsets in the interval [0, len(data) - batch_size)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('-' * 10)

for b in range(batch_size): # batch dimension
    print(f'Batch {b + 1}/{batch_size}')
    for t in range(block_size): # time/position dimension
        context = xb[b, : t+1]
        target = yb[b, t]
        print(f'When input is {context.tolist()}, target is {target}')

inputs:
torch.Size([4, 8])
tensor([[70,  6,  1, 31, 67, 65, 80, 53],
        [49,  1, 72, 54, 60, 63, 51, 54],
        [66, 54, 50, 54,  1, 57, 61, 57],
        [65, 62, 63, 54,  4,  1, 55, 53]])
targets:
torch.Size([4, 8])
tensor([[ 6,  1, 31, 67, 65, 80, 53,  1],
        [ 1, 72, 54, 60, 63, 51, 54, 59],
        [54, 50, 54,  1, 57, 61, 57,  1],
        [62, 63, 54,  4,  1, 55, 53, 54]])
----------
Batch 1/4
When input is [70], target is 6
When input is [70, 6], target is 1
When input is [70, 6, 1], target is 31
When input is [70, 6, 1, 31], target is 67
When input is [70, 6, 1, 31, 67], target is 65
When input is [70, 6, 1, 31, 67, 65], target is 80
When input is [70, 6, 1, 31, 67, 65, 80], target is 53
When input is [70, 6, 1, 31, 67, 65, 80, 53], target is 1
Batch 2/4
When input is [49], target is 1
When input is [49, 1], target is 72
When input is [49, 1, 72], target is 54
When input is [49, 1, 72, 54], target is 60
When input is [49, 1, 72, 54, 60], target is 63
When input is [4

Probably the simplest language model is a bi-gram with character-based tokens. Given a single character, it predicts the next character in the sequence. I now implement a bi-gram as a baseline for our Russian text generation task.

In [558]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # each token reads off the logits (input to softmax) for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers (B = # batches, T = # timesteps/block size)
        # we are essentially predicting the next character based on the embedding of a single token
        logits = self.token_embedding_table(idx)  # (B, T, C) : batch, time, channels
        
        if targets is None:
            loss = None
        else:
            # reshape logits since cross_entropy expects (B, C, T) inputs
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)  # equivalently, targets.view(-1)

            # negative log likelihood loss - calculates quality of our logits with respect to the true targets
            # a 'good' logit will have a high value in the target dimension and low values in other dimensions
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        # the bigram only uses the last char as the context
        # we pass in the full context here as practice for generation using transformer
        for _ in range(max_new_tokens):
            # get predictions
            logits, loss = self(idx)  # calls the forward function
            # retrieve only final timestep
            logits = logits[:, -1, :] # (B, T, C) -> (B, C)
            # apply softmax to get probability distribution
            dist = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(dist, num_samples=1) # (B, 1)
            # append new sample to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T + 1)
        return idx


In [559]:
model = BigramLanguageModel(vocab_size)
logits, loss = model(xb, yb)
print(logits.shape)  # 4 batches, 8 timesteps, vocab_size channels
print(loss)

torch.Size([32, 87])
tensor(4.8334, grad_fn=<NllLossBackward0>)


In [560]:
torch.manual_seed(3)

def sample(context, new_tokens=100):
    print(f'Context: {decode(context[0].tolist())}')
    sample = model.generate(context, new_tokens)
    text = decode(sample[0].tolist())
    print(f'Sample: {text}')


# as the model's starting context for sampling, let's provide a newline character
blank_context = torch.tensor([encode('\n')])
sample(blank_context, 250)

Context: 

Sample: 
ПхчМнЭХ–г!,цЙщЫ,рАНвщiсЛжрЭтЗБИзняН—́’зЩЯщБги;ыРуШжмгЕЙСыТПУг–ПФЦырщп́ФЕпЫЧпо т!Их-фУЬюш–лрёФъЛшШi:пМш’̀Гб—мСЁМвєГчВ
Х̀Ю&ТЁбєЙлкiыубц
ЛИЭЫтущп́оНд:р:мвущфНдеєоЪЬЁІЙёЯэбИ?кІЙнЉк’є РЙ!Ип’м-оф&ЙЦкЭ’,Хп́каыцУуйНБ
Н’лЯШшРОящ
СБШОЮЁфу;—––тЯэ?ЁеЛзІыцфшАзП;є


The above sampled text is gibberish. Let's train the model so it can produce something that looks more reasonable.

In [561]:
# typical lr setting is 3e-4, but for small models we can use a much higher lr
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [562]:
batch_size = 32
num_steps = 10000
for step in range(num_steps):
    # sample a batch of data
    xb, yb = get_batch('train')
    
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    if step == 0 or step == num_steps - 1:
        print(f'Step {step + 1}/{num_steps}: loss={loss.item()}')

Step 1/10000: loss=4.9720001220703125
Step 10000/10000: loss=2.5729243755340576


After optimization, let's see if we can sample something more reasonable.

In [563]:
sample(blank_context, 250)

Context: 

Sample: 
О—  -тх Вск убрдцелящушь уря,  кивобрия инет, стоси иск ветоешать кРОният, всгосска  М.
у. в Бетокудпе чене в о серо тыхусви.Са выки сяЫ руга ожал зврегше нарого- споцостстошичимот скусяЖГц вс нени фи онослорь этолс вел. Аако прой  н Ктс ниц стренене


This is starting to look more like Russian text, but it is still pretty much gibberish. This is because the bigram predicts the next token only by looking at the last token in the context window. I.e. our 'context' is just one token. We're not learning any complex language patterns this way. With a transformer, we can enable the tokens to 'talk' to each other over longer ranges to learn more complex dependencies.

# Self-attention

In [564]:
torch.manual_seed(3)

# here's a toy example
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)  # C is the dim of features we track at each timestep in the sequence
x.shape

torch.Size([4, 8, 2])

We want information to flow from start of sequence to current timestep; we don't want current token to communicate with future tokens in the sequence. At inference time, we don't know what these will be yet. In short, we want tokens to communicate only with past tokens.

A very simple form of this communication is a "BOW," or Bag of Words, approach: for each token at the `t`th timestep, we get the mean of the feature vectors of the timesteps up to and including `t`. This is very lossy because it does not capture any spatial information.

In [565]:
# xbow[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)  # C
        
xbow.shape

torch.Size([4, 8, 2])

In [566]:
x[0]

tensor([[-0.0766,  0.3599],
        [-0.7820,  0.0715],
        [ 0.6648, -0.2868],
        [ 1.6206, -1.5967],
        [-0.0517, -0.3060],
        [ 0.2485, -0.2226],
        [ 0.9132,  0.2043],
        [ 0.5740,  0.4163]])

In [567]:
xbow[0]

tensor([[-0.0766,  0.3599],
        [-0.4293,  0.2157],
        [-0.0646,  0.0482],
        [ 0.3567, -0.3630],
        [ 0.2750, -0.3516],
        [ 0.2706, -0.3301],
        [ 0.3624, -0.2538],
        [ 0.3888, -0.1700]])

Let's use matrix multiplication to make the computations more efficient (avoiding the for loops, and taking advantage of any GPU resources)

In [568]:
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'{a = }')
print(f'{b = }')
print(f'{c = }')

a = tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
b = tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c = tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


We can use `torch.tril` to effectively 'mask' tokens after a timestep `t`; we obtain a lower triangular matrix of ones, where the upper half is 0s, and then matrix multiply by our training sample.

In [569]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'{a = }')
print(f'{b = }')
print(f'{c = }')

a = tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b = tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c = tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


Now, to capture the `mean` functionality we used in the bag of words approach, instead of getting a lower triangular matrix of just 1s, we normalize each row so the values sum to 1. Then, when we take a dot product with each row, we are effectively getting the average of the non-zeroed out tokens.

In [570]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / a.sum(1, keepdim=True)  # normalize each row of a to enable averaging during matrix multiplication
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f'{a = }')
print(f'{b = }')
print(f'{c = }')

a = tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b = tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c = tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Let's apply this vectorization approach to our initial example with a batch of dim (B, T, C).

In [571]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
# batched matrix multiply - pytorch applies the matrix multiplication to each mat in B dim in parallel
xbow2 = weights @ x  # (T, T) @ (B, T, C) --> pytorch creates B dim for weights --> (B, T, T) @ (B, T, C) --> (B, T, C)
print(xbow2.shape)
print(xbow[0], xbow2[0])
torch.allclose(xbow, xbow2)


torch.Size([4, 8, 2])
tensor([[-0.0766,  0.3599],
        [-0.4293,  0.2157],
        [-0.0646,  0.0482],
        [ 0.3567, -0.3630],
        [ 0.2750, -0.3516],
        [ 0.2706, -0.3301],
        [ 0.3624, -0.2538],
        [ 0.3888, -0.1700]]) tensor([[-0.0766,  0.3599],
        [-0.4293,  0.2157],
        [-0.0646,  0.0482],
        [ 0.3567, -0.3630],
        [ 0.2750, -0.3516],
        [ 0.2706, -0.3301],
        [ 0.3624, -0.2538],
        [ 0.3888, -0.1700]])


True

One more, identical, way to rewrite this batch multiplication. We use softmax to normalize our weights matrix and achieve the averaging mask.

In [572]:
tril = torch.tril(torch.ones(T, T))

In [573]:
weights = torch.zeros((T, T))
# replace elements in weights where tril has zeroes with -inf
weights = weights.masked_fill(tril == 0, float('-inf'))
# take softmax along each row
weights = F.softmax(weights, dim=-1)
# this turns each row into a probability dist
# i.e. each row sums to 1 like we want
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

The initial values in the `weights` matrix before applying softmax can be thought of as 'affinities' or interaction strength; i.e. how much of each token from the past do we want to use in our aggregation for the current token, or how much of each past token to 'pay attention' to. Then, masking future tokens with `-inf` in a triangular fashion ensures that the current token does not interact with the future tokens, since softmax will bring these weights to 0.

The 'affinities' in `weights` will not just be constant at 0 like in this toy example (this results in the current token having equal 'affinity' for each of the previous tokens in the sequence). They will be data-dependent. We can achieve this with the self-attention mechanism via a soft-attention block.

To compute these 'affinities' in a data-dependent way, instead of making the `weights` a matrix of all zeros, we will have each token emit a query and a key vector. Roughly speaking, the query vector encodes what the token at that timestep is 'looking for.' The key vector encodes what the token 'contains.' Now, to get the 'affinity' that, say, the token at timestep 4 (call it t4) has for the token at timestep 1 (t1), we take the dot product of t4's query vector and t1's key vector. That dot product now becomes an entry in `weights`. This works because if the key and query vector are 'aligned,' their dot product will be high, giving a high affinity.

In [574]:
torch.manual_seed(3)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# a single head of self-attention
head_size = 16  # the dim of the key and query vectors
key = nn.Linear(C, head_size, bias=False) # maps tokens to key vectors
query = nn.Linear(C, head_size, bias=False) # maps tokens to query vectors
k, q = key(x), query(x)  # (B, T, head_size)
weights = q @ k.transpose(-2, -1)  # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
out = weights @ x

out.shape

torch.Size([4, 8, 32])

Now, in the same way produced a `key` and `query` from each token, we will produce a `value` vector. Then, instead of aggregating values from our raw input `x` according to our affinities in `weights`, we will aggregate the values that the raw tokens map to.

In [575]:
torch.manual_seed(3)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# a single head of self-attention
head_size = 16  # the dim of the key and query vectors
key = nn.Linear(C, head_size, bias=False) # maps tokens to key vectors
query = nn.Linear(C, head_size, bias=False) # maps tokens to query vectors
value = nn.Linear(C, head_size, bias=False) # maps tokens to value vectors
k, q = key(x), query(x)  # (B, T, head_size)
weights = q @ k.transpose(-2, -1)  # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

v = value(x)
out = weights @ v

out.shape

torch.Size([4, 8, 16])

Things to note:
- Attention is simply a communication mechanism. We can interpret the mechanism as a directed graph in which the nodes communicate with each other via edges with data-dependent weights. Each node aggregates the values of the nodes pointing to it via a weighted sum. In language modeling, the directed graph just has a specific structure, where each node is pointed to only by itself and the nodes at previous timesteps.
- We need the position embedding in order to encode the notion of space into the tokens. This is because attention doesn't capture any notion of space on its own. It just acts over a set of vectors. We need to encode the space information into the vectors ourselves.
- Each training sample across the batch dimension `B` is processed independently. Vectors don't 'talk' to each other across batches.
- To make an encoder attention block, we just delete the line that masks with `tril` and prevents communication with future tokens. This allows each token to communicate with all other tokens. With the masking, we have a decoder attention block since we use triangular masking, preventing future tokens from 'giving away the answer' to past tokens; this is typically used in autoregressive settings like language modeling (where we apply inference to generate a token, append this new token to the sequence, and repeat).
- Self-attention refers to the choice of producing key and value vectors from the same source we use to produce query vectors; i.e. our raw tokens `x` is the source for all of these. In cross-attention, we produce the keys and values from a different source which can encode some sort of context we would like to condition on; i.e. external nodes containing info we would like to bring in.

Example: machine translation. Machine translation employs cross-attention between a transformer encoder and transformer decoder; i.e. it uses a transformer encoder-decoder architecture.

In machine translation, we have some text input in a foreign language that we want to translate into a target language. This text in the foreign language is the input to the trasformer encoder, which has no `tril` masking, allowing all tokens to communicate with all tokens; we thus encode the entire content of the foreign sentence. The output of the encoder is then used to condition the decoder via cross-attention. In particular, the output of the encoder is used to get the key and value vectors for the decoder. The decoder's task is then to generate text in the target language using this information, as well as any previously generated tokens in the context window.

In the 'Attention Is All You Need' paper, the authors make one more key choice; implementing 'scaled' attention. With scaled attention, we divide each of our query-key dot products by the square root of our `head_size`. This turns out to be an important normalization of the values in our weights matrix. Let's see why.

In [576]:
# let our keys and queries be generated from the standard normal distribution (mean 0, variance 1)
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
weights = q @ k.transpose(-2, -1)

In [577]:
# verify that the variances of k and q are ~1
print(k.var())
print(q.var())

tensor(0.9535)
tensor(1.0524)


In [578]:
weights.var()

tensor(15.1501)

We see that the variance of our weights matrix is on the order of `head_size`.

But, if we multiply weights by the $1/\sqrt{\text{head\_size}}$, we see..

In [579]:
weights *= head_size ** -0.5
weights.var()

tensor(0.9469)

that now the variance of `weights` is ~1.

This is important because we use `weights` as an input to `softmax`. For initialization, it is important that the values across each row of `weights` are fairly diffuse/spread out; we want the values to all be pretty close to 0, i.e. have low variance. We do not want some values to be very positive and some to be very negative, which a high variance indicates. In this situation with a high variance, the output of `softmax` applies to `weights` would converge to one-hot vectors.

In short, we do not want the probabilities outputted by `softmax` to be too extreme, especially at initialization; this would lead us to aggregate information from just one or two nodes, so we'd miss out on learning from other nodes that were driven to 0 by softmax.

Scaled attention prevents this by controlling the variance of our affinities at initialization.