In [1]:
# %pip install tqdm
# %pip install ipywidgets
# %pip install tensorboard

In [2]:
import copy
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import pandas as pd
import tqdm.auto as tqdm
from torch.utils.tensorboard import SummaryWriter

## Downloading data

In [3]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
# !curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt > input.txt

In [4]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [5]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [6]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"vocab: {''.join(chars)}")
print(f"{len(chars)}")

vocab: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


## Making our own encoder & decoder

In [7]:
token2idx = {}
idx2token = {}

for idx, token in enumerate(sorted(list(set(text)))):
    token2idx[token] = idx
    idx2token[idx] = token

def encode_one(token):
    return token2idx[token]

def decode_one(idx):
    return idx2token[idx]

def encode(text):
    return [encode_one(char) for char in text]

def decode(code):
    return "".join([decode_one(idx) for idx in code])

vocab_size = len(token2idx)
vocab_size

65

In [8]:
print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [9]:
decode(encode(text[:100]))

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

### Alternative Tokenizers

#### SentencePiece

https://github.com/google/sentencepiece - Subword Unit Level Tokenizer

#### TikToken - Byte Pair Token 

In [10]:
# import tiktoken

In [11]:
# enc = tiktoken.get_encoding("gpt2")

In [12]:
# enc.n_vocab

In [13]:
# enc.encode("hii there")

In [14]:
# enc.decode([71, 4178, 612])

In [15]:
# len(enc.encode(text))

### Loading into Torch

In [16]:
import torch
text_tensor = torch.tensor(encode(text))
text_tensor

tensor([18, 47, 56,  ..., 45,  8,  0])

In [17]:
print(text_tensor.shape, text_tensor.dtype)
print(text_tensor[:10]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [18]:
n = int(0.9 * len(text_tensor))
train_data = text_tensor[:n]
val_data = text_tensor[n:]

## Playing around with transformer

We never give all data, only work on chunks of data, sampled from our dataset.

In [19]:
block_size = 8
# aka context_size = 8

In [20]:
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [21]:
decode(train_data[:block_size+1].tolist())

'First Cit'

We're gonna train our transformer to simultaneously predict all next characters, so previous block gives us 8 different examples.
We want our transformer to see different contexts up to and including block_size

In [22]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target is {target}")

when input is tensor([18]) the target is 47
when input is tensor([18, 47]) the target is 56
when input is tensor([18, 47, 56]) the target is 57
when input is tensor([18, 47, 56, 57]) the target is 58
when input is tensor([18, 47, 56, 57, 58]) the target is 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target is 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is 58


In [23]:
torch.manual_seed(1337)

<torch._C.Generator at 0x117a29f10>

In [24]:
batch_size = 4
block_size = 8

# def get_batch(split):
#     data = train_data if split == "train" else val_data


In [25]:
torch.randint(high=5, size=(10,))

tensor([0, 2, 2, 0, 0, 3, 0, 0, 4, 0])

In [26]:
def get_batch(split="train", batch_size=4, block_size=8, device=None):
    data = train_data if split == "train" else val_data
    batch_offsets = torch.randint(high=len(data)-block_size, size=(batch_size,))
    xs = torch.stack([data[idx:idx+block_size] for idx in batch_offsets])
    ys = torch.stack([data[idx+1:idx+block_size+1] for idx in batch_offsets])
    if device:
        xs, ys = xs.to(device), ys.to(device)
    return xs, ys

xs, ys = get_batch()
print(f"{xs.shape=}")
print(f"{ys.shape=}")

xs.shape=torch.Size([4, 8])
ys.shape=torch.Size([4, 8])


In [27]:
xs

tensor([[39, 58, 47, 53, 52, 12,  1, 37],
        [53, 56, 43,  1, 21,  1, 41, 39],
        [50, 39, 52, 63,  1, 47, 58, 57],
        [56, 53, 63,  1, 42, 47, 42,  1]])

In [28]:
ys

tensor([[58, 47, 53, 52, 12,  1, 37, 53],
        [56, 43,  1, 21,  1, 41, 39, 51],
        [39, 52, 63,  1, 47, 58, 57, 43],
        [53, 63,  1, 42, 47, 42,  1, 57]])

## Implementing Bigram Language Model

In [29]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

<torch._C.Generator at 0x117a29f10>

In [30]:
emb = nn.Embedding(vocab_size, vocab_size)
emb

Embedding(65, 65)

In [31]:
emb(torch.tensor([5]))

tensor([[-0.1338,  0.3899, -0.2884, -1.4651,  0.0101, -0.3004, -1.5733,  0.0148,
         -0.0447, -0.5367, -0.5223, -0.2181, -2.1608,  0.7865,  0.6854, -1.2576,
          0.6094, -2.0551, -0.4431, -0.6499, -0.6870,  0.2567, -1.2669,  0.2645,
         -0.6445,  1.0834, -0.7995,  0.2922,  1.3143,  1.2607, -0.3505, -2.0660,
          1.0575, -1.0572,  0.9911, -0.0797,  1.0751,  0.2381,  0.5757,  1.6685,
          0.5976, -1.8736,  1.2910, -0.3753, -1.8943,  0.5557,  0.8567, -0.8461,
          0.5015, -0.9656, -0.7255,  0.0990,  0.5928, -0.0422, -0.9566,  1.4424,
          0.4341, -0.4292,  0.3666,  0.1275, -0.0560,  0.8315, -0.5512,  1.0477,
          1.6187]], grad_fn=<EmbeddingBackward0>)

In [32]:
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_emb = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs, targets): ## inputs.shape=(B,T), targets.shape=(B,T)
        logits = self.token_emb(inputs) ## logits.shape=(B,T,C)
        # logits = torch.reshape(logits, (-1, logits.shape[-1]))
        # targets = torch.reshape(targets, (-1,))
        (B,T,C) = logits.shape
        loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
        return logits, loss

In [33]:
model = BigramLanguageModel()
out, loss = model(xs, ys)
print(f"{out.shape=}")
print(f"{loss=}")

out.shape=torch.Size([4, 8, 65])
loss=tensor(4.6267, grad_fn=<NllLossBackward0>)


### Coding a generation from a model

In [34]:
inputs = xs
print(f"{inputs.shape=}")

inputs.shape=torch.Size([4, 8])


