In [164]:
import urllib

In [165]:
!wget https://raw.githubusercontent.com/karpathy/ng-video-lecture/master/input.txt

In [166]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [167]:
len(text)

1115393

In [168]:
text[:1000]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [169]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [170]:
stoi = { ch:i for i,ch in enumerate(chars)}
itos = { i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [171]:
print(encode('saluu'))
print(decode(encode('saluu')))

[57, 39, 50, 59, 59]
saluu


In [172]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
data[:36]

torch.Size([1115393]) torch.int64


tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63])

In [173]:
# Train/Validation dataset split
n = int(0.9 * len(data))
data_train = data[:n]
data_val = data[n:]

In [174]:
block_size = 8
x = data_train[:block_size]
y = data_train[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'Context: {context} | Target: {target}')

Context: tensor([18]) | Target: 47
Context: tensor([18, 47]) | Target: 56
Context: tensor([18, 47, 56]) | Target: 57
Context: tensor([18, 47, 56, 57]) | Target: 58
Context: tensor([18, 47, 56, 57, 58]) | Target: 1
Context: tensor([18, 47, 56, 57, 58,  1]) | Target: 15
Context: tensor([18, 47, 56, 57, 58,  1, 15]) | Target: 47
Context: tensor([18, 47, 56, 57, 58,  1, 15, 47]) | Target: 58


In [175]:
# Data batching
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = data_train if split == 'train' else data_val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    return x, y

In [176]:
xb, yb = get_batch('train')
xb

tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]])

In [177]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'Input: {context} | Target: {target}')

Input: tensor([53]) | Target: 59
Input: tensor([53, 59]) | Target: 6
Input: tensor([53, 59,  6]) | Target: 1
Input: tensor([53, 59,  6,  1]) | Target: 58
Input: tensor([53, 59,  6,  1, 58]) | Target: 56
Input: tensor([53, 59,  6,  1, 58, 56]) | Target: 47
Input: tensor([53, 59,  6,  1, 58, 56, 47]) | Target: 40
Input: tensor([53, 59,  6,  1, 58, 56, 47, 40]) | Target: 59
Input: tensor([49]) | Target: 43
Input: tensor([49, 43]) | Target: 43
Input: tensor([49, 43, 43]) | Target: 54
Input: tensor([49, 43, 43, 54]) | Target: 1
Input: tensor([49, 43, 43, 54,  1]) | Target: 47
Input: tensor([49, 43, 43, 54,  1, 47]) | Target: 58
Input: tensor([49, 43, 43, 54,  1, 47, 58]) | Target: 1
Input: tensor([49, 43, 43, 54,  1, 47, 58,  1]) | Target: 58
Input: tensor([13]) | Target: 52
Input: tensor([13, 52]) | Target: 45
Input: tensor([13, 52, 45]) | Target: 43
Input: tensor([13, 52, 45, 43]) | Target: 50
Input: tensor([13, 52, 45, 43, 50]) | Target: 53
Input: tensor([13, 52, 45, 43, 50, 53]) | Targe

In [178]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [179]:
m = BigramLanguageModel(vocab_size=vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.8948, grad_fn=<NllLossBackward0>)


In [180]:
print(decode(m.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [187]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch('train')
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.468960762023926


In [188]:
print(decode(m.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=300)[0].tolist()))


CENO,
Youthouth ngre mir
LYRD:
O: tines o ehofa liteloroulis de anod ct mer t losthowarkerolin:
IO:
He he faregur t ayorff houndfis cod sws h ovig nd fa his dve aclidipo her d
PSALedrs fureauplf tton meas offe, feeay ous?
an stheh mis I maky,
Whery wanotinabusown anithy g, tom o whoull ou teedow tip
