## References
- [A Recipe for Training Neural Networks
](https://karpathy.github.io/2019/04/25/recipe/)
- [Harvard CS197 AI Research Experiences](https://docs.google.com/document/d/1uvAbEhbgS_M-uDMTzmOWRlYxqCkogKRXdbKYYT98ooc/edit#heading=h.2z3yllpny6or)
- [Unit tests for machine learning research](https://semla.polymtl.ca/wp-content/uploads/2022/11/Pablo-Unit-tests-for-ML-code-SEMLA-talk.pdf)
- [CS 329S: Machine Learning Systems Design](https://stanford-cs329s.github.io/syllabus.html)

## Set up the end-to-end training/evaluation skeleton + get dumb baselines

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # self.bigram_table = nn.Embedding(vocab_size, vocab_size)
        self.token_embedding_table = nn.Embedding(vocab_size, 16)
        self.linear = nn.Linear(16, vocab_size)
        print('number of parameters:', sum(p.numel() for p in self.parameters()))
    
    def forward(self, token_indexes):
        # token_index: (batch_size, sequence_length)
        # logits = self.bigram_table(token_indexes)

        embedding = self.token_embedding_table(token_indexes)
        logits = self.linear(embedding)
        # logits: (batch_size, sequence_length, vocab_size)
        return logits

    def loss_per_token(self, token_indexes, targets):
        logits = self(token_indexes)
        # logits: (batch_size, sequence_length, vocab_size)
        # targets: (batch_size, sequence_length)
        batch_size, sequence_length, vocab_size = logits.shape
        loss = F.cross_entropy(
            logits.view(batch_size*sequence_length, vocab_size),
            targets.view(batch_size*sequence_length),
            reduction='none'
            )
        # loss: (batch_size*sequence_length)
        return loss.view(batch_size, sequence_length)
    
    def loss(self, token_indexes, targets):
        logits = self(token_indexes)
        # logits: (batch_size, sequence_length, vocab_size)
        # targets: (batch_size, sequence_length)
        batch_size, sequence_length, vocab_size = logits.shape
        loss = F.cross_entropy(
            logits.view(batch_size*sequence_length, vocab_size),
            targets.view(batch_size*sequence_length)
            )
        # loss: scalar
        return loss
    
    def generate(self, token_indexes, max_new_tokens):
        # token_indexes: (batch_size, sequence_length)
        batch_size, sequence_length = token_indexes.shape
        for _ in range(max_new_tokens):
            logits = self(token_indexes)
            # logits: (batch_size, sequence_length, vocab_size)
            next_token_logits = logits[:, -1, :]
            # next_token_logits: (batch_size, vocab_size)
            next_token_probs = F.softmax(next_token_logits, dim=-1)
            # next_token_probs: (batch_size, vocab_size)
            next_token = torch.multinomial(next_token_probs, num_samples=1)
            # next_token: (batch_size, 1)
            token_indexes = torch.cat([token_indexes, next_token], dim=1)
            # token_indexes: (batch_size, sequence_length+1)
        return token_indexes


In [2]:
def rand_int_test(cls, low, high, shape, kwargs):
    layer = cls(**kwargs).cuda()
    random_input = torch.randint(low, high, shape).cuda()
    print('input shape:', random_input.shape)
    output = layer(random_input)
    print('output shape:', output.shape)
    return output

In [3]:
test_cls = BigramLanguageModel
batch_size = 4
context_length = 1024
vocab_size = 256

kwargs = {'vocab_size': vocab_size}
output = rand_int_test(test_cls, 0, vocab_size, (batch_size, context_length), kwargs)

number of parameters: 8448
input shape: torch.Size([4, 1024])
output shape: torch.Size([4, 1024, 256])


In [4]:
from data import get_batch, enc
import tiktoken
import math

x, y = get_batch(batch_size, context_length, 'train')
vocab_size = tiktoken.get_encoding("gpt2").n_vocab
model = BigramLanguageModel(vocab_size).cuda()
loss = model.loss(x.cuda(), y.cuda())
print('random guess loss:', -math.log(1/vocab_size))
print(loss)
loss_per_token = model.loss_per_token(x.cuda(), y.cuda())
print(loss_per_token.shape, loss_per_token.mean())
print(loss_per_token)

number of parameters: 1658481
random guess loss: 10.82490511970208
tensor(10.9950, device='cuda:0', grad_fn=<NllLossBackward0>)
torch.Size([4, 1024]) tensor(10.9950, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([[11.4311, 11.6552, 10.2010,  ..., 10.5417, 10.6344, 11.4137],
        [11.0913, 10.3161, 10.8965,  ..., 11.6884, 11.4491, 10.4440],
        [12.3048, 10.9655, 10.6260,  ..., 10.9756, 11.2433, 10.6060],
        [10.5069, 10.6218, 11.0385,  ..., 11.9397, 10.6035, 10.4034]],
       device='cuda:0', grad_fn=<ViewBackward0>)


In [5]:
input_tokens = x[0, :4].unsqueeze(0).cuda()
max_new_token = 8
generated_tokens = model.generate(input_tokens, max_new_token)
print('input', [enc.decode([i.item()]) for i in input_tokens[0]])
print('output', [enc.decode([i.item()]) for i in generated_tokens[0]])
print('Gold label', [enc.decode([i.item()]) for i in  x[0]])

input [' must', '\n', 'In', ' that']
output [' must', '\n', 'In', ' that', ' Calculator', ' HPV', 'Empty', ' LW', ' Seconds', ' Infinite', ' payoff', 'ste']
Gold label [' must', '\n', 'In', ' that', ' be', ' made', ' more', ' bitter', '.', ' Fear', ' o', "'", 'ers', 'h', 'ades', ' me', ':', '\n', 'Good', ' expedition', ' be', ' my', ' friend', ',', ' and', ' comfort', '\n', 'The', ' gracious', ' queen', ',', ' part', ' of', ' his', ' theme', ',', ' but', ' nothing', '\n', 'Of', ' his', ' ill', '-', 'ta', "'", 'en', ' suspicion', '!', ' Come', ',', ' Cam', 'illo', ';', '\n', 'I', ' will', ' respect', ' thee', ' as', ' a', ' father', ' if', '\n', 'Th', 'ou', ' bear', "'s", 't', ' my', ' life', ' off', ' hence', ':', ' let', ' us', ' avoid', '.', '\n', '\n', 'C', 'AM', 'ILL', 'O', ':', '\n', 'It', ' is', ' in', ' mine', ' authority', ' to', ' command', '\n', 'The', ' keys', ' of', ' all', ' the', ' post', 'ern', 's', ':', ' please', ' your', ' high', 'ness', '\n', 'To', ' take', ' the', '

In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 32
context_length = 1024
iterations = 500
for steps in range(iterations):
    x, y = get_batch(batch_size, context_length, 'train')
    # print(x[0], y[0])
    x, y = x.cuda(), y.cuda()
    loss = model.loss(x, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if steps % 100 == 0:
        print('steps:', steps, 'loss:', loss.item())
print('steps:', steps, 'loss:', loss.item())

steps: 0 loss: 11.002180099487305
steps: 100 loss: 10.177762985229492
steps: 200 loss: 9.035687446594238
steps: 300 loss: 7.741647243499756
steps: 400 loss: 6.681490898132324
steps: 499 loss: 6.057015895843506


In [7]:
input_tokens = x[0, :4].unsqueeze(0).cuda()
max_new_token = 8
generated_tokens = model.generate(input_tokens, max_new_token)
print('input', [enc.decode([i.item()]) for i in input_tokens[0]])
print('output', [enc.decode([i.item()]) for i in generated_tokens[0]])
print('Gold label', [enc.decode([i.item()]) for i in  x[0]])

input ['I', ' have', ' a', ' motion']
output ['I', ' have', ' a', ' motion', ' we', ' Glou', ' depletion', 'Pont', 'ndra', ' vividly', 'object', 'Red']
Gold label ['I', ' have', ' a', ' motion', ' much', ' imports', ' your', ' good', ';', '\n', 'Whe', 'ret', 'o', ' if', ' you', "'ll", ' a', ' willing', ' ear', ' incl', 'ine', ',', '\n', 'What', "'s", ' mine', ' is', ' yours', ' and', ' what', ' is', ' yours', ' is', ' mine', '.', '\n', 'So', ',', ' bring', ' us', ' to', ' our', ' palace', ';', ' where', ' we', "'ll", ' show', '\n', 'What', "'s", ' yet', ' behind', ',', ' that', "'s", ' meet', ' you', ' all', ' should', ' know', '.', '\n', '\n', 'SL', 'Y', ':', '\n', 'I', "'ll", ' p', 'hee', 'ze', ' you', ',', ' in', ' faith', '.', '\n', '\n', 'Host', 'ess', ':', '\n', 'A', ' pair', ' of', ' stocks', ',', ' you', ' rogue', '!', '\n', '\n', 'SL', 'Y', ':', '\n', 'Ye', ' are', ' a', ' baggage', ':', ' the', ' S', 'lys', ' are', ' no', ' rog', 'ues', ';', ' look', ' in', '\n', 'the', ' chr

In [8]:
print('seen tokens: ', batch_size * context_length * iterations)

seen tokens:  16384000


In [9]:
from ngram import Ngram
from data import text, enc
import torch
vocab = list(range(enc.n_vocab))
context_lengh = 16
ngram = Ngram(2, vocab)
inputs = [enc.encode(text)[:context_lengh]]
targets = torch.LongTensor([enc.encode(text)[1:context_lengh+1]]).cuda()
loss = ngram.loss(inputs, targets)
print(loss)
epochs = (batch_size * context_length * iterations) // len(enc.encode(text))
ngram = Ngram(2, vocab)
print(epochs)
for epoch in range(epochs):
    ngram.train(enc.encode(text))
loss = ngram.loss(inputs, targets)
print(loss)

tensor(10.8249, device='cuda:0')
48
tensor(10.0816, device='cuda:0')


In [10]:
ngram = Ngram(2, vocab, 1e-3)
loss = ngram.loss(inputs, targets)
print(loss)
ngram.train(enc.encode(text))
loss = ngram.loss(inputs, targets)
print(loss)

tensor(10.8249, device='cuda:0')
tensor(10.5330, device='cuda:0')