In [35]:
logits, loss = model(xs, ys)
print(f"{logits.shape=}")

logits.shape=torch.Size([4, 8, 65])


In [36]:
probs = torch.softmax(logits[:, -1, :], dim=-1)
print(f"{probs.shape=}")

probs.shape=torch.Size([4, 65])


In [37]:
sample = torch.multinomial(probs, 1)
print(f"{sample.shape=}")
print(f"{sample=}")

sample.shape=torch.Size([4, 1])
sample=tensor([[12],
        [42],
        [60],
        [55]])


In [38]:
inputs = torch.concat([inputs, sample], dim=-1)
inputs

tensor([[39, 58, 47, 53, 52, 12,  1, 37, 12],
        [53, 56, 43,  1, 21,  1, 41, 39, 42],
        [50, 39, 52, 63,  1, 47, 58, 57, 60],
        [56, 53, 63,  1, 42, 47, 42,  1, 55]])

In [39]:
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_emb = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        logits = self.token_emb(inputs) ## logits.shape=(B,T,C)
        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [40]:
def generate(model, inputs, sequence_length = 100, context_size = None):
    """
    inputs: input sequence that is used to start generation
    size: length of a generated sequence
    """
    result = inputs
    for _ in range(sequence_length):
        ## Get the predictions over the next character
        if context_size is None:
            cond = result
        else:
            cond = result[:, -context_size:] # Trim the context
        logits = model(cond)
        probs = torch.softmax(logits[:, -1, :], dim=-1)
        sample = torch.multinomial(probs, num_samples=1)
        result = torch.concat([result, sample], dim=-1)
    return result


In [41]:
model = BigramLanguageModel()
logits, loss = model(xs, ys)
print(f"{logits.shape=}")
print(f"{loss=}")

logits.shape=torch.Size([4, 8, 65])
loss=tensor(4.2725, grad_fn=<NllLossBackward0>)


In [42]:
xs

tensor([[39, 58, 47, 53, 52, 12,  1, 37],
        [53, 56, 43,  1, 21,  1, 41, 39],
        [50, 39, 52, 63,  1, 47, 58, 57],
        [56, 53, 63,  1, 42, 47, 42,  1]])

In [43]:
generate(model, xs, 5)[0]

tensor([39, 58, 47, 53, 52, 12,  1, 37,  2, 18, 36, 11, 21])

In [44]:
print(decode(generate(model, torch.zeros((1,1), dtype=torch.long), 500)[0].tolist()))


!bPk ;M-.
,ruyhZoO:SVV:VVOt:e$Ie,&3Wr!dhTx;ldKBNL:d,MbBAOe JYR
&:rNHUESAbIfa!S
h;q:kUoKYy
QlmsKuMul3U
sjOdXEBFckMbBu&Ud$oP?fNDMCkgcRPgY,Zk
BSCsAOE:rnhDw.Sam$,&EFARCXzKGWkb'K'3Xl.:fwBFQpKB;-tHMRf!LcXHRmR$,Qxkm:ffwlxud
dDmRDnM;HNB'K?XKSUfcpE-SChhuKF
EtO;P?Wy,Qb'KMWcZmomfqubhumeqaVHuxbJQlKyDTA!oS
EGDMmOF'l;hYgK;ibFMRf!y'NlDMe ?HfEN.P--cXl;Wklz.3j.knpEDRTGhuq aBB,,qnwO&RnhB,d
ELV-qsEG!TBIjnR KvW!b
,
Fkb.k,I-JhRS.
MxmU,jojC.SAS
MUCFiL!OviAFkgxV-IjNJj,MdXatxVhyk,;dQvXnZVO?esMuGk3-mpg;e,Ltxmh3G?e:e?WAp


## Try to train the model

In [45]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) ## Much more advanced than SGD; Typical lr is 1e-4, but for smaller NNs we can use much higher value

In [46]:
batch_size = 32

num_steps = 100
pbar = tqdm.tqdm(desc="train", total=num_steps)
for step in range(num_steps):
    xb, yb = get_batch("train", batch_size)
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    pbar.set_postfix({"loss": loss.detach().item()})
    pbar.update(1)

train:   0%|          | 0/100 [00:00<?, ?it/s]

## Entire training loop

In [47]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import pandas as pd
import tqdm.auto as tqdm
from torch.utils.tensorboard import SummaryWriter

In [48]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [49]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"vocab: {''.join(chars)}")
# print(f"{len(chars)}")

vocab: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [50]:
token2idx = {}
idx2token = {}

for idx, token in enumerate(sorted(list(set(text)))):
    token2idx[token] = idx
    idx2token[idx] = token

def encode_one(token):
    return token2idx[token]

def decode_one(idx):
    return idx2token[idx]

def encode(text):
    return [encode_one(char) for char in text]

def decode(code):
    return "".join([decode_one(idx) for idx in code])

vocab_size = len(token2idx)
vocab_size

65

In [51]:
data = torch.tensor(encode(text))
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [52]:
def get_batch(split="train", batch_size=4, block_size=8, device=None):
    data = train_data if split == "train" else val_data
    batch_offsets = torch.randint(high=len(data)-block_size, size=(batch_size,))
    xs = torch.stack([data[idx:idx+block_size] for idx in batch_offsets])
    ys = torch.stack([data[idx+1:idx+block_size+1] for idx in batch_offsets])
    if device:
        xs, ys = xs.to(device), ys.to(device)
    return xs, ys

In [53]:
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_emb = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        logits = self.token_emb(inputs) ## logits.shape=(B,T,C)
        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [54]:
run_id = 0

In [55]:
run_id += 1
exp_name = f"shakespeare-bigram-{run_id}"
writer = SummaryWriter(f"runs/{exp_name}")

In [56]:
batch_size = 32
model = BigramLanguageModel()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) ## Much more advanced than SGD; Typical lr is 1e-4, but for smaller NNs we can use much higher value

num_steps = 10000
pbar = tqdm.tqdm(desc="train", total=num_steps)
for step in range(num_steps):
    xb, yb = get_batch("train", batch_size)
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    loss_val = loss.item()
    optimizer.step()
    pbar.set_postfix({"loss": loss_val})
    writer.add_scalar("Loss/train", loss_val, step)
    pbar.update(1)
pbar.close()

