In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
with open('verne.txt', 'r') as f:
    text = f.read()

vocab_size = len(set(text))

In [3]:
# construct a character level tokenizer
ctoi = {c:i for i,c in enumerate(set(text))}
itoc = {i:c for i,c in enumerate(set(text))}
encode = lambda x: [ctoi[c] for c in x]
decode = lambda x: ''.join([itoc[i] for i in x])

In [24]:
data = torch.tensor(encode(text), dtype=torch.long)

n = int(len(data) *.9)
train_data = data[:n]
val_data = data[n:]

batch_size = 32
block_size = 8
device = torch.device('mps')

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [25]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        #construct a lookup table where each row corresponds to each token
        #and contains the logits for the next tokcn
        self.embedding_table= nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx:torch.Tensor, target:torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor | None]:
        #look up the logits for the next token
        logits = self.embedding_table(idx)

        if target is None:
            loss = None
        else:
            #compute the loss
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            loss = F.cross_entropy(logits, target.view(-1))
        return logits, loss

    def generate(self, idx: torch.Tensor, max_tokens:int) -> torch.Tensor:
        #generate tokens
        with torch.no_grad():
            for _ in range(max_tokens):
                logits, loss = self.forward(idx)
                logits = logits[:, -1, :]
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                idx = torch.cat((idx, next_token), dim=1)
            return idx

In [26]:
torch.manual_seed(1337)
bigram = BigramLanguageModel(vocab_size).to(device) 
x, y = get_batch('train')

print(x.shape)

logits, loss = bigram(x,y)
print(loss)

print(decode(bigram.generate(torch.zeros(1,1, dtype=torch.long, device=device), 100)[0].tolist()))

torch.Size([32, 8])
tensor(5.1389, device='mps:0', grad_fn=<NllLossBackward0>)
0£h Œi((“WI_+z:YyNXn=-1”_Tr5i﻿£:oN“3$
°m/zfŒ"EfYM5>:3&OgPŒ,‘J-6i1/_V_″vfS7I@FnCé=—A
N2:i57ï/)X1!nEb,>


In [27]:
optimizer = torch.optim.AdamW(bigram.parameters(), lr=1e-3)

In [29]:
for i in range(10000):
    x,y = get_batch('train')
    logits, loss = bigram(x,y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print(loss.item())

2.5752177238464355


In [37]:
print(decode(bigram.generate(torch.zeros(1,1, dtype=torch.long, device=device), 100)[0].tolist()))

0
The borven iove s theannokintwaim, we---trs to ar-o ad anted I ves, “Hibouerthedeloke ier Theatis l
