In [7]:
#from jaxtyping import Array, Float, PyTree
import torch
import torch.nn
import torch.random
from typing import List, Tuple

with open('crime_and_punishment.txt') as f:
    text = f.read()
print(len(text))

vocab = sorted(set(text))

def c2i(c: str) -> int:
    return vocab.index(c)

def i2c(i: int) -> str:
    return vocab[i]

text_ids = [c2i(c) for c in text]

torch.random.manual_seed(1337)

1154391


<torch._C.Generator at 0x111b86c30>

In [20]:
k = int(0.9*len(text_ids))
train = text_ids[:k]
test = text_ids[k:]

In [21]:
# Let's construct our training examples.
# We want to take our text, some block size, and generate a random sample.
# How did Karpathy's approach ensure the whole data was used in a given epoch?
# I seem to recall him randomly choosing an offset. Well let's just go with that for now
# and update it later when we want to do training
def create_batch(batch_size: int, block_size: int, split) -> Tuple[torch.Tensor, torch.Tensor]:
    if split == 'train':
        text_ids = train
    elif split == 'test':
        text_ids = test
    else:
        raise ValueError()
    rand_starts = torch.randint(len(text_ids) - block_size, (batch_size,))
    x = torch.tensor([text_ids[rand_start:rand_start+block_size] for rand_start in rand_starts.tolist()])
    y = torch.tensor([text_ids[rand_start+1:rand_start+block_size+1] for rand_start in rand_starts.tolist()])
    return x, y

x, y = create_batch(4, 8, 'test')

In [24]:
def estimate_loss(model, eval_iters, batch_size, block_size):
    model.eval()
    result = {}
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iters)
        for iter in range(eval_iters):
            x, y = create_batch(batch_size, block_size, split)
            losses[iter] = model.loss(x, y).item()
        result[split] = losses.mean()
    return result

In [27]:
class GPT(torch.nn.Module):
    def __init__(self, vocab: List[str], hdim: int):
        super().__init__()
        self.vocab = vocab
        self.embedding = torch.nn.Embedding(len(vocab), hdim)
        self.final_proj = torch.nn.Linear(hdim, len(vocab))
        self.loss_fn = torch.nn.CrossEntropyLoss()
        # Now we need to make a loss.

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        logits = self.final_proj(x)
        return logits

    def loss(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        logits = self(x)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        y = y.view(B*T)
        loss = self.loss_fn(logits, y)
        return loss

gpt = GPT(vocab, 32)
gpt(x)
gpt.loss(x, y)
estimate_loss(gpt, 500, 4, 8)

{'train': tensor(4.8465), 'test': tensor(4.8414)}