train:   0%|          | 0/10000 [00:00<?, ?it/s]

In [57]:
print(decode(generate(model, torch.zeros((1,1), dtype=torch.long), 500)[0].tolist()))


BuORI ck lifad:
Opicey helersurke ngorasechiveend ofoxFr seis are,'D:
I y es ast py tXJus at ADre crwchiuthe t
GEWBUMapre t fellrerin ton'semye, yome and th horsthuthoghegis hes tind, o.
Wheneel wooe it as d 'd vo's y ge,
Thioutong d,
S ty ildeanst, ntichand aran'd itas foy t, w pring the s ind:
BRDI yorig thal ar VI for tharey PXEre
DUK:
S:
Ha w f y adotrs the lyous y y, athee igVimoro-have t
E:
IOLed:
Plontor, mey e nd coro co shabe y athat! sd waut ju w'd nm mey he ffone yo by curoftha CHAN I


## Improving the model

### Defining a training loop

In [58]:
import torch

In [59]:
from dataclasses import dataclass

In [60]:
@dataclass
class TrainConfig:
    device: str = "cpu"
    batch_size: int = 32
    block_length: int = 8
    num_steps: int = 10000
    learning_rate: float = 1e-3
    eval_freq: int = 1000
    eval_batches: int = 100

In [61]:
class BigramLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()

        self.config = config
        self.token_emb = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        logits = self.token_emb(inputs) ## logits.shape=(B,T,C)
        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [62]:
def get_device(config):
    if config.device == "gpu" and torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = torch.device("mps")
    elif config.device == "gpu" and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    return device

In [63]:
@torch.no_grad()
def estimate_loss(model, num_batches):
    result = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(num_batches)
        for batch_id in range(num_batches):
            xs, ys = get_batch(split)
            _, loss = model(xs, ys)
            losses[batch_id] = loss.item()
        result[split] = losses.mean()
    model.train()
    return result

In [64]:
def run_train(model, run_name, config: TrainConfig):
    print(f"executing train run: {run_name}")
    device = get_device(config)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

    pbar = tqdm.tqdm(desc="train", total=config.num_steps)
    writer = SummaryWriter(f"runs/{run_name}")
    for step in range(config.num_steps):
        xb, yb = get_batch("train", batch_size=config.batch_size, block_size=config.block_length, device=device)
        logits, loss = model(xb, yb)
        optimizer.zero_grad()
        loss.backward()
        loss_val = loss.detach().item()
        optimizer.step()
        pbar.set_postfix({"loss": loss_val})

        writer.add_scalar("Loss/train", loss_val, step)

        if step % config.eval_freq == 0:
            losses_smooth = estimate_loss(model, num_batches=config.eval_batches)
            writer.add_scalar("AvgLoss/train", losses_smooth["train"], step)
            writer.add_scalar("AvgLoss/val", losses_smooth["val"], step)
        
        pbar.update(1)
    pbar.close()

In [65]:
run_id += 1
run_name = f"shakespeare-bigram-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001)
model = BigramLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-2


train:   0%|          | 0/10000 [00:00<?, ?it/s]

### Math trick to quickly calculate self-attention via MatMul + Softmax
Allows tokens to start talking to eachother

In [66]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [67]:
x[0,:,0]

tensor([ 0.1808, -0.3596,  0.6258,  0.9545,  0.3612, -1.3499,  0.2360, -0.9211])

In [68]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xmean = torch.mean(xprev, 0) # (C,)
        xbow[b,t] = xmean

In [69]:
xbow[0, :, 0]

tensor([ 0.1808, -0.0894,  0.1490,  0.3504,  0.3525,  0.0688,  0.0927, -0.0341])

In [70]:
ones = torch.triu(torch.ones((T,T)))
ones

tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.]])

In [71]:
# batch way to get moving average
xbow2 = ((x.permute((0,2,1)) @ ones) / ones.sum(0)).permute(0,2,1)
xbow2[0,:,0]

tensor([ 0.1808, -0.0894,  0.1490,  0.3504,  0.3525,  0.0688,  0.0927, -0.0341])

In [72]:
xbow.shape

torch.Size([4, 8, 2])

In [73]:
xbow[0,:,0]

tensor([ 0.1808, -0.0894,  0.1490,  0.3504,  0.3525,  0.0688,  0.0927, -0.0341])

In [74]:
xbow2.shape

torch.Size([4, 8, 2])

In [75]:
xbow2[0,:,0]

tensor([ 0.1808, -0.0894,  0.1490,  0.3504,  0.3525,  0.0688,  0.0927, -0.0341])

In [76]:
torch.isclose(xbow[0,:,0], xbow2[0,:,0])

tensor([True, True, True, True, True, True, True, True])

In [77]:
torch.allclose(xbow, xbow2)

True

Weighted aggregation via broadcasted batch multiplication:

In [78]:

weights = torch.tril(torch.ones((T,T)))
weights = weights / weights.sum(dim=1, keepdim=True)
xbow3 = weights @ x # (T,T) @ (B,T,C) => Broadcasts weights so that it has B dimension: (B,T,T) @ (B,T,C) => (B,T,C)


In [79]:
torch.allclose(xbow, xbow3)

True

Make attention mask via softmax. 
Currnetly weigiths are initialized with zeros, but next they are going to be learned from data

In [80]:
tril = torch.tril(torch.ones(T,T))
weights = torch.zeros(T,T)
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
xbow4 = weights @ x # (T,T) @ (B,T,C) => Broadcasts weights so that it has B dimension: (B,T,T) @ (B,T,C) => (B,T,C)
torch.allclose(xbow, xbow4)

True

### Plugging it back to the model:

In [81]:
@dataclass
class TrainConfig:
    device: str = "cpu"
    batch_size: int = 32
    block_length: int = 8
    num_steps: int = 10000
    learning_rate: float = 1e-3
    eval_freq: int = 1000
    eval_batches: int = 100
    token_emb_size: int = 32   ## Added this one

In [82]:
class BigramLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        # This creates one spurious level of interaction through the linear layer:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)
        logits = self.lm_head(tok_emb) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [83]:
run_id += 1
run_name = f"shakespeare-bigram-lmhead-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001)
model = BigramLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-lmhead-3


train:   0%|          | 0/10000 [00:00<?, ?it/s]

### Add positional embeddings:

