In [65]:
import random
import torch
from torch.nn import functional as F
random.seed(0x1337_b00b)

# Context length -> How many characters we take as input for the network to predict
# the next
block_size = 3

class Dataset:
    def __init__(self):
        with open('names.txt', 'r') as f:
            names = f.read().split('\n')
        self.names = names

        self.build_vocab()

        # Shuffle names in place
        random.shuffle(self.names)
        # Training set and dev/validation set last index
        # First 80% is used for training, 10% percent for validation, 10% for test
        train_set_idx = int(0.8 * len(names))
        validation_idx = int(0.9 * len(names))
        
        self.X, self.Y = {}, {}
        self.X["train"], self.Y["train"] = self.build_dataset(self.names[:train_set_idx])
        self.X["valid"], self.Y["valid"] = self.build_dataset(self.names[train_set_idx:validation_idx])
        self.X["test"], self.Y["test"] = self.build_dataset(self.names[validation_idx:])
        

    def build_vocab(self):
        # Build vocabulary
        vocab = []
        for name in self.names:
            vocab += name
        self.vocab = sorted(set(vocab))
        
        # Build mapping from letter to integer id and for id to letter
        # Leave the `0` key for `.` (dot) which new treat as a null / terminating char
        self.itos = { i+1:l for i, l in enumerate(self.vocab)}
        self.itos[0] = '.'
        # Build the inverse mapping -> from character to integer id
        self.stoi = { l:i for i, l in self.itos.items()}
                

    def build_dataset(self, words):
        global block_size
        # Inputs
        X = []
        # Targets
        Y = []
        
        # For each name
        for word in words:
            # The start is an empty new context (which contains our designed dot special character)
            context = [0] * block_size
            # For each character in the name (adding dot as a stopping token)
            for ch in word + '.':
                # We add the current context and as an input to the dataset
                X.append(context)
                # Get the index of the current character and add it as a target for a potential
                # generated new character that could follow this context
                idx_ch = self.stoi[ch]
                Y.append(idx_ch)
                # Slide the context window and add the new character to it
                context = context[1:] + [idx_ch]
    
        X = torch.Tensor(X).long()
        Y = torch.Tensor(Y).long()
        return (X, Y)


    def dataset_demo(self, split, count = 10):
        for i, p in zip(self.X[split][:count], self.Y[split][:count]):
            print([self.itos[c.item()] for c in i], "-->", self.itos[p.item()])

In [66]:
# Parameters setup
import torch
from torch.nn import functional as F

batch_size = 32
emb_size = 10
g = torch.Generator().manual_seed(0x1337_b00b)
# Create the embedding
C = torch.randn((27, emb_size), generator=g)

# Initializing the model parameters
# First layer
# We scale the weights to control what is the standard deviation. We set it
# according to a tanh gain (5/3) and the fan_in (square root of the input size).
# source is the kaimin paper initialisation and torch docs
tanh_gain = 5/3
n_hidden = 200
sqrt_of_fan_in = (block_size * emb_size) ** 0.5
W1 = torch.randn((block_size * emb_size, n_hidden), generator=g) * (tanh_gain / sqrt_of_fan_in) # 0.3
# We multiply by 0.01 to make sure the bias is not high and gives us a
# uniform distribution as in layer 2
# Because we are computing the batch norm bias, this is not very useful and is replaced
# by the batch normalisation bias.
b1 = torch.randn(n_hidden, generator=g) * 0.01
# Second layer
# We multiply by 0.01 (or a small scalar) in order to reduce and more uniformly
# distribute the weights for the first training pass and loss calculation, such that
# the logits that we get at the end of the network are giving roughly the same
# probability to any of the characters. This is the same concept as the first layer
# but with an instinctive value instead of something empirical taken from docs
W2 = torch.randn((n_hidden, 27), generator=g) * 0.01
# We multiply by zero to make sure intialization give uniform distribution to the logits
b2 = torch.randn(27, generator=g) * 0

