In [1]:
from torch import nn
from torch.nn import functional as F
from pathlib import Path

import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1337)

<torch._C.Generator at 0x1087f5e90>

In [2]:
with Path('../input.txt').open('r', encoding='utf-8') as f:
    text = f.read()

vocab = sorted(list(set(text)))
stoi = { ch: i for i, ch in enumerate(vocab) }
itos = { i: ch for i, ch in enumerate(vocab) }
encode = lambda x: [stoi[s] for s in x]
decode = lambda x: [itos[s] for s in x]

print("".join(decode(encode("hii there"))))

hii there


In [3]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[n:]
val_data = data[:n]

In [4]:
batch_size = 4
block_size = 8
n_embd = 12

eval_iters = 200
max_iters = 4000

vocab_size = len(vocab)

In [5]:
# torch.randint

In [6]:
def get_batch(split, batch_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [23]:
# B, T, C = 4, 8, 2
# x = torch.randn(B, T, C)

# tril = torch.tril(torch.ones((T, T)))
# wei = torch.zeros(T, T)
# wei = wei.masked_fill(tril == 0, float('-inf'))
# wei = F.softmax(wei, dim=-1)
# out = wei @ x
# out

In [27]:
class Head(nn.Module):
    def __init__(self, head_size, block_size):
        super().__init__()
        self.query = nn.Linear(head_size, head_size, bias=False)
        self.key = nn.Linear(head_size, head_size, bias=False)
        self.value = nn.Linear(head_size, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        
        v = self.value(x) # (B, T, C)
        out = wei @ v
        return out

In [51]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(head_size=n_embd, block_size=block_size)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.sa_head(x)
        logits = self.lm_head(x)
        
        if targets is None:
            losses = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            losses = F.cross_entropy(logits, targets)
        return logits, losses

    def generate(self, idx, max_new_tokens=100):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx

In [59]:
m = LanguageModel(vocab_size, n_embd, block_size)
m = m.to(device)

xb, yb = get_batch('train', batch_size)
logits, losses = m(xb, yb)
print(logits, losses)

_in = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(_in, max_new_tokens=1000)
print("".join(decode(out[0].tolist())))

tensor([[-0.4109,  1.2377,  1.6360,  ..., -0.6036, -0.8290,  0.1405],
        [-0.1910,  0.8041,  0.7763,  ..., -0.1787, -0.2842,  0.1111],
        [ 0.0637,  0.7521,  0.7351,  ..., -0.3728, -0.5637,  0.0292],
        ...,
        [-0.0553,  0.2214,  0.0936,  ...,  0.0092,  0.0863, -0.2287],
        [ 0.1482,  0.5251,  0.0561,  ..., -0.0180,  0.0078, -0.2500],
        [ 0.1655,  0.3600,  0.0456,  ..., -0.0674, -0.0370, -0.0994]],
       grad_fn=<ViewBackward0>) tensor(4.1167, grad_fn=<NllLossBackward0>)

AaT-gwVAI3v.cjVMQc J-MUwt,FD?phP:-JCn-gd!ik!lLmz

Smhvld;QkeqDChkSdySHwdiChqStcqFKpGWdVj!G-U.O3zYF3tnKzUjAwdVlZvvBKCgO &u-zS'qmXd:qRytqBAxe!dTF?fBgBbK:Npm '
Y3-T.LXXE&hWBaVUV?U'K!o:?cC.QL!FsGnhYJBAhNyxzcYrKmZ3VYdtX
urtbamRwWkL&  JhyRvIz?
PkyTiWdJx,P-I?idRKV!tpf: mESinpG&IRH;PGGzdywzM3qQGehEd?suUaXdk.LJJ'j,lX'u-x.XLf3rmM,IPB
fP? D!Iq.KDDgidhTBQM!OUrjzT&o3FdwR-qweIn.WBjCaWY
:QlbzMa ;vKsHpL,g3wx,V!sRMfV sJJZJ?Uk'oT zIkGXS
ZJn-xe&umvFLHMSwxPh-JA
ALf'e;dLdCjjzxlc
TQp
bj-!kh:ZzdZF:diCIciWf,M

In [65]:
@torch.no_grad()
def estimate_loss(model, eval_iters, batch_size):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        Losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, losses = model(X, Y)
            Losses[i] = losses.item()
        out[split] = Losses.mean()
    model.train()
    return out

In [67]:
m = LanguageModel(vocab_size, n_embd, block_size)
m = m.to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
for iter_ in range(max_iters):
    if iter_ % eval_iters == 0:
        out = estimate_loss(m, eval_iters, batch_size)
        print(f"train loss {out['train']} val loss {out['val']}")
    xb, yb = get_batch('train', batch_size)
    logits, losses = m(xb, yb)
    optimizer.zero_grad()
    losses.backward()
    optimizer.step()

print(f"final loss {losses.item()}")

train loss 4.2946038246154785 val loss 4.295632362365723
train loss 3.3749914169311523 val loss 3.327522039413452
train loss 3.2372970581054688 val loss 3.218822956085205
train loss 3.1244585514068604 val loss 3.166337251663208
train loss 3.0699524879455566 val loss 3.0952811241149902
train loss 3.029994010925293 val loss 3.071331262588501
train loss 2.981454372406006 val loss 3.0435445308685303
train loss 2.9409444332122803 val loss 2.983912944793701
train loss 2.912058115005493 val loss 2.9756081104278564
train loss 2.8543996810913086 val loss 2.9526641368865967
train loss 2.8422157764434814 val loss 2.939225196838379
train loss 2.8240556716918945 val loss 2.902628183364868
train loss 2.820958137512207 val loss 2.8182404041290283
train loss 2.7555530071258545 val loss 2.831132173538208
train loss 2.724179744720459 val loss 2.7914960384368896
train loss 2.706482172012329 val loss 2.8141391277313232
train loss 2.681138515472412 val loss 2.7671778202056885
train loss 2.612917184829712 v

In [68]:
_in = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(_in, max_new_tokens=1000)
print("".join(decode(out[0].tolist())))


Wh ithands.
Ho m sthe I, hatal ldh milothhe aig morer y th, te no t ith mrio ce touwESiticay chh atafun y; V
B

TIit.
Winre! mt hnote, nilly c.

e fapy igesurerh mhe lmairy tay gy yaafoprurn! hthe vavit:
Nir hotanr ms Ko m, teaincherhy urtan:
H
MO?
UWh whe
RHNSis wom ponl:
w
RI!:
Yst K: withe an z l.
Ty cavothor de mauno.
LRPRHUTZAROHWhoou y bwy.

O:
t:
Mo; yarh G, mer g:
O
B
B, o
NI:
M
I:
TI
Yho$nI:RA:
RH
Upopve wedr on!.
o
ZI a,
OPRFy: snege I ph! wirrre hs.


AS medheu ta fo we uane gknarb I thitamirhiin; heyad kosh nig ane h mar reshe;
Pr woyavaist, sy wk ty celge lel!lyox Pdd rhd ch I:ifo s wad issus wthacoul.
RO:


A en.
Tggbiife;
F:
I,
KTENO:
TWho a, th he itayithi telnid jsish imt, a yiv w'ian y a yA. cogte ns an pr? cy csed uvevor hyo aved eneul, werlourhe' oK, heofarml hs, thsy iidia wir et t awdanes itheandaim a anuche!:
sary ved, hoce

TGqho maf
y it PTwe n, othaithon ty thh wasyig p, willch'Uthor yO whe'ithef
Ta ct tLl web, way lryhavo mar I mtha
Vgh iche tupt at s bd
Ba-