First cell is just code I imported from my solution of the previous exercises, which loads and creates the dataset for the MLP.

In [1]:
import random 
import string
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm import tqdm

chars = sorted(set(string.ascii_lowercase))
stoi = {ch: i+1 for i, ch in enumerate(chars)}
stoi['.'] = 0
itos = {i: ch for ch, i in stoi.items()}

with open('names.txt', 'r') as f:
    names = f.read().splitlines()

random.shuffle(names)
n_train = int(0.8*len(names))
n_dev = int(0.9*len(names))
trainset = names[:n_train]
devset = names[n_train:n_dev]
testset = names[n_dev:]
len(trainset), len(devset), len(devset)

def prepare_data(names, cl):
    xs = []
    ys = []
    for name in names:
        name = name + '.'
        context = [stoi['.']]*cl
        for ch in name:
            ix = stoi[ch]
            xs.append(context)
            ys.append(ix)
            context = context[1:] + [ix]
    return torch.tensor(xs), torch.tensor(ys)

def print_context(X, y):
    for context, label in zip(X,y):
        con_str = ''.join(itos[ix.item()] for ix in context)
        print(f'{con_str} --> {itos[label.item()]}')

context_length = 4
X_train, y_train = prepare_data(trainset, context_length)
X_dev, y_dev = prepare_data(devset, context_length)
X_test, y_test = prepare_data(testset, context_length)
print_context(X_train[:10], y_train[:10])

.... --> c
...c --> a
..ca --> e
.cae --> l
cael --> i
aeli --> .
.... --> z
...z --> a
..za --> y
.zay --> i


In [75]:
g = torch.Generator().manual_seed(42)

class Linear:
  
  def __init__(self, fan_in, fan_out, bias=True):
    self.weight = torch.randn((fan_in, fan_out), generator=g) / fan_in**0.5
    self.bias = torch.zeros(fan_out) if bias else None
  
  def __call__(self, x):
    self.out = x @ self.weight
    if self.bias is not None:
      self.out += self.bias
    return self.out
  
  def parameters(self):
    return [self.weight] + ([] if self.bias is None else [self.bias])


class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    if self.training:
      xmean = x.mean(0, keepdim=True) # batch mean
      xvar = x.var(0, keepdim=True) # batch variance
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    # update the buffers
    if self.training:
      with torch.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

class Tanh:
  def __call__(self, x):
    self.out = torch.tanh(x)
    return self.out
  def parameters(self):
    return []

## Initialize Weights and Biases with Zero

In [33]:

vocab_size = len(stoi)
emb_size = 10
n_hidden = 100
C = torch.randn((vocab_size, emb_size), requires_grad=True)
layers = [Linear(context_length*emb_size, n_hidden), Tanh(),
          Linear(n_hidden, n_hidden), Tanh(),
          Linear(n_hidden, n_hidden), Tanh(),
          Linear(n_hidden, n_hidden), Tanh(),
          Linear(n_hidden, vocab_size)]

with torch.no_grad():
    #layers[-1].gamma *= 0.1 #decrease confidence before softmax to improve initial loss
    #layers[-1].weight *= 0.1
    #layers[-1].bias *= 0.1
    for layer in layers:
        if isinstance(layer, Linear):
            #layer.weight *= 5/3 # apply gain due to squishing of tanh
            layer.weight *= 0
            layer.bias *= 0

parameters = [C]
for layer in layers:
    for p in layer.parameters():
        parameters.append(p)
        p.requires_grad = True
print(sum(p.numel() for p in parameters))

37397


In [46]:
train_steps = 100_000
lr = 0.1
batch_size = 64

for steps in range(train_steps):
    # ---forward pass---
    ixs = torch.randint(low=0, high=len(X_train), size=(batch_size,))
    x = C[X_train[ixs]].view(batch_size, emb_size*context_length)
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, y_train[ixs])

    # ---backward pass---
    for layer in layers:
        layer.out.retain_grad()

    for p in parameters:
        p.grad = None

    loss.backward()

    if steps % 10_000 == 0:
        print(f'{steps:7d}/{train_steps:7d}: {loss.item():.4f}')

    for p in parameters:
        p.data += -lr * p.grad

      0/ 100000: 2.9255
  10000/ 100000: 2.8040
  20000/ 100000: 2.8569
  30000/ 100000: 2.8903
  40000/ 100000: 2.8335
  50000/ 100000: 2.8809
  60000/ 100000: 2.8329
  70000/ 100000: 2.7204
  80000/ 100000: 2.8652
  90000/ 100000: 2.9339