In [84]:
class BigramLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [85]:
run_id += 1
run_name = f"shakespeare-bigram-posemb-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001)
model = BigramLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-posemb-4


train:   0%|          | 0/10000 [00:00<?, ?it/s]

### Implement self-attention for single individual "head"

Every single token (aka node) at each position will emit two vectors:
- Query: What am I looking for
- Key: What do I contain

The way we get affinities between tokens in the sequence is: we basically do the dot product between the keys and the queries - so my query dot products with all the keys of the other tokens in the sequence. And that dot product now becomes values in `weights` matrix (before we apply masking & softmax). So if the key and query are aligned, they interact in a very high amount, and then we get to learn more about that specific token as opposed to any other tokens in the sequence.

### Code a single head

In [86]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C) ## batch size, time, channels

head_size = 16
# bias=False is going to just apply some MatMul with some fixed weights
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

# when we forward this linear on top of my x, all the tokeins in all the positions in the (B,T) arrangement in parallel produce the key and a query
## We emit all the keys & queries in parallel (no communication happened just yet):
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)

In [87]:
(k @ q.transpose(-2, -1))[0]

tensor([[-1.7629, -3.3334, -1.0226,  0.7836, -1.2566, -0.3126,  1.0876, -1.8044],
        [-1.3011, -1.6556, -1.2606, -0.8014,  0.0187,  2.4152,  1.9652, -0.4126],
        [ 0.5652,  0.1040,  0.0762, -0.3368, -0.7880, -0.1106, -0.2621, -0.8306],
        [ 2.1616,  3.3782, -0.3813, -0.8496, -1.3204, -0.9931, -0.3158,  0.5898],
        [-1.0674, -2.1825, -0.9843, -0.5602,  2.0363,  3.3449,  0.6091, -0.7987],
        [ 1.9632,  1.0415, -1.4303, -1.1701,  0.8638, -2.5229,  1.2616, -0.5856],
        [ 1.0765, -0.0557,  0.0749, -1.2927,  0.3719,  1.4187, -0.5484,  0.6433],
        [-0.4530,  0.2927, -0.9547, -1.0260,  0.9258,  1.2196,  0.8048,  0.6303]],
       grad_fn=<SelectBackward0>)

In [88]:
q.shape

torch.Size([4, 8, 16])

In [89]:
k.shape

torch.Size([4, 8, 16])

In [90]:
(q @ k.transpose(-2, -1))[0]

tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

In [91]:
(k @ q.transpose(-2, -1))[0]

tensor([[-1.7629, -3.3334, -1.0226,  0.7836, -1.2566, -0.3126,  1.0876, -1.8044],
        [-1.3011, -1.6556, -1.2606, -0.8014,  0.0187,  2.4152,  1.9652, -0.4126],
        [ 0.5652,  0.1040,  0.0762, -0.3368, -0.7880, -0.1106, -0.2621, -0.8306],
        [ 2.1616,  3.3782, -0.3813, -0.8496, -1.3204, -0.9931, -0.3158,  0.5898],
        [-1.0674, -2.1825, -0.9843, -0.5602,  2.0363,  3.3449,  0.6091, -0.7987],
        [ 1.9632,  1.0415, -1.4303, -1.1701,  0.8638, -2.5229,  1.2616, -0.5856],
        [ 1.0765, -0.0557,  0.0749, -1.2927,  0.3719,  1.4187, -0.5484,  0.6433],
        [-0.4530,  0.2927, -0.9547, -1.0260,  0.9258,  1.2196,  0.8048,  0.6303]],
       grad_fn=<SelectBackward0>)

In [92]:
## All the keys will dot-product with all the keys:
# weights = k @ q.permute(0, 2, 1)
weights = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) => (B, T, T)
tril = torch.tril(torch.ones(T,T))
weights = weights.masked_fill(tril == 0, float('-inf')) # implicit broadcast of (tril == 0) to to (B, T, T)
weights = F.softmax(weights, dim=-1) # (B, T, T)
out = weights @ x # (B,T,T) @ (B,T,C) => Broadcasts weights so that it has B dimension: (B,T,T) @ (B,T,C) => (B,T,C)
out.shape

torch.Size([4, 8, 32])

In [93]:
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [94]:
out[0, 0, :]

tensor([ 0.1808, -0.0700, -0.3596, -0.9152,  0.6258,  0.0255,  0.9545,  0.0643,
         0.3612,  1.1679, -1.3499, -0.5102,  0.2360, -0.2398, -0.9211,  1.5433,
         1.3488, -0.1396,  0.2858,  0.9651, -2.0371,  0.4931,  1.4870,  0.5910,
         0.1260, -1.5627, -1.1601, -0.3348,  0.4478, -0.8016,  1.5236,  2.5086],
       grad_fn=<SliceBackward0>)

In [95]:
out[1, 0, :]

tensor([ 0.4562, -1.0917, -0.8207,  1.8634,  0.8148, -0.0643,  1.4237,  0.2617,
        -1.8528,  0.2019, -1.1787, -0.1036, -1.7830, -0.8323, -0.4346, -1.2480,
        -0.2880,  0.8809, -0.7190,  0.1745,  0.7520, -0.0629, -0.7111,  0.9810,
        -0.7244, -1.5010, -2.8348, -2.8272, -0.1736,  0.0512, -0.6576, -2.5729],
       grad_fn=<SliceBackward0>)

So for every row of B we will have (T, T) matrix giving us the affinities, and we're using they as weights in the weighted aggregation.
So token knows what content it has and what position it is in. So it creates a query, and all the other tokens are emitting keys.

Token & query key & query embeddings may encode various information in their different channels. So the current token emits this information in the query. When all the nodes emit keys, they also advertise this information about themselves. And maybe one of the channels in the query can be, for example: I am a vowel in pos 8 and I am looking for a any conconsonant at position up to four. And now some token may say in it's key: hey, i'm a consonant and i'm in the position 3. And this key will have a higher number in that specific channel. Hence via dot-product the query and a key can find eachother and create a high affinity, and through the softmax we will end up aggregating a lot of it's information into my position, and so I'll get to learn a lot about it.

Actually we're not aggregating tokens exactly, we're actually projecting into a value space:

In [96]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C) ## batch size, time, channels

