In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
with open("data/shakespeares.txt", mode='r', encoding="utf-8") as f:
    text = f.read()

In [3]:
chars = list(set(text))
vocab_size = len(chars)

stoi = {c:i for i,c in enumerate(chars)}
itos = {i:c for i,c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda ids: "".join([itos[idx] for idx in ids])

In [4]:
" ".join(chars), vocab_size, decode(encode("I love deeplearning"))

("y t V E X o b h R M T   G B c d u v I N e - $ D n A F ; x . s z \n ? J , ! H W Z K w i ' Q Y : P S m & p U f k r L O l 3 C q a j g",
 65,
 'I love deeplearning')

In [5]:
data = torch.tensor(encode(text))
n1 = int(0.9 * len(data))
train_data = data[:n1]
val_data = data[n1:]
len(train_data), len(val_data)

(1003854, 111540)

In [6]:
torch.manual_seed(1337)
batch_size = 32
block_size = 8 # context size

def get_batch(split):
    data = train_data if split == "train" else val_data

    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([data[i: i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x,y

In [7]:
xb, yb = get_batch("train")
xb.shape, yb.shape

(torch.Size([32, 8]), torch.Size([32, 8]))

In [19]:
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_token_table = nn.Embedding(vocab_size,vocab_size)
    def __call__(self, idx:Tensor, target:Tensor=None): # idx:(B,T), target:(B,T)
        logits:Tensor = self.embedding_token_table(idx) # (B,T,C)
        loss=None
        if target is not None:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(-1)
            loss = F.cross_entropy( logits, target)
        return logits, loss

    def generate(self, idx:Tensor, max_token): # idx:(B,T)
        for _ in range(max_token):
            logits, _ = self(idx)
            logits = logits[:,-1,:] # (B,C)
            probs = F.softmax(logits,dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idx = torch.cat((idx,idx_next),dim=1) # (B,T+1)
        return idx

model = BigramLanguageModel()

idx = torch.zeros((1, 1),dtype=torch.long)
print(decode(model.generate(idx, 100)[0].tolist()))

yUuvT;pfC'P'v'Av
FM-vRm
B,n
hbiwfRRoNJRHGyC3eG B!zhfwbn'NiZUkUhIyU;CbQLK-',HwQspR .ncqWK:IykfUDNb.mdI


In [20]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [25]:
for i in range(10000):
    xb,yb = get_batch("train")
    logits, loss = model(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if i % 1000 ==0:
        print(loss.item())
print(decode(model.generate(idx, 100)[0].tolist()))

2.580059051513672
2.358429431915283
2.4657866954803467
2.4264559745788574
2.4315245151519775
2.4069695472717285
2.415640354156494
2.519814968109131
2.4235448837280273
2.429150104522705
y de
TI air thind,

CKENINCe hatier t halequr

UScthoind ie, igeall:

ARid wha st,
LAnotikikemeout mp


## The mathematical trick in self-attention

In [4]:
torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])