In [1]:
from torch import nn
from torch.nn import functional as F
from pathlib import Path

import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1337)

<torch._C.Generator at 0x103c45eb0>

In [2]:
with Path('../input.txt').open('r', encoding='utf-8') as f:
    text = f.read()

vocab = sorted(list(set(text)))
stoi = { ch: i for i, ch in enumerate(vocab) }
itos = { i: ch for i, ch in enumerate(vocab) }
encode = lambda x: [stoi[s] for s in x]
decode = lambda x: [itos[s] for s in x]

print("".join(decode(encode("hii there"))))

hii there


In [3]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[n:]
val_data = data[:n]

In [4]:
batch_size = 4
block_size = 8
n_embd = 12

eval_iters = 200
max_iters = 4000

vocab_size = len(vocab)

In [5]:
# torch.randint

In [6]:
def get_batch(split, batch_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [19]:
B, T, C = 4, 8, 6
x = torch.randn(B, T, C)
tril = torch.tril(torch.ones(T, T))
tril[:T, :T]

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [None]:
wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei @ x

In [None]:
class Head(nn.Module):
    def __init__(self, head_size, block_size):
        super().__init__()
        self.key = nn.Linear(head_size, head_size, bias=False)
        self.query = nn.Linear(head_size, head_size, bias=False)
        self.value = nn.Linear(head_size, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, C)
        q = self.query(x) # (B, T, C)

        wei = k @ q.transpose(-2, -1) * C**-0.5 # (B, T, T)

        # Why we need `self.tril[:T, :T] == 0?
        # Because in the forward pass, \(T\) (the current sequence length)
        # can be less than or equal to `block_size`. The buffer `self.tril` is sized
        # `(block_size, block_size)`, while `wei` is `(B, T, T)`.
        # If you do `self.tril == 0`, you get a `(block_size, block_size)` mask, which
        # will not match `(T, T)` for smaller \(T\).
        # Using `self.tril[:T, :T] == 0` ensures the mask also has
        # shape \((T, T)\) to match `wei`. This way, the model only attends up
        # to the current sequence length \(T\), rather than the full `block_size`.
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        v = self.value(x) # (B, T, C)
        out = wei @ v # (B, T, T)x(B, T, C) = (B, T, C)
        
        return out

In [13]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(head_size=n_embd, block_size=block_size)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.sa_head(x) # (B, T, C)
        logits = self.lm_head(x) # (B, T, vocab_size)
        
        if targets is None:
            losses = None
        else:
            # targets.shape when cross_entropy turns them
            # to OHE --> (B, T, vocab_size).
            # cross_entropy needs shape (minibatch, C).
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            losses = F.cross_entropy(logits, targets)
            # F.softmax(logits, dim=-1)
        return logits, losses

    def generate(self, idx, max_new_tokens=1000):
        for i in range(max_new_tokens):
            # if i == 20:
            #     break
            # print(f"idx.shape {"idx.shape})
            
            # The context window is block_size,
            # so we have to truncate the sequence to make sure
            # we're not extendnig past block_size
            idx_cond = idx[:, -block_size:]
            # print(f"idx_cond.shape {idx_cond.shape}")
            logits, _ = self(idx_cond)
            # print(f"before indexing {logits.shape}")
            logits = logits[:,-1,:]
            # print(f"after indexing {logits.shape}")
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
            # print()
        return idx

In [14]:
m = LanguageModel(vocab_size, n_embd)
m = m.to(device)

xb, yb = get_batch('train', batch_size)
logits, losses = m(xb, yb)
logits, losses

(tensor([[ 0.2544, -0.1150, -0.1789,  ..., -0.2726,  0.7882, -0.3389],
         [-0.2856, -0.1458, -0.0049,  ..., -0.0748,  0.4821,  0.1438],
         [-0.1902, -0.1198,  0.0850,  ...,  0.0093,  0.3267,  0.4185],
         ...,
         [-0.1761,  0.1390,  0.1561,  ...,  0.1782,  0.1724,  0.1801],
         [-0.4078, -0.0443,  0.0874,  ...,  0.0681,  0.1808,  0.2982],
         [-0.6289, -0.0441,  0.0464,  ...,  0.0414,  0.1434,  0.2973]],
        grad_fn=<ViewBackward0>),
 tensor(4.2712, grad_fn=<NllLossBackward0>))

In [15]:
_in = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(_in)

In [16]:
print("".join(decode(out[0].tolist())))


;GH?YDQXXHvC'efXg!'q!GQudrZs&&z?jRje;ubDXlyLQfFGKQXlzID-D3rEj&!ympinKW:XgZ,O,ZfIwkAumJsq
.BTQlcMDE$LvErXtsp:?WE33Ni:Tlcl.LHjGvrs$rf;DT-q,nIERtN?;ujh'x!YHUiiNTNQ-dhtBI:G$zyPloIzZ&YQ.XyCOLLlvcPhyVqYzuejj?KCbDQJEjL.nL$d&3kbN3V
WBR!ojpeqS'qD,KipfOgDa$PJWCZg;S'-z3hq-ngREqhtTNg3E3esYX;sqDMrPysuMyyOL$qm?&.vUwl znrR$HHhmlw.'CsQwBLm'sgTcGICQ'.BzFNF$Srx3HKEIsMgsslFGkWrVugRIdX
hYpv!-jviiZTeRV MLCUliA!;UOYt!-ApoMyeYAmwDB?Ns!3sZyP$HIeWrcns:PBCOuF&&rib&Og NYOJV-ddsIuABSEXsW.L$&
Urf3lAY;&W
Ypkbg&zELg;Y yX'yt-F,su&lXza!Sx3nI?MW!IHi3kE.!EcxC,i3d.iEzz3RwKCpCZ&nOvufAQ BDEvPC$;m3!FN-PM?$Y3XMg,Ee-?A'kh,-us.xz'.UgUz&wCmBKtnGzZg&v;?TOv'P?qDCETzBd:$ w -LuwgiYXhh:E-wFnCJ;fL;gzDyLhrwXl,b,LoNh3W
yvw3grfWa$EMVsYg,N&IU$UBdtLMKSNMuxB,P-33,?w-3QMlMWJya!,KY;m$f,INttRBZFsds$hFDcY;yg3BlbyunJEjznvHIRAQ.wYhxfgyxIUrHC.y UWv YJt,K
Z
J;3VLe-CT:VZm?! dTCyVVHeRSqDEXPpohX&v
Q.UZ.
$eh&qzRHIsjdisRr
yjuxvvMzSW:n$
3C&ym;ANWhElSmx.TDmo$MK
SnfHyV:t&&We;DSX!VR-?r?yzvKGZZ!FPx'$SWjzsb:glDJsZDCwafMevqlmk&M!fKsC&hISnv?3IlI,SEklFqkfGbdnB

In [20]:
@torch.no_grad()
def estimate_loss(model, batch_size, eval_iters):
    out = {}
    model.eval()
    for split in['train', 'val']:
        losses_list = torch.zeros(eval_iters)
        for i in range(eval_iters):
            X, Y = get_batch(split, batch_size)
            logits, losses = model(X, Y)
            losses_list[i] = losses.item()
        out[split] = losses_list.mean()
    model.train()
    return out

In [21]:
m = LanguageModel(vocab_size, n_embd)
m = m.to(device)

optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

for iter_ in range(max_iters):
    if iter_ % eval_iters == 0:
        out = estimate_loss(m, batch_size, eval_iters)
        print(f"train loss {out['train']} val loss {out['val']}")
    xb, yb = get_batch('train', batch_size)
    logits, losses = m(xb, yb)
    optimizer.zero_grad()
    losses.backward()
    optimizer.step()
print(losses.item())

train loss 4.207266807556152 val loss 4.216193675994873
train loss 3.423068046569824 val loss 3.3658740520477295
train loss 3.293736457824707 val loss 3.2521278858184814
train loss 3.232297420501709 val loss 3.244847059249878
train loss 3.128462314605713 val loss 3.1404178142547607
train loss 3.0582430362701416 val loss 3.088736057281494
train loss 2.9825069904327393 val loss 3.023178815841675
train loss 2.9039318561553955 val loss 2.9639642238616943
train loss 2.835329532623291 val loss 2.879398822784424
train loss 2.765599250793457 val loss 2.863208293914795
train loss 2.7834486961364746 val loss 2.801379442214966
train loss 2.7111623287200928 val loss 2.7769906520843506
train loss 2.7035605907440186 val loss 2.7881929874420166
train loss 2.683370590209961 val loss 2.759216547012329
train loss 2.6446056365966797 val loss 2.7420058250427246
train loss 2.6271276473999023 val loss 2.727348566055298
train loss 2.6083714962005615 val loss 2.6915433406829834
train loss 2.6572649478912354 v

In [22]:
_in = torch.zeros((1, 1), dtype=torch.long)
out = m.generate(_in)
print("".join(decode(out[0].tolist())))


AVo; tse?

A!
Toous ti?
I-EI:
OU:
Avoce PUOyt h k; at apsn sishdta ty oey mf astl yroHhand mis raetanllles gehengyr shociteas gryaf I, ith fe yy it.
B:
Tcli?

L
Wyeean?


Pyorg th belreon h sandt inshesins heense chart dllg su.

Nathy theg?

Seo RNIO:
Wo sd se seml, pes hdsen o,

THT::
KP
OSI loflres tanwou Kwof ud wgo thu ms:
h sus becadesath dins
Bag ce iosnthe th lile kat:
Ehar;: m nthy.
Nhe o athl he ts, withe b.



O:e hyy y
TI
BSAAGL-O:
EThotecoussh ild f o, t,
:
SA,
ir t atese sry mlfee tho
Hcy y sm, b dhoknon tEGl- sorreiroked tedl winrthinurshol tyesriche ina; pify, okrt sse alsint licore xyo;
ANAh ste heuls, b be
Thers thinp nmard thavovisoca yondr nove sf msrud li wamt snri:
A
The y ilrg.

-
RUAITI
AO:
Punirfel b, woy atheentor,,, nben Bryy g y bsds gL,
Dorwig sout houlthd srhyomil, rro'ls he Sod-
Bh'omir mssme,
NI so, I'mir rerche chid lad, men msorep ons-.ON.


LHche o cisled ta wt y aible s;
Th gime I s;
Ae?
uguall:

WEAhitiyr mhpy myicak bonce mtlomis eeocray sottachoth