In [47]:
print("Non-zero Weights")
for i, layer in enumerate(layers):
    if torch.count_nonzero(layer.out) > 0:
        print(f'Layer {i}, {layer.__class__.__name__}, {layer.out.shape}')

print("Non-zero Gradients")
for i, layer in enumerate(layers):
    if torch.count_nonzero(layer.out.grad) > 0:
        print(f'Layer {i}, {layer.__class__.__name__}, {layer.out.shape}')

Non-zero Weights
Layer 8, Linear, torch.Size([64, 27])
Non-zero Gradients
Layer 8, Linear, torch.Size([64, 27])


Just the last Layer is learning, this is due to the softmax, which maps the all zero outputs of the last layer to a uniform distribution, for which the loss is then calculated and backpropagated. All the previous layers don't learn anything.

## Folding BatchNorms

In [76]:
C = torch.randn((vocab_size, emb_size), requires_grad=True)
layers = [Linear(context_length*emb_size, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
          Linear(n_hidden, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
          Linear(n_hidden, vocab_size, bias=False), BatchNorm1d(vocab_size)]

with torch.no_grad():
    layers[-1].gamma *= 0.1 #decrease confidence before softmax to improve initial loss
    for layer in layers:
        if isinstance(layer, Linear):
            layer.weight *= 5/3 # apply gain due to squishing of tanh

parameters = [C]
for layer in layers:
    for p in layer.parameters():
        parameters.append(p)
        p.requires_grad = True
print(sum(p.numel() for p in parameters))

17424


In [77]:
train_steps = 100_000
lr = 0.1
batch_size = 64

for steps in range(train_steps):
    # ---forward pass---
    ixs = torch.randint(low=0, high=len(X_train), size=(batch_size,))
    x = C[X_train[ixs]].view(batch_size, emb_size*context_length)
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, y_train[ixs])

    # ---backward pass---
    for layer in layers:
        layer.out.retain_grad()

    for p in parameters:
        p.grad = None

    loss.backward()

    if steps % 10_000 == 0:
        print(f'{steps:7d}/{train_steps:7d}: {loss.item():.4f}')

    for p in parameters:
        p.data += -lr * p.grad

      0/ 100000: 3.3212
  10000/ 100000: 1.9722
  20000/ 100000: 2.1599
  30000/ 100000: 1.7938
  40000/ 100000: 1.6946
  50000/ 100000: 1.9471
  60000/ 100000: 2.1108
  70000/ 100000: 1.8634
  80000/ 100000: 1.9576
  90000/ 100000: 1.8849


In [78]:
@torch.no_grad()
def regular_inference(X, y, layers):
    x = C[X].view(X.shape[0], -1)
    for layer in layers:
        if isinstance(layer, BatchNorm1d):
            layer.training = False
        x = layer(x)
    return F.cross_entropy(x, y).item()

regular_inference(X_dev, y_dev, layers)

2.0743706226348877

In [82]:
# create new model without batchnorms
new_layers = [Linear(context_length*emb_size, n_hidden), Tanh(),
          Linear(n_hidden, n_hidden), Tanh(),
          Linear(n_hidden, vocab_size)]

i=0
for layer in new_layers:
    if isinstance(layer, Linear):
        old_lin = layers[i*3]
        bn = layers[i*3+1]
        what = bn.gamma/torch.sqrt(bn.eps+bn.running_var)
        bhat = (-bn.running_mean)/torch.sqrt(bn.eps+bn.running_var)
        layer.weight = old_lin.weight * what
        layer.bias = bn.gamma * bhat + bn.beta
        i+=1

regular_inference(X_dev, y_dev, new_layers)

2.0743706226348877