head_size = 16
# bias=False is going to just apply some MatMul with some fixed weights
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
v = value(x) # (B, T, head_size)

weights = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) => (B, T, T)
## This matmul will run in parallel across batch dimension, so it is similar to looping over rows and multiplying (T, head_size) x (head_size, T)

tril = torch.tril(torch.ones(T,T))
weights = weights.masked_fill(tril == 0, float('-inf')) # implicit broadcast of (tril == 0) to to (B, T, T)
weights = F.softmax(weights, dim=-1) # (B, T, T)


out = weights @ v # (B,T,T) @ (B,T,C) => Broadcasts weights so that it has B dimension: (B,T,T) @ (B,T,C) => (B,T,C)
out.shape

torch.Size([4, 8, 16])

Attention is a communication mechanism - nodes in directed graph, every node has vector of information, and it aggregates all information from all nodes that point to it.
First node is only pointing to itself; Second points to self and prev node etc, all the way up to last node that points back.
That's the structure of our directed graph has in our auto-regressive scenario like language modelling.
In principle Attention can be applied to any arbitrary directed graph, and it's just a communication mechanism between the nodes.

There's no notion of space, nodes have no idea where are they located in space. That's why we need to encode them positionally and give them some information that anchors them to some position to let them know where they are. 

It's unlike convolutional filters, that are acting in space.
If we need it to have notion of space to Attention mechanism, we have to specifically add it.

In this case tokens don't talk to future tokens, but in principle we may allow them to, and have the full communication graph.
For example: in sentiment analysis we predict the sentiment of the sentence, so it's fine to let tokens talk to eachother. 
And in this case we just use encoder block of self-attention here: deleting the masked_fill line of code.
And what we've implemented here is actually a decoder block, where we want to use masking to avoid giving up the answer.
Attention doesn't care, as it supports arbitrary connectivity between nodes.

Q: What's the difference between Attention, Self-Attention and Cross-attention?
A: We just implemented Self-Attention. The reason it is called this way is because the keys, queries and values are all coming from the same source x.
So these nodes are self-attending. But in principle in encoder-decoder transformers we can have queries are produced from input x, but keys and values can come from whole separate external source, and sometimes from encoder blocks that encode some context that we like to condition on.
So cross-attention is used when there's a separate source of nodes we'd like to pull information from into our nodes. And it's self-attention if we just have nodes that would like to look to eachother and talk to eachother.

### Adding normalization (scaled attention)

It's important, especially at initialization, for `weight` to be fairly defused.  
If weight takes on very positive & very negative values, `softmax` would converge to one-hot vector.

In [97]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=0)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [98]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=0)

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

In [99]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
weight = q @ k.transpose(-2, -1) / (head_size ** 0.5)
weight.var()

tensor(1.0918)

### Plugging self-attention into a model

In [100]:
@dataclass
class TrainConfig:
    device: str = "cpu"
    batch_size: int = 32
    block_length: int = 8
    num_steps: int = 10000
    learning_rate: float = 1e-3
    eval_freq: int = 1000
    eval_batches: int = 100
    token_emb_size: int = 32
    head_size: int = 16

In [101]:
class Head(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.keys = nn.Linear(config.token_emb_size, config.head_size, bias=False)
        self.queries = nn.Linear(config.token_emb_size, config.head_size, bias=False)
        self.values = nn.Linear(config.token_emb_size, config.head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(config.head_size, config.head_size)))

    def forward(self, inputs):
        B, T, C = inputs.shape

        k = self.keys(inputs) # (B,T,H)
        q = self.queries(inputs) # (B,T,H)
        
        ## Compute attention scores, aka affinities:
        weights = q @ k.transpose(-2, -1) # (B,T,H) @ (B,H,T) => (B,T,T)
        weights = weights / (self.config.token_emb_size ** 0.5) # (B,T,T)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B,T,T) <= This makes it a decoder block
        weights = torch.softmax(weights, dim=-1) # (B,T,T)

        ## Perform the weighted aggregation of the values:
        v = self.values(inputs) # (B, T, head_size)
        out = weights @ v # (B,T,T) @ (B,T,H) => (B,T,H)
        return out

In [102]:
class BigramSelfAttentionLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.sa_head = Head(config) ## Self-attention head
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        x = self.sa_head(x) ## Apply self-attention head

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [103]:
run_id += 1
run_name = f"shakespeare-bigram-selfattention-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=32)
model = BigramSelfAttentionLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-selfattention-5


train:   0%|          | 0/10000 [00:00<?, ?it/s]

In [104]:
print(decode(
    generate(
        model, 
        torch.zeros((1,1), dtype=torch.long), 
        sequence_length=500,
        context_size = config.block_length
    )[0].tolist()
))


G RIED:
thounde
'Ythe tarcthirighran dsrgalle, dteschanee he gh alemealsle chi'sble ngoushe havele,
Nothigroveerce mure, therall, Moug brand Sod,
Why har,
Toonot; ngsilley arugenaccth
I:
Thoue tis scar dt yanegnshe? Sharis thild,'d thorsce, hod;
Whee din thee ndor theod fers hal I:
Cang dizorot nelerd quuinsork de it at, be boran becarlde grorne avesis ay lde ay ha'di fre to; ord wheatiflen o'dy I bulan sbenouve se dy.

OF sour the tath he;
Ot! waly mem rors me feer daredsoe ud mcow hin herigr, 


In [105]:
run_id += 1
run_name = f"shakespeare-bigram-selfattention-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=1e-3, token_emb_size=32, head_size=32)
model = BigramSelfAttentionLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-selfattention-6


train:   0%|          | 0/10000 [00:00<?, ?it/s]

So we got down to a loss of ~2.35 roughly (see tensorboard), let's try to add MultiHead Attention!

## Adding Multi-Head Attention

In [106]:
@dataclass
class TrainConfig:
    device: str = "cpu"
    batch_size: int = 32
    block_length: int = 8
    num_steps: int = 10000
    learning_rate: float = 1e-3
    eval_freq: int = 1000
    eval_batches: int = 100
    token_emb_size: int = 32
    head_size: int = 8
    num_heads: int = 4

In [107]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.heads = nn.ModuleList([Head(config) for _ in range(config.num_heads)])

    def forward(self, inputs):
        return torch.concat([head(inputs) for head in self.heads], dim=-1)

