In [None]:
![ -f input.txt ] || wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
# load dataset
with open('input.txt', 'r') as file:
    text = file.read()

vocab = sorted(list(set(''.join(text))))

print("--- vocabulary of dataset ---")
print(repr(''.join(vocab)))
print("vocab size:", len(vocab))
print()

stoi = { s:i for i, s in enumerate(vocab)}
itos = { i:s for s, i in stoi.items()}

encode = lambda seq: [stoi[ch] for ch in seq]
decode = lambda tokens: ''.join([itos[token] for token in tokens])

print("--- testing encoding and decoding (string: hello world) ---")
print("encoding:", encode('hello, world!'))
print("decoding:", decode(encode('hello, world!')))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# hyperparameters
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SZ = 256
BLOCK_SIZE = 256
VOCAB_SIZE = len(vocab)
EMBEDDING_DIM = 288
NUM_HEADS = 9
NUM_DECODER_LAYERS = 6
HEAD_SIZE = EMBEDDING_DIM // NUM_HEADS
EVAL_ITERS = 200
EVAL_INTERVAL = 200
DROPOUT = 0.2
LR = 3e-4
NUM_TRAIN_ITERS = 5000

print(f"--- running on device {DEVICE} ---")

In [None]:
data = torch.tensor(encode(text), device=DEVICE)

N = int(len(data) * 0.9)
train_data = data[:N]
validation_data = data[N:]

In [None]:
# sampling the data (get a single random batch of the data)
def get_batch(split: str):
    dataset = train_data if split == 'train' else validation_data

    ix = torch.randint(0, len(dataset) - BLOCK_SIZE, (BATCH_SZ,))

    contexts = [dataset[pos : pos+BLOCK_SIZE] for pos in ix]
    targets = [dataset[pos+1 : pos+BLOCK_SIZE+1] for pos in ix]

    contexts, targets = torch.stack(contexts).to(DEVICE), torch.stack(targets).to(DEVICE)

    return contexts, targets

In [None]:
sample_context, sample_target = get_batch('train')
print("--- sample batch ---")
print(f"context:\n{sample_context}")
print(f"target:\n{sample_target}")

print(f"context[0] decoded: {repr(decode(sample_context[0].tolist()))}, target[0] is {repr(decode(sample_target[0].tolist()))}")

In [None]:
class Head(nn.Module):

    def __init__(self):
        super().__init__()

        self.query = nn.Linear(EMBEDDING_DIM, HEAD_SIZE)
        self.key = nn.Linear(EMBEDDING_DIM, HEAD_SIZE)
        self.value = nn.Linear(EMBEDDING_DIM, HEAD_SIZE)
        self.register_buffer('mask', torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))

    def forward(self, x):
        B, T, C = x.shape

        Q = self.query(x) # (B, T, HEAD_SIZE)
        K = self.key(x) # (B, T, HEAD_SIZE)

        wei = (Q @ K.transpose(-2, -1)) * (HEAD_SIZE**-0.5)

        wei = wei.masked_fill(self.mask[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1)

        V = self.value(x)

        out = wei @ V # (B, T, HEAD_SIZE)
        
        return out


In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.heads = nn.ModuleList(Head() for _ in range(NUM_HEADS))
        self.proj = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM)
        self.dropout = nn.Dropout(DROPOUT) # do dropout regularization at the end of each sub-layer

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)

        out = self.dropout(self.proj(out)) # project back to embedding dimension and apply regularization

        return out

In [None]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(EMBEDDING_DIM, 4 * EMBEDDING_DIM),
            nn.GELU(),
            nn.Linear(4 * EMBEDDING_DIM, EMBEDDING_DIM),
            nn.Dropout(DROPOUT)
        )
    
    def forward(self, x):
        return self.mlp(x)

In [None]:
class Block(nn.Module):

    def __init__(self):
        super().__init__()

        self.layer_norm1 = nn.LayerNorm(EMBEDDING_DIM)
        self.mha = MultiHeadAttention()
        self.layer_norm2 = nn.LayerNorm(EMBEDDING_DIM)
        self.feed_forward = MLP()

    def forward(self, x):
        x = x + self.mha(self.layer_norm1(x))
        x = x + self.feed_forward(self.layer_norm2(x))

        return x

In [None]:
# our bigram model
class BigramLM(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.positional_encoding = nn.Embedding(BLOCK_SIZE, EMBEDDING_DIM)
        self.blocks = nn.Sequential(*[Block() for _ in range(NUM_DECODER_LAYERS)])
        self.layer_norm = nn.LayerNorm(EMBEDDING_DIM)
        self.lm_head = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)
    
    def forward(self, x, targets=None):
        loss = None
        B, T = x.shape # (B, T)

        input_embedding = self.token_embedding(x)
        positional_encoding = self.positional_encoding(torch.arange(T).to(DEVICE))

        x = input_embedding + positional_encoding # (B, T, EMBEDDING_DIM)
        x = self.blocks(x)
        x = self.layer_norm(x)

        logits = self.lm_head(x) # (B, T, VOCAB_SZ)

        if targets is not None:
            B, T, C = logits.shape

            logits = logits.view(B*T, C) # reshape to have logits for each single token in the batch all laid out in the frontier
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, x, max_new_tokens):
        # x: (B, T)
        for _ in range(max_new_tokens):
            # truncate to context window
            x_trunc = x[:, -BLOCK_SIZE:]

            # forward
            logits, loss = self(x_trunc) # (B, T, VOCAB_SIZE)

            # last time step of logits for each batch item
            logits = logits[:, -1, :] # (B, VOCAB_SIZE)

            # softmax for normalized probabilities
            probs = F.softmax(logits, dim=-1)

            # sample
            gen = torch.multinomial(probs, num_samples=1) # (B, 1)

            x = torch.cat((x, gen), dim=1)
        
        return x

In [None]:
model = BigramLM().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, fused=True)

print(model)
print("Number of parameters in model:", sum(parameter.numel() for parameter in model.parameters()))

In [None]:
@torch.no_grad()
def estimate_loss():
    losses = {}
    model.eval()

    for split in ['train', 'validation']:
        losses_for_split = torch.zeros(EVAL_ITERS, device=DEVICE)

        for i in range(EVAL_ITERS):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)

            losses_for_split[i] = loss.item()
        
        losses[split] = losses_for_split.mean()
    
    model.train()
    
    return losses

In [None]:
# print("--- model generation before training ---")
# print(decode(model.generate((torch.zeros((1, 1), dtype=torch.long, device=DEVICE)), max_new_tokens=1000)[0].tolist()))

In [None]:
from torch.amp import autocast, GradScaler

In [None]:
scaler = GradScaler()

for i in range(NUM_TRAIN_ITERS):
    xb, yb = get_batch('train')

    optimizer.zero_grad(set_to_none=True)

    with autocast():
        output = model(input)

        logits, loss = model(xb, yb)
        estimated_loss = estimate_loss()

        if i % EVAL_INTERVAL == 0:
            print(f"--- estimated loss at epoch {i}. train: {estimated_loss['train']}, validation: {estimated_loss['validation']} ---")

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    # loss.backward()
    # optimizer.step()    

print(f"--- loss immediately after training loop: {loss} ---")  

In [None]:
print("--- model generation after training ---")
print(decode(model.generate((torch.zeros((1, 1), dtype=torch.long, device=DEVICE)), max_new_tokens=10000)[0].tolist()))