# A tiny language model
- Adapted from https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py
- Generates simple sentences like the white dog chased the small cat.
- Has less than 2000 parameters but still seems to take nearly 2 minutes to train on a three year old entry level graphics card.
- Seems to find some word embeddings as part of the overall training that end up with dog being close to cat etc.

In [1]:
# generate the training data
nouns = ['dog', 'cat']
adjectives = ['white', 'black', 'small']
verbs = ['saw', 'chased']
import random
r = random.choice
with open('animals.txt', 'w', encoding="utf-8") as f:
    for _ in range(100):
        f.write(f' the {r(adjectives)} {r(nouns)} {r(verbs)} the {r(adjectives)} {r(nouns)}\n')

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337);

In [3]:
# hyperparameters
batch_size = 4
block_size = 12
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 8
n_head = 4
n_layer = 2
dropout = 0.2

In [4]:
with open('animals.txt', 'r', encoding='utf-8') as f:
    text = f.read()
lines = text.splitlines()
allWords = [line.split() for line in lines]
distinctWords = sorted(list(set([word for line in allWords for word in line])))
tokens = ['\n'] + [' ' + w for w in distinctWords]
vocab_size = len(tokens)
stoi = { ch:i for i,ch in enumerate(tokens) }
itos = { i:ch for i,ch in enumerate(tokens) }
decode = lambda l: ''.join([itos[i] for i in l])
def encodedLine(line):   # e.g. ['big', 'dog'] --> [1, 5, 0]
    return [stoi[' ' + w] for w in line] + [0]
def flatten(a):
    return [x for row in a for x in row]
encodedText = flatten([encodedLine(words) for words in allWords])

In [5]:
# Train and test splits
data = torch.tensor(encodedText, dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [6]:
vocab_size

9

In [7]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [8]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

In [9]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [10]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [11]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [12]:
class GPTLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [13]:
model = GPTLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters()), 'parameters')

1961 parameters


In [14]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 2.2017, val loss 2.2000
step 500: train loss 1.1733, val loss 1.1718
step 1000: train loss 0.8567, val loss 0.8589
step 1500: train loss 0.7747, val loss 0.7841
step 2000: train loss 0.7467, val loss 0.7479
step 2500: train loss 0.6756, val loss 0.6779
step 3000: train loss 0.6057, val loss 0.6117
step 3500: train loss 0.5934, val loss 0.5982
step 4000: train loss 0.5886, val loss 0.5939
step 4500: train loss 0.5843, val loss 0.5850
step 4999: train loss 0.5851, val loss 0.5856


In [15]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=100)[0].tolist()))


 the white cat chased the white dog
 the the small cat saw the small cat
 the white dog chased the white dog
 the white cat saw the black cat
 the white cat chased the black cat
 the black dog chased the white dog
 the white dog saw the white cat
 the black cat saw the small dog
 the white dog saw the black cat
 the small dog chased the white dog
 the black dog chased the white cat
 the white dog saw the white cat
 the black cat


In [16]:
tokens

['\n', ' black', ' cat', ' chased', ' dog', ' saw', ' small', ' the', ' white']

In [17]:
w = model.token_embedding_table.weight

In [18]:
def f(i, j):
    return torch.dot(w[i], w[j]).item()

In [19]:
# The dot products of the embeddings with each other.
((w @ w.transpose(-2,-1)) * 1000).int()

tensor([[ 29,   0,  -8,   6,  -9,   5,  -5,  -9,  -7],
        [  0,  41, -15,  -7, -12,  -8,  39,  -7,  40],
        [ -8, -15,  27,  -8,  24,  -5, -15, -15, -15],
        [  6,  -7,  -8,  23,  -8,  17,  -3,  -7,  -4],
        [ -9, -12,  24,  -8,  22,  -5, -12, -14, -12],
        [  5,  -8,  -5,  17,  -5,  21,  -7,  -9,  -2],
        [ -5,  39, -15,  -3, -12,  -7,  38,  -5,  40],
        [ -9,  -7, -15,  -7, -14,  -9,  -5,  51,  -8],
        [ -7,  40, -15,  -4, -12,  -2,  40,  -8,  47]], device='cuda:0',
       dtype=torch.int32)

In [20]:
# Print the tokens whose embeddings are correlated.
for i in range(1, vocab_size):
    for j in range(i+1, vocab_size):
        if (f(i, j) > 0):
            print(tokens[i], tokens[j])

 black  small
 black  white
 cat  dog
 chased  saw
 small  white