In [108]:
class BigramMultiHeadSelfAttentionLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.sa_heads = MultiHeadAttention(config) ## Multiple Self-attention heads
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        x = self.sa_heads(x) ## Apply multiple self-attention heads => (B,T,head_size*num_heads)

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [109]:
run_id += 1
run_name = f"shakespeare-bigram-multihead-selfattention-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=8, num_heads=4)
model = BigramMultiHeadSelfAttentionLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-multihead-selfattention-7


train:   0%|          | 0/10000 [00:00<?, ?it/s]

### Adding FFN

When adding multi-head self-attention that did the communication, we went too fast to compute the logits.
So the tokens looked at eachother but didn't really had enough time to think about what they found from another tokens.
So it makes sense to add some simple feed-forward single layer: simple projection followed by non-linearity.
Note that this feed-forward is now applied on a per-token level: so all the tokens now do this independently; self-attention is a communication, but once we gathered all the data, we need to think on all this data individually.

In [110]:
class FeedForward(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.token_emb_size, config.token_emb_size),
            nn.ReLU()
        )
    def forward(self, inputs):
        return self.net(inputs)

In [111]:
class BigramMultiHeadSelfAttentionFFLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.sa_heads = MultiHeadAttention(config) ## Multiple Self-attention heads
        self.ffwd = FeedForward(config)
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        x = self.sa_heads(x) # Apply self-attention heads => (B,T,head_size*num_heads)

        x = self.ffwd(x) # (B,T,token_emb_size)

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [112]:
run_id += 1
run_name = f"shakespeare-bigram-multihead-selfattention-ff-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=8, num_heads=4)
model = BigramMultiHeadSelfAttentionFFLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-multihead-selfattention-ff-8


train:   0%|          | 0/10000 [00:00<?, ?it/s]

### Intersperse communication and computation

Let's define a block that contains self-attention with feed-forward on top of that, and then repeat it multiple times

In [113]:
class Block(nn.Module):

    def __init__(self, config: TrainConfig):
        super().__init__()
        self.self_attention = MultiHeadAttention(config) ## Communication
        self.feed_forward = FeedForward(config) ## Computation

    def forward(self, inputs):
        return self.feed_forward(self.self_attention(inputs))

In [114]:
class BigramMultiHeadSelfAttentionBlocksLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.sa_blocks = nn.Sequential(
            Block(config),
            Block(config),
            Block(config),
        )
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        x = self.sa_blocks(x) # Apply self-attention blocks => (B,T,head_size*num_heads)

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [115]:
run_id += 1
run_name = f"shakespeare-bigram-multihead-selfattention-blocks-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=8, num_heads=4)
model = BigramMultiHeadSelfAttentionBlocksLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-multihead-selfattention-blocks-9


train:   0%|          | 0/10000 [00:00<?, ?it/s]

As our network starts becoming deeper. And deep NNs are typically suffer from optimization issues, so we need one more idea from the original paper.

### Add Skip Connections aka Residual Connections, and Batch Normalization.

In a Resudual pathway we're forking an additional computation graph, perform some computations, and then project back to the residual pathway via addition.
So we go from the inputs to the targets only via plus ... plus ... plus ops. During back-prop addition distributes gradients equally to both of it's branches.
And this allows supervision - the gradients from the loss - hop through every addition node all the way to the input, and then also fork off into residual blocks. This is basically a gradient super-highway that goes all the way from loss to input unimpeded. And these residual blocks are typically initialized in a way so that they are contributing very little if anything to the residual pathway. During the optimization they come online over time and start to contribute, as they kick in.


In [116]:
class ResidualFeedForward(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.token_emb_size, config.token_emb_size),
            nn.ReLU(),
            ## We need to add this to have a projection back into a residual pathway:
            nn.Linear(config.token_emb_size, config.token_emb_size),
        )
    def forward(self, inputs):
        return self.net(inputs)

In [117]:
class ResidualMultiHeadAttention(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.heads = nn.ModuleList([Head(config) for _ in range(config.num_heads)])

        ## We need to add this to have a projection back into a residual pathway:
        self.proj = nn.Linear(config.token_emb_size, config.token_emb_size)

    def forward(self, inputs):
        x = torch.concat([head(inputs) for head in self.heads], dim=-1)
        x = self.proj(x)
        return x

In [118]:
class ResidualBlock(nn.Module):

    def __init__(self, config: TrainConfig):
        super().__init__()
        self.self_attention = ResidualMultiHeadAttention(config) ## Communication
        self.feed_forward = ResidualFeedForward(config) ## Computation

    def forward(self, inputs):
        x = inputs
        ## Adding a projection back into a residual pathways:
        x = x + self.self_attention(x)
        x = x + self.feed_forward(x)
        return x

In [119]:
class ResidualBigramMultiHeadSelfAttentionLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.sa_blocks = nn.Sequential(
            ResidualBlock(config),
            ResidualBlock(config),
            ResidualBlock(config),
        )
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        x = self.sa_blocks(x) # Apply self-attention blocks => (B,T,head_size*num_heads)

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [120]:
run_id += 1
run_name = f"shakespeare-bigram-residual-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=8, num_heads=4)
model = ResidualBigramMultiHeadSelfAttentionLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-residual-10


train:   0%|          | 0/10000 [00:00<?, ?it/s]

### Multiply inner dim of FFN by four:

In [121]:
class ResidualFeedForwardWide(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.token_emb_size, config.token_emb_size * 4),
            nn.ReLU(),
            ## We need to add this to have a projection back into a residual pathway:
            nn.Linear(config.token_emb_size * 4, config.token_emb_size),
        )
    def forward(self, inputs):
        return self.net(inputs)

In [122]:
class ResidualBlockWide(nn.Module):

    def __init__(self, config: TrainConfig):
        super().__init__()
        self.self_attention = ResidualMultiHeadAttention(config) ## Communication
        self.feed_forward = ResidualFeedForwardWide(config) ## Computation

    def forward(self, inputs):
        x = inputs
        ## Adding a projection back into a residual pathways:
        x = x + self.self_attention(x)
        x = x + self.feed_forward(x)
        return x

