# Imports

In [24]:
import os
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
# Load environment variables
%load_ext dotenv
%dotenv

# Setting Up Vocabulary

In [33]:
# Load the data
with open('../data/input.txt', 'r', encoding='utf-8') as inf:
    text = inf.read()
print(f'Length of dataset in characters: {len(text)}')

Length of dataset in characters: 1115394


In [34]:
# Set up the vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(f'Vocabulary size: {len(chars)}')


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocabulary size: 65


In [41]:
# Set up the encoder and decoder
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s: [stoi[ch] for ch in s] # str -> List[int]
decode = lambda l: ''.join([itos[i] for i in l])   # List[int] -> str

# Loading Data

In [3]:
run = wandb.init(name='train-bigram-baseline', job_type='training')
artifact = run.use_artifact('nanogpt/mini-shakespeare-tensors:latest')
data_dir = artifact.download()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmartmichals[0m ([33mmartymcfly[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   2 of 2 files downloaded.  


In [9]:
# Load the datasets
train = torch.load(os.path.join(data_dir, 'train.pt'))
val   = torch.load(os.path.join(data_dir, 'val.pt'))

In [13]:
# Context window size
block_size = 8
train[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

Below are examples of training samples for the transformer. We arrange the training samples this way s.t. at runtime, the model is used to seeing context sizes of varying lengths, from 1 token of context all the way up until `block_size` tokens of context.

In [16]:
# Show training samples
x = train[:block_size]
y = train[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'When input is {context} the target is: {target}')

When input is tensor([18]) the target is: 47
When input is tensor([18, 47]) the target is: 56
When input is tensor([18, 47, 56]) the target is: 57
When input is tensor([18, 47, 56, 57]) the target is: 58
When input is tensor([18, 47, 56, 57, 58]) the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58


Here we define a function to sample batches from the training or validation datasets. The returned samples are all `block_size` tokens long, with offset prediction tokens. Refer to this later in the training process, since the above implies the training process needs to somehow learn across varying input sequence lengths.

In [22]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split: str):
    # Generate a batch of inputs x and targets y
    data = train if split == 'train' else val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

# Show a batch sample
xb, yb = get_batch('train')
print('Input:')
print(xb.shape)
print(xb)
print('Output:')
print(yb.shape)
print(yb)
print('-------\nEnumeration of all training samples')

# Show all the training samples, xb is 4 x 8, which means we have 32 independent training examples, enumerated below
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f'When input is {context} the target is: {target}')

Input:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
Output:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-------
Enumeration of all training samples
When input is tensor([24]) the target is: 43
When input is tensor([24, 43]) the target is: 58
When input is tensor([24, 43, 58]) the target is: 5
When input is tensor([24, 43, 58,  5]) the target is: 57
When input is tensor([24, 43, 58,  5, 57]) the target is: 1
When input is tensor([24, 43, 58,  5, 57,  1]) the target is: 46
When input is tensor([24, 43, 58,  5, 57,  1, 46]) the target is: 43
When input is tensor([24, 43, 58,  5, 57,  1, 46, 43]) the target is: 39
When input is tensor([44]) the target is: 53
When input is tensor([44, 53]) the target is: 56
W

In [23]:
# Show sample
print(xb)

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


What are logits?

Logits are the raw, unnormalized output of a neural network's final layer before it is passed through a softmax activation function. In the context of machine learning and deep learning, logits typically refer to the raw scores or values produced by the last layer of a neural network, just before applying the softmax function.

In [44]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # Each token reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) # (B, T, C), since idx: (B, T) and (C), i.e. vocab_size logits per token

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # cross-entropy loss expects (N, C) where N is no. samples and c is no. classes as input
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, _ = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes  (B, C)
            # apply softmax for probability distribution
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled idx to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

vocab_size = 65 # from the data setup notebook
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss) # loss is expected to be roughly -ln(1/65), use the cross entropy loss fxn to estimate

# test generation
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [45]:
# Create an optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [60]:
# Train the optimizer
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.375595808029175


In [61]:
# test generation
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


ABRGomyrothas ie lf IUCl t y d,
VOLUS:
Lo t! or yowou is weser ppt, ps f oucilbo, wofe: caisee rin ICHincoutos Pugibosser theyon, n, f thid qure aimeme, shoro, he tr elonchyin s f ond D ad todin come hther he.
Whtcigo athepe mo lane t?



AUThit a hifat ngrustt iunedines al d yofre ak wive as ty t:
A thathefango bllyesmeoule by ma
CARe th d.
Whate,
'junveres poe ty way
Th ph gon Ithe, t mar, ha, t nth ll hast ghatarld mersife mast'd fe heanksofthatheal! f dsalisthe n s uno inghe thorerdichailerd
