In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
eval_iters = 200
n_emb = 32
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [3]:
x,y = get_batch('train')
x.shape

torch.Size([32, 8])

In [19]:
torch.manual_seed(1337)
B,T,C = 1,3,4 # batch size, sequence length, embedding size
# x = torch.randn(B,T,C)
x = [[[1,2,3,4],[5,6,7,8],[9,10,11,12]]]
x = torch.tensor(x, dtype=torch.float32)

# model definition
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)
q = query(x)
v = value(x)


tril = torch.tril(torch.ones(T, T)) 
wei = torch.matmul(q, k.transpose(-2,-1)) / (head_size ** 0.5)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
# print(wei.shape)

out = torch.matmul(wei, v)

In [46]:
class Head(nn.Module):
    def __init__(self,head_size):
        super().__init__()
        self.key = nn.Linear(n_emb, head_size)
        self.query = nn.Linear(n_emb, head_size)
        self.value = nn.Linear(n_emb, head_size)

        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # A buffer is used because it is not a parameter of the model, but we want it on the GPU

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        wei = torch.matmul(q, k.transpose(-2,-1)) / (C ** 0.5)
        wei = wei.masked_fill(self.tril == 0, float('-inf'))
        wei = F.softmax(wei, dim=1)

        out = torch.matmul(wei, v)
        return out

class MutliHeadAttention(nn.Module):
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])

    def forward(self, x):
        B, T, C = x.shape
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        return out

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_emb)
        self.position_embedding_table = nn.Embedding(block_size, n_emb)
        # self.lm_head = nn.Linear(n_emb, vocab_size)
        self.sa_heads = MutliHeadAttention(4, n_emb//4)
        self.lm_head = nn.Linear(n_emb, vocab_size) # converge back to vocab size so that we can predict the next token 

    def forward(self, idx, targets=None):

        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.sa_heads(x) # (B,T,n_emb)
        logits = self.lm_head(x) # (B,T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #Crop idx to block_size
            idx_cond = idx[:, -block_size:]

            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [47]:
model = BigramLanguageModel()
m = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [48]:
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


step 0: train loss 4.1682, val loss 4.1727
step 300: train loss 1.3057, val loss 1.2969
step 600: train loss 0.9667, val loss 1.0017
step 900: train loss 0.8189, val loss 0.8479
step 1200: train loss 0.7638, val loss 0.7973
step 1500: train loss 0.7025, val loss 0.7066
step 1800: train loss 0.6710, val loss 0.7045
step 2100: train loss 0.6654, val loss 0.6956
step 2400: train loss 0.5975, val loss 0.6304
step 2700: train loss 0.5705, val loss 0.6044


In [54]:
# generate from the model
context = torch.zeros((1, 8), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=5000)[0].tolist())[8:])

Bove is oodwer yourastheors sort, I lis be heseghoulcithe oto scitith the is that ls che bot ot ot ete do he whake odat mor noto these sou thes isger tohe the lirsth che in thate shou oke sh uptoke ando sts th of ast mobe Iose a yount kat od wha loode kse the jor bolad
That gst ias otth secod pod thaperok th thousp ato sh otheaghor th dicols whe lokm ss theo Is hem huspot ket thust the I the st 'sd tht yousciod do acheg kes, gott ordose whe 'swh the pok poufer mas
od he dithough the ! lomanete fs he andat po bithe reros chend as pousth the ghound owh cars gh wake ad weo lotthad Hough ot the potent he be st the ard thel rsuroxne rt tham whake agh mand'egenso of bet the etther mon ache mo soll he
Ars ood th pat qu,
STARDY loke morod blokter odefe rote rgerou che othurtessod fa ss se ar datho che whand che poto sher oo'ught th thentito pihe noo theteghoulgeo od whastakis oto 'rge fobe cot tak sthorove this om th geg. Wher oth't IcChatis o glor shey;
Tod omates th theanwhe otot aud ootth t