In [123]:
class ResidualBigramMultiHeadSelfAttentionLanguageModelWide(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.sa_blocks = nn.Sequential(
            ResidualBlockWide(config),
            ResidualBlockWide(config),
            ResidualBlockWide(config),
        )
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        x = self.sa_blocks(x) # Apply self-attention blocks => (B,T,head_size*num_heads)

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [124]:
run_id += 1
run_name = f"shakespeare-bigram-residual-wide-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=8, num_heads=4)
model = ResidualBigramMultiHeadSelfAttentionLanguageModelWide(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-residual-wide-11


train:   0%|          | 0/10000 [00:00<?, ?it/s]

In [125]:
print(decode(
    generate(
        model, 
        torch.zeros((1,1), dtype=torch.long), 
        sequence_length=500,
        context_size = config.block_length
    )[0].tolist()
))


FELeft loid.

CORIOLANUS:
What not be prasty?

Feep thou shalling-staling thinc.'

QUEEN ELIZABETHEN EDWARD:
Now her fie with is grone!
What unse shall to tue
trege, thee what speaster anher, brunsen is sto vander, 't prow, let I blawn to have than guay reenbre it and at blife macing have heaves my sucquee:
I stabled fabehen on hirse you ceady nock undo bown oath,
My say it? a laight toll flitield, of his heirle in him the repron, nowt a 'brow hir on ttewliftle.
I treath?

CLARYork, be kno, tit 


We managed to get training loss down to ~1.9, but validation loss is ~2.1, it means we're probably start overfitting a little bit, as the network gets big enough.

### LayerNorm Refresher

Batch Normalization makes sure that across batch dimension any individual neuron has unit gaussian distribution (mean=0, stddev=1).
So this normalizes every single column of the batch (by keeping track of the distribution statistics).
For the Layer Normalization we're going to normalize rows, not columns. This doesn't require to have any running buffers / moving averages to be calculated and stored to be re-used at inference time. Now our computation can just get applied.

In [126]:
class LayerNorm:
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
    
    def __call__(self, x):
        # calculate the forward pass
        xmean = x.mean(1, keepdim=True) # batch mean
        xvar = x.var(1, keepdim=True) # batch variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]

In [127]:
torch.manual_seed(1337)
module = LayerNorm(100)
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

### Adding LayerNorm to our model

Also now it is more common to apply layernorm before the transformations, so the're is a re-shuffling of the layernorms.  
It's called a pre-norm formulation.


In [128]:
class LayerNormResidualBlockWide(nn.Module):

    def __init__(self, config: TrainConfig):
        super().__init__()
        self.self_attention = ResidualMultiHeadAttention(config) ## Communication
        self.feed_forward = ResidualFeedForwardWide(config) ## Computation
        # Mean and Variance is taken across last token_emb_size dimensions, So our B,T dimensions acts as a batch dimension.
        # So this is a per-token transformation that normalizes features and makes it unit-gaussian at initialization.
        # But because there are beta and gamma trainable parameters, tokens can become non-gaussian
        self.ln1 = nn.LayerNorm(config.token_emb_size)
        self.ln2 = nn.LayerNorm(config.token_emb_size)

    def forward(self, inputs):
        x = inputs
        ## Adding a projection back into a residual pathways:
        x = x + self.self_attention(self.ln1(x))
        x = x + self.feed_forward(self.ln2(x))
        return x

In [129]:
class LNResidualBigramMultiHeadSelfAttentionLanguageModelWide(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.sa_blocks = nn.Sequential(
            LayerNormResidualBlockWide(config),
            LayerNormResidualBlockWide(config),
            LayerNormResidualBlockWide(config),
            nn.LayerNorm(config.token_emb_size)
        )
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        x = self.sa_blocks(x) # Apply self-attention blocks => (B,T,head_size*num_heads)

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [130]:
run_id += 1
run_name = f"shakespeare-bigram-residual-ln-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=8, num_heads=4)
model = LNResidualBigramMultiHeadSelfAttentionLanguageModelWide(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-residual-ln-12


train:   0%|          | 0/10000 [00:00<?, ?it/s]

So we managed to get same results, but not much has changed

### Adding dropout + Cosmetic changes

Dropout is needed because as we're gonna scale up the number of parameters in our NN, we're starting to get concerned with overfitting  
So we're adding it to get some effect of regularization

In [131]:
@dataclass
class TrainConfig:
    device: str = "cpu"
    batch_size: int = 32
    block_length: int = 8
    num_steps: int = 10000
    learning_rate: float = 1e-3
    eval_freq: int = 1000
    eval_batches: int = 100
    token_emb_size: int = 32
    head_size: int = 8
    num_heads: int = 4
    num_blocks: int = 3
    dropout: float = 0.5

In [132]:
class DropoutHead(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.keys = nn.Linear(config.token_emb_size, config.head_size, bias=False, device=self.device)
        self.queries = nn.Linear(config.token_emb_size, config.head_size, bias=False, device=self.device)
        self.values = nn.Linear(config.token_emb_size, config.head_size, bias=False, device=self.device)
        self.register_buffer("tril", torch.tril(torch.ones(config.block_length, config.block_length, device=self.device)))
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, inputs):
        B, T, C = inputs.shape

        k = self.keys(inputs) # (B,T,H)
        q = self.queries(inputs) # (B,T,H)
        
        ## Compute attention scores, aka affinities:
        weights = q @ k.transpose(-2, -1) # (B,T,H) @ (B,H,T) => (B,T,T)
        weights = weights / (self.config.token_emb_size ** 0.5) # (B,T,T)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # (B,T,T) <= This makes it a decoder block
        weights = torch.softmax(weights, dim=-1) # (B,T,T)

        weights = self.dropout(weights)

        ## Perform the weighted aggregation of the values:
        v = self.values(inputs) # (B, T, head_size)
        out = weights @ v # (B,T,T) @ (B,T,H) => (B,T,H)
        return out

In [133]:
class DropoutMultiHeadAttention(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.device = get_device(config)
        self.heads = nn.ModuleList([DropoutHead(config) for _ in range(config.num_heads)])

        ## We need to add this to have a projection back into a residual pathway:
        self.proj = nn.Linear(config.token_emb_size, config.token_emb_size, device=self.device)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, inputs):
        x = torch.concat([head(inputs) for head in self.heads], dim=-1)
        x = self.proj(x)
        x = self.dropout(x)
        return x

In [134]:
class DropoutFeedForward(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.device = get_device(config)
        self.net = nn.Sequential(
            nn.Linear(config.token_emb_size, config.token_emb_size * 4, device=self.device),
            nn.ReLU(),
            ## We need to add this to have a projection back into a residual pathway:
            nn.Linear(config.token_emb_size * 4, config.token_emb_size, device=self.device),
            nn.Dropout(config.dropout)
        )

    def forward(self, inputs):
        return self.net(inputs)

In [135]:
class DropoutBlock(nn.Module):

    def __init__(self, config: TrainConfig):
        super().__init__()
        self.device = get_device(config)
        self.self_attention = DropoutMultiHeadAttention(config) ## Communication
        self.feed_forward = DropoutFeedForward(config) ## Computation
        # Mean and Variance is taken across last token_emb_size dimensions, So our B,T dimensions acts as a batch dimension.
        # So this is a per-token transformation that normalizes features and makes it unit-gaussian at initialization.
        # But because there are beta and gamma trainable parameters, tokens can become non-gaussian
        self.ln1 = nn.LayerNorm(config.token_emb_size, device=self.device)
        self.ln2 = nn.LayerNorm(config.token_emb_size, device=self.device)

    def forward(self, inputs):
        x = inputs
        ## Adding a projection back into a residual pathways:
        x = x + self.self_attention(self.ln1(x))
        x = x + self.feed_forward(self.ln2(x))
        return x

In [136]:
class DropoutBigramLanguageModel(nn.Module):
    def __init__(self, config: TrainConfig):
        super().__init__()
        self.config = config
        self.device = get_device(config)
        self.token_embedding_table = nn.Embedding(vocab_size, config.token_emb_size, device=self.device)
        self.position_embedding_table = nn.Embedding(config.block_length, config.token_emb_size, device=self.device)
        self.sa_blocks = nn.Sequential(*[
            DropoutBlock(config) for _ in range(config.num_blocks)
        ])
        self.ln_f = nn.LayerNorm(config.token_emb_size, device=self.device)
        self.lm_head = nn.Linear(config.token_emb_size, vocab_size, device=self.device)

    def forward(self, inputs, targets=None): ## inputs.shape=(B,T), targets.shape=(B,T)
        
        B,T = inputs.shape
        
        ## Encoding the tokens:
        tok_emb = self.token_embedding_table(inputs) # (B,T,token_emb_size)

        ## Sometimes people also encode positions of the tokens
        pos_idx = torch.arange(T, device=self.device) # T
        pos_emb = self.position_embedding_table(pos_idx) # (T,token_emb_size)

        x = tok_emb + pos_emb # pos_emb gets broadcasted across batch => (B,T,token_emb_size)

        x = self.sa_blocks(x) # Apply self-attention blocks => (B,T,head_size*num_heads)
        x = self.ln_f(x) # Apply final layer normalization

        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is not None:
            (B,T,C) = logits.shape
            loss = F.cross_entropy(logits.view(B*T,C), targets.view(B*T))
            return logits, loss
        else:
            return logits

In [137]:
run_id += 1
run_name = f"shakespeare-bigram-dropout-05-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=8, num_heads=4, dropout=0.5)
model = DropoutBigramLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-dropout-05-13


train:   0%|          | 0/10000 [00:00<?, ?it/s]

In [138]:
run_id += 1
run_name = f"shakespeare-bigram-dropout-02-{run_id}"
config = TrainConfig(device="cpu", num_steps=10_000, learning_rate=0.001, token_emb_size=32, head_size=8, num_heads=4, dropout=0.2)
model = DropoutBigramLanguageModel(config)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-dropout-02-14


train:   0%|          | 0/10000 [00:00<?, ?it/s]

### Scaling up a network size

In [139]:
run_id += 1
run_name = f"shakespeare-bigram-big-{run_id}"
config = TrainConfig(
    device="gpu", 
    batch_size=64, ## Number of independent sequences to be processed in parallel
    block_length=256, ## Maximum context length for predictions
    num_steps=10_000, 
    eval_batches=200,
    eval_freq=500,
    learning_rate=3e-4,
    token_emb_size=384, 
    head_size=384 // 6, 
    num_heads=6,
    dropout=0.2
)
device = get_device(config)
model = DropoutBigramLanguageModel(config)
model.to(device)
run_train(model, run_name, config)

executing train run: shakespeare-bigram-big-15


train:   0%|          | 0/10000 [00:00<?, ?it/s]

RuntimeError: Placeholder storage has not been allocated on MPS device!

This fails on mac now, but on A100 could take ~15 minutes

In the original paper there's also an encoder part because it's a language translation paper. An encoder transformer gets an input (prompt) in a different language, embeds it's tokens via similar transformer, but with unmasked self-attention, and then via cross-attention allows decoder to attend to it.

To train smth similar to Chat-GPT there's a pre-training and fine-tuning steps. In a pre-training stage we are training on a large chunk of internet and just trying to get a decoder-only transformer to babble text. Similar to what we've done here. Differences: Our little Shakespeare transformer is about 10m parameters, and the dataset is about 1m characters, so roughly 1m tokens. OpenAI uses subword tokenizer, so their vocabulary is of ~50k elements. And their sequences are more condensed. In this vocabulary Shakespeare dataset would be ~ 300k tokens. 

So we trained 10m parameter model on 300k tokens. Biggest GPT3 transformer has:
- 175B parameters
- 96 layers
- token_emb_size: 128 * 96 = 12288
- head_size = 128
- num_heads = 96
- batch_size = 3.2m
- learning_rate = 6e-5

They trained on 300 billion tokens. And that is not considered big anymore.
Architecture is actually pretty similar to what we looked at it.

After pre-training stage we get a document completer. It tries to complete sequences, so from questions it would ask more questions.

Second fine-tuning stage is - align it. 3 steps:

Step 1. Fine-Tuning
- Collect training data that looks like a use-case (documents with questions on top, and answers in below)
- They then fine-tune the model to only focus on documents that look like these
- These very-very large models are very sample-efficient during fine-tuning

Step 2. Let model respond. Responses are getting ranked to collect a training set that is then used to train a different reward model that predicts how much a candidate response will be desirable.

Step 3. Once we have a reward model, we run PPO - form of a policy gradient reinforcement learning optimizer to fine-tune this sampling policy so that answers generated are expected to score high according to a reward model.
