In [8]:
import torch
import torch.nn.functional as F
from torch import nn
from pathlib import Path

torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
with Path('input.txt').open("r", encoding='utf-8') as f:
    text = f.read()

vocab = sorted(list(set(text)))
stoi = { ch: i for i, ch in enumerate(vocab) }
itos = { i: ch for i, ch in enumerate(vocab) }
stoi['h']

46

In [10]:
encode = lambda x: [stoi[s] for s in x]
decode = lambda x: [itos[s] for s in x]
print("".join(decode(encode("hii there"))))

hii there


In [11]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [17]:
block_size = 8
batch_size = 4
vocab_size = len(vocab)
n_embd = 16

In [22]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


In [44]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        tok_emb = self.token_embedding_table(idx)
        logits = self.lm_head(tok_emb)
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens=100):
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx

In [45]:
model = LanguageModel(vocab_size, n_embd)
m = model.to(device)
    
xb, yb = get_batch('train')
logits, loss = m(xb, yb)
print(logits, loss)

tensor([[ 0.1449,  0.2313,  0.9192,  ..., -0.0208,  0.3893,  0.0451],
        [ 0.6140, -0.3954, -0.6476,  ...,  0.3876, -0.1631, -0.2325],
        [ 0.6969, -0.4061, -0.9308,  ..., -0.2372, -0.7925, -0.1318],
        ...,
        [-0.8424,  1.0103, -0.9083,  ..., -0.1128, -0.2465, -0.7655],
        [ 0.3486, -0.0026,  0.0323,  ..., -0.2791,  0.1204,  0.3467],
        [ 0.1071, -0.2266,  0.8529,  ..., -0.8352, -0.8122,  1.1223]],
       grad_fn=<ViewBackward0>) tensor(4.4882, grad_fn=<NllLossBackward0>)


In [46]:
start_idx = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(start_idx, max_new_tokens=100)

In [47]:
print("".join(decode(out[0].tolist())))


vuDZN.Q&.bz!DUNvCEGBt.jFBhUco'Y,GrCyWZ. IM-dIM-$LY:?By&HSr
DL;vvOshgNPtJf::abSzIo?'hPnOmfjNFsNFUI:Jl


In [56]:
@torch.no_grad()
def estimate_loss(model, eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(len(losses)):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[i] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [58]:
max_iters = 4000
eval_iters = 200

model = LanguageModel(vocab_size, n_embd)
m = model.to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
    
for iter_ in range(max_iters):
    if iter_ % eval_iters == 0:
        out = estimate_loss(m, eval_iters)
        print(f"train loss {out['train']} val loss {out['val']}")
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

train loss 4.33693790435791 val loss 4.320051193237305
train loss 3.6961019039154053 val loss 3.7138495445251465
train loss 3.314929723739624 val loss 3.3365960121154785
train loss 3.077657699584961 val loss 3.096342086791992
train loss 2.8986334800720215 val loss 2.9349753856658936
train loss 2.8346869945526123 val loss 2.8470304012298584
train loss 2.7784552574157715 val loss 2.792571783065796
train loss 2.7229669094085693 val loss 2.7633180618286133
train loss 2.7235612869262695 val loss 2.7153449058532715
train loss 2.650404453277588 val loss 2.6887402534484863
train loss 2.6732606887817383 val loss 2.6649930477142334
train loss 2.6559348106384277 val loss 2.638273239135742
train loss 2.603419542312622 val loss 2.6435353755950928
train loss 2.5942888259887695 val loss 2.6315701007843018
train loss 2.6217739582061768 val loss 2.6362555027008057
train loss 2.56425404548645 val loss 2.6281909942626953
train loss 2.5767405033111572 val loss 2.620434045791626
train loss 2.61979460716247

In [None]:
start_idx = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(start_idx, max_new_tokens=100)
print("".join(decode(out[0].tolist())))