In [48]:
from jaxtyping import Float, Integer
import torch
import torch.nn
import torch.random
from typing import Dict, List, Tuple, Union, Optional

with open('crime_and_punishment.txt') as f:
    text = f.read()
print(len(text))

vocab = sorted(set(text))

def c2i(c: str) -> int:
    return vocab.index(c)

def i2c(i: int) -> str:
    return vocab[i]

text_ids = [c2i(c) for c in text]

torch.random.manual_seed(1337)

1154391


<torch._C.Generator at 0x113c661f0>

In [28]:
k = int(0.9*len(text_ids))
train = text_ids[:k]
test = text_ids[k:]

In [29]:
# Let's construct our training examples.
# We want to take our text, some block size, and generate a random sample.
# How did Karpathy's approach ensure the whole data was used in a given epoch?
# I seem to recall him randomly choosing an offset. Well let's just go with that for now
# and update it later when we want to do training
def create_batch(batch_size: int, block_size: int, split: str) -> Tuple[Float[torch.Tensor, 'B T C'], Float[torch.Tensor, 'B T C']]:
    if split == 'train':
        text_ids = train
    elif split == 'test':
        text_ids = test
    else:
        raise ValueError()
    rand_starts: Float[torch.Tensor, '...'] = torch.randint(len(text_ids) - block_size, (batch_size,))
    x = torch.tensor([text_ids[rand_start:rand_start+block_size] for rand_start in rand_starts.tolist()])
    y = torch.tensor([text_ids[rand_start+1:rand_start+block_size+1] for rand_start in rand_starts.tolist()])
    return x, y

x, y = create_batch(4, 8, 'test')

In [30]:
def estimate_loss(model: torch.nn.Module, eval_iters: int, batch_size: int, block_size: int) -> Dict[str, Float[torch.Tensor, '...']]:
    model.eval()
    result: Dict[str, Float[torch.Tensor, '...']] = {}
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iters)
        for iter in range(eval_iters):
            x, y = create_batch(batch_size, block_size, split)
            _, loss = model(x, y)
            losses[iter] = loss.item()
        result[split] = losses.mean()
    return result

In [76]:
class GPT(torch.nn.Module):
    def __init__(self, vocab: List[str], hdim: int):
        super().__init__()
        self.vocab = vocab
        self.embedding = torch.nn.Embedding(len(vocab), hdim)
        self.final_proj = torch.nn.Linear(hdim, len(vocab))
        self.loss_fn = torch.nn.CrossEntropyLoss()
        # Now we need to make a loss.

    def forward(self, x: Integer[torch.Tensor, 'B T'], y: Optional[Integer[torch.Tensor, 'B T']] = None) -> Tuple[Float[torch.Tensor, 'B T C'], Optional[Float[torch.Tensor, '...']]]:
        x = self.embedding(x)
        logits = self.final_proj(x)

        if y is None:
            return logits, None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            y = y.view(B*T)
            loss = self.loss_fn(logits, y)
            return logits, loss

    def generate(self, context: Integer[torch.Tensor, 'B T'], max_output_length: int) -> Integer[torch.Tensor, "B T"]:
        for _ in range(max_output_length):
            # Convert context into input IDs. This requires knowing context size.
            logits, _ = self(context)
            logits = logits[:, -1, :]
            # Now find the word closest to this logit. Question: How is that done? Answer: it actually is sampling just from a multinomial over the probs.  
            probs = torch.nn.functional.softmax(logits, dim=-1)
            idx = torch.multinomial(probs, 1)
            # Then feed those words as context in. Once you get to length `context_size` start doing a sliding window.
            context = torch.cat((context, idx), dim=1)
        return context


gpt = GPT(vocab, 32)
gpt(x, None)
context = torch.tensor(c2i('\n')).view(1, 1)
print(context)
print(context.shape)
''.join(i2c(x) for x in gpt.generate(context, 100)[0].tolist())
estimate_loss(gpt, 500, 4, 8)
optim = torch.optim.AdamW(gpt.parameters(), lr=1e-3)

tensor([[0]])
torch.Size([1, 1])


In [77]:
for step in range(10000):
    gpt.zero_grad()
    batch = create_batch(4, 8, 'train')
    logits, loss = gpt(*batch)
    loss.backward()
    optim.step()
estimate_loss(gpt, 500, 4, 8)

{'train': tensor(2.4882), 'test': tensor(2.5584)}

In [78]:
''.join(i2c(x) for x in gpt.generate(context, 100)[0].tolist())

'\nHal,\ntexod, s sur lloughag pe, Acuromeckie o, iere harionkee gethay orrosng way, itin lkigazs, fece '