In [39]:
import torch
import torch.nn.functional as F
from torch import nn
from pathlib import Path

torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
with Path('input.txt').open("r", encoding='utf-8') as f:
    text = f.read()

vocab = sorted(list(set(text)))
stoi = { ch: i for i, ch in enumerate(vocab) }
itos = { i: ch for i, ch in enumerate(vocab) }
stoi['h']

46

In [3]:
encode = lambda x: [stoi[s] for s in x]
decode = lambda x: [itos[s] for s in x]
print("".join(decode(encode("hii there"))))

hii there


In [4]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [5]:
block_size = 8
batch_size = 4
head_size = 16
vocab_size = len(vocab)
n_embd = 16

In [22]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


In [70]:
class Head(nn.Module):
    def __init__(self, n_embd, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        wei = q @ k.transpose(-2, -1) * C**-.5 # (B, T, C)x(B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        v = self.value(x) # (B, T, C)
        out = wei @ v # (B, T, T)x(B, T, C) -> (B, T, C)
        return out

In [78]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd=n_embd, head_size=n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.sa_head(x)
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens=100):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx

In [79]:
model = LanguageModel(vocab_size, n_embd)
m = model.to(device)
    
xb, yb = get_batch('train')
logits, loss = m(xb, yb)

In [80]:
start_idx = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(start_idx, max_new_tokens=100)

In [81]:
print("".join(decode(out[0].tolist())))


jFK ;aFbvVL&CjHVBwphXKuLzBs?t n?L;SNnAWvp'Og$YVdaw'PDT .ZVjqILwZn$uEiWxBk.UqDdKWcMhGcN'ZquFaE,RawrFB


In [82]:
@torch.no_grad()
def estimate_loss(model, eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(len(losses)):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[i] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [83]:
max_iters = 4000
eval_iters = 200

model = LanguageModel(vocab_size, n_embd)
m = model.to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
    
for iter_ in range(max_iters):
    if iter_ % eval_iters == 0:
        out = estimate_loss(m, eval_iters)
        print(f"train loss {out['train']} val loss {out['val']}")
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

train loss 4.234309673309326 val loss 4.248541355133057
train loss 3.3896186351776123 val loss 3.3828630447387695
train loss 3.1888492107391357 val loss 3.21234130859375
train loss 3.1197478771209717 val loss 3.1506528854370117
train loss 3.020824670791626 val loss 3.0481066703796387
train loss 2.9578046798706055 val loss 2.961676836013794
train loss 2.881498336791992 val loss 2.8840441703796387
train loss 2.8322219848632812 val loss 2.8188843727111816
train loss 2.7620062828063965 val loss 2.822767972946167
train loss 2.8167638778686523 val loss 2.7403478622436523
train loss 2.7083499431610107 val loss 2.719287633895874
train loss 2.66115665435791 val loss 2.693411350250244
train loss 2.6839959621429443 val loss 2.67233943939209
train loss 2.65138578414917 val loss 2.7023422718048096
train loss 2.6598258018493652 val loss 2.6544039249420166
train loss 2.6279218196868896 val loss 2.62431001663208
train loss 2.6370041370391846 val loss 2.646200656890869
train loss 2.6679248809814453 val

In [84]:
start_idx = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(start_idx, max_new_tokens=100)
print("".join(decode(out[0].tolist())))


IZ nd lor btherusay; ETper:
HAcGly, dat co wato: bh,
Ouricoun, anen dad hary ho?

O:
OOS th lillowrd