# Parameters to allow the batch normalisation to scale and shift the uniform gaussian
# of a neuron's inputs in a batch, such that it allows the backpropagation step
# to optimise the neuron's shape because the neuron is modified with these 2 values
# which become parameters and have gradients computed as well and are optimised
bn_gain = torch.ones((1, n_hidden))
bn_bias = torch.zeros((1, n_hidden))

# Initialize the mean and the std that will be accumulating the mean and the standard
# deviation of batch normalisation at each training step in order to be used later in
# production when we are feeding a single example without a batch
# Because preactivations are standard gaussian we expect the mean to be close to zero
# and the std to be close to one.
hpreact_mean_final = torch.zeros((1, n_hidden))
hpreact_std_final = torch.ones((1, n_hidden))

parameters = [C, W1, b1, W2, b2, batch_norm_gain, batch_norm_bias]

for p in parameters:
    p.requires_grad = True

param_count = sum(p.nelement() for p in parameters)

In [67]:
steps = []
losses = []

In [68]:
# Training loop
max_steps = 1
for idx in range(max_steps):
    # Minibatch construction
    # Sample indexes from X (minibatch of 32 examples)
    idxs = torch.randint(0, d.X["train"].shape[0], (batch_size,))
    
    # Forward pass, only with the minibatch
    emb = C[d.X["train"][idxs]]
    # Compute the preactivation of the first hidden layer
    emb_cat = emb.view(emb.shape[0], block_size * emb_size) 
    
    hpre_bn = emb_cat @ W1 + b1

    # Compute the mean and the std for the current batch
    # Batch normalisation, before activation, to make every neuron uniform gaussian
    # across/on these 32 examples (the entire batch)
    bn_mean_idx = 1/batch_size * hpre_bn.sum(0, keepdim=True)
    bn_diff = hpreact - hpreact_mean_running_idx
    bn_diff2 = bn_diff**2
    bn_var = 1/batch_size*(bn_diff2).sum(0, keepdim=True)
    bn_var_inv_sqrt = (bn_var + 1e-5)**-0.
    bn_raw = bn_diff * bn_var_inv_sqrt
    # Scale and shift this normalisation such that during the backpropagation step
    # the network will be able to tweak and update this neuron by the means of the
    # operations on this neuron using the batch_norm gain and bias
    hpreact = bn_gain * bn_raw + bn_bias

    # Activate with tanh
    h = torch.tanh(hpreact)
    # Compute the logits -> the values of the last preactivation before the probabilities
    logits = h @ W2 + b2 # log-counts
    # Compute the loss
    # loss = F.cross_entropy(logits, d.Y["train"][idxs])
    # Normalise the rows to convert logits into probabilities
    logits_maxes = logits.max(1, keepdim=True).values
    norm_logits = logits - logits_maxes
    counts = norm_logits.exp()
    counts_sum = counts.sum(1, keepdims=True)
    counts_sum_inv = counts_sum**0.5
    probs = counts * counts_sum_inv
    # Compute the loss over all the 32 inputs
    logprobs = probs.log()
    loss_manual = -logprobs[range(batch_size), d.Y["train"][idxs]].mean()
    print(loss_manual)
    #print(loss.item())
    # Reset the gradients
    for p in parameters:
        p.grad = None

    for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, norm_logits,
              logits_maxes, logits, h, hpreact, bn_raw, bn_var_inv_sqrt, bn_var,
              bn_diff2, bn_diff, hpre_bn, bn_mean_idx, emb_cat, emb]:
        t.retain_grad()
    # Compute the backward pass
    loss.backward()

loss

tensor(-1.1770, grad_fn=<NegBackward0>)


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [35]:
# Utility function to compare manual gradients to torch gradients
def cmp(manual_dt, dt):
    # Check for perfect equality
    are_equal = torch.all(manual_dt == dt).item()
    # Check for equality up to a certail decimal (usefull for unstable floats)
    close_eq = torch.allclose(manual_dt, dt)
    # Check the maximum difference between close equality and perfect equality
    max_diff = (manual_dt - dt).abs().max().item()
    print(f"exact {str(are_equal):.5s}, approx {str(close_eq):.5s}, max diff {max_diff}")