In [34]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [35]:
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [36]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(' '.join(chars))
print(vocab_size)


   ! $ & ' , - . 3 : ; ? A B C D E F G H I J K L M N O P Q R S T U V W X Y Z a b c d e f g h i j k l m n o p q r s t u v w x y z
65


In [37]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

In [38]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:20])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56])


In [39]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [40]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

#### Why `block_size + 1`?:
Apparently each chunk will have `n-1` contiguous "subchunks" that can be trained on. First take character 1, then 1-2, then 1-3, etc.

In [41]:
batch_size = 4
block_size = 8


def get_batch(split="train"):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x, y

In [42]:
xb, yb = get_batch()
print(f"Input: {xb.shape}\n{xb}")
print(f"Output: {yb.shape}\n{yb}")

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, : t + 1]
        output = yb[b, t]
        print(f"When the input is {context.tolist()}, output is {output}")

Input: torch.Size([4, 8])
tensor([[ 1, 47, 52,  1, 39, 47, 56,  1],
        [53, 56, 58, 59, 52, 43,  5, 57],
        [53, 53, 49, 57,  1, 44, 53, 56],
        [ 0, 15, 27, 30, 21, 27, 24, 13]])
Output: torch.Size([4, 8])
tensor([[47, 52,  1, 39, 47, 56,  1, 53],
        [56, 58, 59, 52, 43,  5, 57,  1],
        [53, 49, 57,  1, 44, 53, 56,  1],
        [15, 27, 30, 21, 27, 24, 13, 26]])
When the input is [1], output is 47
When the input is [1, 47], output is 52
When the input is [1, 47, 52], output is 1
When the input is [1, 47, 52, 1], output is 39
When the input is [1, 47, 52, 1, 39], output is 47
When the input is [1, 47, 52, 1, 39, 47], output is 56
When the input is [1, 47, 52, 1, 39, 47, 56], output is 1
When the input is [1, 47, 52, 1, 39, 47, 56, 1], output is 53
When the input is [53], output is 56
When the input is [53, 56], output is 58
When the input is [53, 56, 58], output is 59
When the input is [53, 56, 58, 59], output is 52
When the input is [53, 56, 58, 59, 52], outpu

In [43]:
import torch
import torch.nn as nn
from torch.nn import functional as F


class Bigram(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)  # (B, T, C)

        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss

    def generate(self, idx, max_tokens):
        for _ in range(max_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


In [44]:
m = Bigram(vocab_size)
logits, loss = m(xb, yb)
print(loss)

idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, max_tokens=100)[0].tolist()))

tensor(4.1941, grad_fn=<NllLossBackward0>)

Z.VoLNMxWIOn;;zi&3?guLmaxDeiywhzcuEPl-':XY
y&aCxHJCI,n' or?xkjkvgsY,MRDhDpL,dAUyWWinsxK-rSm&F$kkiCBL


In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [53]:
batch_size = 32
for step in range(1000):
    xb, yb = get_batch()
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    print(loss.item())

2.4293129444122314
2.502713203430176
2.544260263442993
2.5500729084014893
2.5360677242279053
2.4658143520355225
2.4520323276519775
2.580437660217285
2.4628562927246094
2.4875290393829346
2.510249376296997
2.518958330154419
2.562567710876465
2.6080660820007324
2.5171234607696533
2.638807773590088
2.627624750137329
2.5467443466186523
2.439547300338745
2.5945003032684326
2.5450594425201416
2.545886754989624
2.5978844165802
2.6864511966705322
2.5687379837036133
2.5803730487823486
2.4028139114379883
2.6482603549957275
2.721215009689331
2.689544439315796
2.50581693649292
2.405122756958008
2.4964003562927246
2.6119656562805176
2.5305709838867188
2.6006698608398438
2.532158613204956
2.5906784534454346
2.531851053237915
2.516901731491089
2.4475953578948975
2.5750503540039062
2.576291561126709
2.590029001235962
2.4917986392974854
2.4956490993499756
2.3312010765075684
2.39487624168396
2.502662420272827
2.5073940753936768
2.4980697631835938
2.4845476150512695
2.60402512550354
2.419222116470337
2.5

In [56]:
print(decode(m.generate(idx, max_tokens=500)[0].tolist()))


A:
By gre.
G
UCoaditr I pemend litho omoubusilelourumank,
YO an me tavemy cr heswin; orintordey wldamy gsen on l gl y, wo ber fe sef! at ber hay st s be tighou y By iat t bl fot s anjof t: anconomyothandisures yohimy ty, t Ore alert ts m'zear il
FFL:

BESe rikedomy-th Ra los m armomatha s br,
Sheay tse EAnelllome burwishro the ba t my;
Go t chenoure arondo tous othanigrushes ishen owik,
NENo ls:
MES wis ous! un y arth yo GR:
Ticr d thenoikn ar;

Fitsthan netndokfo inche thy?
THAndirs, wirut hous


#### Mathematical trick for self-attention


In [74]:
torch.manual_seed(1337)

B, T, C = 4, 8, 32
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 32])

In [None]:
# xbow = torch.zeros((B, T, C))
# for b in range(B):
#     for t in range(T):
#         xprev = x[b, : t + 1]
#         xbow[b, t] = torch.mean(xprev, 0)

In [None]:
# wei = torch.tril(torch.ones(T, T))
# wei /= wei.sum(1, keepdim=True)
# xbow = wei @ x

In [None]:
# tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
# wei = wei.masked_fill(tril == 0, float("-inf"))
# wei = F.softmax(wei, dim=-1)
# xbow = wei @ x

torch.Size([4, 8, 2])

#### This is attention

In [None]:
head_size = 16  # attention
key = nn.Linear(C, head_size, bias=False)  # "Here's what I am"
query = nn.Linear(C, head_size, bias=False)  # "Who should I focus on?"
value = nn.Linear(C, head_size, bias=False)  # "This is what I can give you"
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)


tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1) * (head_size**-0.5)

v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])