In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [3]:
len(text)

1115394

In [4]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
vocab_size


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

In [6]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}


def encode(s: str) -> list[int]:
    """Encoder: take a string, output a list of integers"""
    return [stoi[c] for c in s]


def decode(l: list[int]) -> str:
    """Decoder: take a list of integers, output a string"""
    return "".join([itos[i] for i in l])


print(encode("hi there"))
print(decode(encode("hi there")))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


In [7]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [8]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [9]:
block_size = 8
train_data[: block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [10]:
x = train_data[:block_size]
y = train_data[1 : block_size + 1]
for t in range(block_size):
    context = x[: t + 1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [11]:
torch.manual_seed(1337)
batch_size = 4  # independent sequences we process in parallel
block_size = 8  # maximum context length


def get_batch(split: str) -> tuple[torch.Tensor, torch.Tensor]:
    """Generate a small batch of data of inputs x and targets y"""
    data = train_data if split == "train" else val_data
    indices = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in indices])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in indices])
    return x, y


xb, yb = get_batch("train")
print("inputs:")
print(xb.shape)
print(xb)
print("sample input:", decode(xb[0].tolist()))
print("targets:")
print(yb.shape)
print(yb)
print("sample output:", decode(yb[0].tolist()))

print("-----")

for b in range(batch_size):  # batch dimension
    for t in range(block_size):  # time dimension
        context = xb[b, : t + 1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
sample input: Let's he
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
sample output: et's hea
-----
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53,

In [12]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        # Each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(
        self, inputs: torch.Tensor, targets: torch.Tensor | None = None
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        # idx and targets are both (B,T) tensor of integers
        # (B, T, C)
        logits = self.token_embedding_table(
            inputs
        )  # (idx.shape[0], idx.shape[1], vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        # idx is (B, T) array of indices in the current context
        for i in range(max_new_tokens):
            # predict
            # (B, T, C) = batch, time, channels
            logits, loss = self(idx)  # (idx.shape[0], i + 1, vocab_size)
            # focus only on the last time step
            # becomes (B, C)
            logits = logits[:, -1, :]  # (idx.shape[0], vocab_size)
            # apply softmax to get probabilities
            # (B, C)
            probs = F.softmax(logits, dim=-1)  # (idx.shape[0], vocab_size)
            # sample from the distribution
            # (B, 1)
            idx_next = torch.multinomial(probs, num_samples=1)  # (idx.shape[0], 1)
            # append sampled index to the running sequence
            # (B, T+1)
            idx = torch.cat((idx, idx_next), dim=1)  # (idx.shape[0], i + 2)

        return idx


m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(5.0364, grad_fn=<NllLossBackward0>)

lfJeukRuaRJKXAYtXzfJ:HEPiu--sDioi;ILCo3pHNTmDwJsfheKRxZCFs
lZJ XQc?:s:HEzEnXalEPklcPU cL'DpdLCafBheH


In [13]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [14]:
batch_size = 32
for steps in range(10000):
    # Sample a batch of data
    xb, yb = get_batch("train")

    # Evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.362440586090088


In [15]:
print(decode(m.generate(idx, max_new_tokens=500)[0].tolist()))


M:
IUSh t,
F th he d ke alved.
Thupld, cipbll t
I: ir w, l me sie hend lor ito'l an e

I:
Gochosen ea ar btamandd halind
Aust, plt t wadyotl
I bel qunganonoth he m he de avellis k'l, tond soran:

WI he toust are bot g e n t s d je hid t his IAces I my ig t
Ril'swoll e pupat inouleacends-athiqu heamer te
Wht s

MI wect!-lltherotheve t fe;
WAnd py;

PO t s ld tathat, ir V
IO thesecin teot tit ado ilorer.
Ply, d'stacoes, ld omat mealellly yererer EMEvesas ie IZEd pave mautoofareanerllleyomerer but?


In [16]:
# Self-attention
B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [17]:
# We want x[b, t] = mean_{i<=t} x[b, i]
xbow = torch.zeros((B, T, C))  # Bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b, : t + 1]  # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

print(x[0])
print(xbow[0])  # Becomes a rolling average along the column

tensor([[ 0.9412, -1.3355],
        [ 2.0440,  1.0032],
        [-1.0121, -1.5424],
        [-0.5439, -0.4893],
        [-1.2058,  0.4698],
        [ 0.9218, -1.0068],
        [-0.8110, -0.1862],
        [ 1.0543,  0.9421]])
tensor([[ 0.9412, -1.3355],
        [ 1.4926, -0.1661],
        [ 0.6577, -0.6249],
        [ 0.3573, -0.5910],
        [ 0.0447, -0.3788],
        [ 0.1909, -0.4835],
        [ 0.0478, -0.4410],
        [ 0.1736, -0.2681]])


In [18]:
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [19]:
# Trick for computing the self-attention efficiently
a = torch.tril(torch.ones(3, 3))
a /= a.sum(1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(f"{a=}")
print(f"{b=}")
print(f"{c=}")  # Rolling average of b

a=tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=tensor([[5., 7.],
        [5., 2.],
        [6., 6.]])
c=tensor([[5.0000, 7.0000],
        [5.0000, 4.5000],
        [5.3333, 5.0000]])


In [20]:
wei = torch.tril(torch.ones(T, T))
wei /= wei.sum(1, keepdim=True)
print(wei)
xbow2 = wei @ x  # (B, T, T) @ (B, T, C) --> (B, T, C)
torch.allclose(xbow, xbow2)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

In [21]:
# Using softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
print(wei)
wei = F.softmax(wei, dim=-1)  # (T, T)
print(wei)
xbow3 = wei @ x  # (B, T, T) @ (B, T, C) -> (B, T, C)
torch.allclose(xbow, xbow3)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

In [35]:
# Self-attention
B, T, C = 4, 8, 32  # batch, time, channels
x = torch.randn(B, T, C)

# Single self-attention head
head_size = 16
key = nn.Linear(C, head_size, bias=False)  # What do I contain?
query = nn.Linear(C, head_size, bias=False)  # What am I looking for?
value = nn.Linear(C, head_size, bias=False)
k: torch.Tensor = key(x)  # (B, T, head_size)
q: torch.Tensor = query(x)  # (B, T, head_size)
# For every row of B we have a (T, T) matrix t
wei = q @ k.transpose(-2, -1)  # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))  # Decoder
wei = F.softmax(wei, dim=-1)
v = value(x)  # (B, T, head_size)
out = wei @ v  # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
out.shape

torch.Size([4, 8, 16])

- Each example across batch dimension is processed completely independently and in parallel
- In an encoder attention block we simply remove the line that does masking with `tril`, allowing tokens to communicate with each other
- It's called a decoder when we have the masking
- Self-attention: The keys and values are produced from the same source as the queries
- Cross-attention: Queries get produced from x, but the keys and values come from a different source (could be an encoder module)
- Scaled attention divides `wei` by $\frac{1}{\sqrt{\text{head\_size}}}$, so when Q, K are unit variance, `wei` will be unit variance too and softmax will stay diffuse and not saturate too much


In [39]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1)
print(f"{k.var()=}")
print(f"{q.var()=}")
print(f"{wei.var()=}")

k.var()=tensor(1.0434)
q.var()=tensor(1.0521)
wei.var()=tensor(16.3439)


In [38]:
# Scaling preserves variance
# If wei's values are too big/low, softmax will saturate and look like one-hot encoding
wei = q @ k.transpose(-2, -1) * head_size**-0.5
print(f"{wei.var()=}")

wei.var()=tensor(1.0808)
