In [1]:
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt

%matplotlib inline 

<h1>MULTI-LAYER PERCEPTRON NEURAL NETWORK V4</h1>

inspiration: A Neural Probabilistic Language Model, Yoshua Bengio

character-level language model

manual backward pass

In [2]:
class Linear:
    def __init__(self, fan_in, fan_out, bias=True):
        self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5
        self.bias = torch.zeros(fan_out) if bias else None
    def __call__(self, x):
        self.out = x @ self.weight 
        if self.bias is not None:
            self.out += self.bias
        return self.out
    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

class BatchNorm1D:

    def __init__(self, dim, eps=1e-5, momentum=0.1):
        # fields
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # parameters
        self.batch_weight = torch.ones(dim)
        self.batch_bias = torch.zeros(dim)
        # buffers
        self.running_var = torch.ones(dim)
        self.running_mean = torch.zeros(dim)

    def __call__(self, x):

        # find var/mean of current batch (or running batch for inference/evaluation)
        if self.training:
            batch_mean = x.mean(0, keepdim=True)
            batch_var = x.var(0, keepdim=True, unbiased=True)
        else:
            batch_mean = self.running_mean
            batch_var = self.running_var

        x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps) # normalized x
        self.out = self.batch_weight * x_hat + self.batch_bias # apply batch weight and bias

        # update running var/mean if training
        if self.training:
            with torch.no_grad():
                self.running_var = (1 - self.momentum) * self.running_var + (self.momentum) * batch_var
                self.running_mean = (1 - self.momentum) * self.running_mean + (self.momentum) * batch_mean

        return self.out

    def parameters(self):
        return [self.batch_weight, self.batch_bias]


class Tanh:
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out
    def parameters(self):
        return []

In [3]:
# compare pytorch gradients to manual
def cmp(s, dt, t):
    ex = torch.all(dt == t.grad).item()
    app = torch.allclose(dt, t.grad)
    print((dt - t.grad).abs().shape)
    max_diff = (dt - t.grad).abs().max().item()
    print(f'{s:15} | exact: {str(ex):5} | approximate: {str(app):5} | maximum difference: {max_diff}')

In [4]:
words = open("../data/names.txt", "r+").read().splitlines()
char_set = sorted(list(set(''.join(words) + '.'))) 
string_to_index = {char: ind for ind, char in enumerate(char_set)}
index_to_string = {ind: char for char, ind in string_to_index.items()}

In [5]:
block_size = 3
vocab_size = 27

def build_dataset(words):
    X = []
    Y = []
    for word in words:
        context = [0] * block_size
        for char in word + '.':
            X.append(context)
            Y.append(string_to_index[char])
            context = context[1:] + [string_to_index[char]]
    X, Y = torch.tensor(X), torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

import random
random.shuffle(words)
n1, n2 = int(len(words) * 0.8), int(len(words) * 0.9) 
X_train, Y_train = build_dataset(words[:n1])
X_dev, Y_dev = build_dataset(words[n1:n2])
X_test, Y_test = build_dataset(words[n2:])

torch.Size([182366, 3]) torch.Size([182366])
torch.Size([22905, 3]) torch.Size([22905])
torch.Size([22875, 3]) torch.Size([22875])


In [6]:
n_embed = 10 # the dimensionality of the charecter embedding vectors
n_hidden = 200 # number of neurons in hidden layer

C = torch.randn((vocab_size, n_embed)            )

# layer 1
W1 = torch.randn((n_embed * block_size, n_hidden)) * (5/3) / ((n_embed * block_size)** 0.5) 
b1 = torch.randn(n_hidden                        ) * 0.1

# layer 2
W2 = torch.randn((n_hidden, vocab_size)          ) * 0.1 
b2 = torch.randn(vocab_size                      ) * 0.1

# batch norm params
batch_norm_gain = torch.ones((1, n_hidden)) * 0.1 + 1
batch_norm_bias = torch.zeros((1, n_hidden)) * 0.1

# running mean/std
batch_norm_std_running = torch.ones((1, n_hidden))
batch_norm_mean_running = torch.zeros((1, n_hidden))

parameters = [C, W1, b1, W2, b2, batch_norm_gain, batch_norm_bias]
print(sum(parameter.nelement() for parameter in parameters))
for parameter in parameters:
    parameter.requires_grad = True

12297


In [7]:
# single iteration
batch_size = 32
indexes = torch.randint(0, X_train.shape[0], (batch_size,))
X_batch, Y_batch = X_train[indexes], Y_train[indexes]

# embed
embed = C[X_batch]

# linear layer 1
embed_cat = embed.view(embed.shape[0], -1)
h_preact = embed_cat @ W1 + b1

# batch normalize
batch_norm_mean = (1/batch_size) * h_preact.sum(0, keepdim=True) # calculate batch mean
batch_norm_diff = h_preact - batch_norm_mean
batch_norm_diff_sqr = batch_norm_diff**2
batch_norm_var = (1/(batch_size-1)) * (batch_norm_diff_sqr).sum(0, keepdim=True) # calculate batch variance
batch_norm_var_inv = (batch_norm_var + 1e-5)**-0.5
batch_norm_raw = batch_norm_diff * batch_norm_var_inv
h_preact_norm = batch_norm_gain * batch_norm_raw + batch_norm_bias # find the normalized output

# non-linearity
h = torch.tanh(h_preact_norm)

# linear layer 2
logits = h @ W2 + b2 

# cross-entropy
logit_maxes = logits.max(1, keepdim=True).values
norm_logits = logits - logit_maxes
counts = norm_logits.exp()
sum_counts = counts.sum(1, keepdim=True)
sum_counts_inv = sum_counts**-1
probs = counts * sum_counts_inv 
log_probs = probs.log()
loss = -log_probs[range(batch_size), Y_batch].mean()

# backward pass
for p in parameters:
    p.grad = None
for t in [log_probs, probs, sum_counts_inv, sum_counts, counts, norm_logits, logit_maxes, logits, h, h_preact_norm, batch_norm_raw, batch_norm_var_inv, batch_norm_var, batch_norm_diff_sqr, batch_norm_diff, batch_norm_mean, h_preact, embed_cat, embed]:
    t.retain_grad()
loss.backward()
loss.item()

3.4842276573181152

In [8]:
log_probs.shape

torch.Size([32, 27])

In [36]:
# Exercise 1: backprop through the whole thing manually, 
# backpropagating through exactly all of the variables 
# as they are defined in the forward pass above, one by one


d_log_probs = torch.zeros_like(log_probs)
d_log_probs[range(batch_size), Y_batch] = -1.0/batch_size
cmp('logprobs', d_log_probs, log_probs)

d_probs = (1 / probs) * d_log_probs
cmp('probs', d_probs, probs)

d_sum_counts_inv = (counts * d_probs).sum(1, keepdim=True)
cmp('counts_sum_inv', d_sum_counts_inv, sum_counts_inv)


d_counts = sum_counts_inv * d_probs
d_sum_counts = (-sum_counts**-2)* d_sum_counts_inv
cmp('counts_sum', d_sum_counts, sum_counts)

d_counts += torch.ones_like(counts) * d_sum_counts
cmp('counts', d_counts, counts)

d_norm_logits = norm_logits.exp() * d_counts
cmp('norm_logits', d_norm_logits, norm_logits)

d_logit_maxes = (-1 * d_norm_logits).sum(1, keepdim=True)
cmp('logit_maxes', d_logit_maxes, logit_maxes)

d_logits = d_norm_logits

d_logits +=  * d_logit_maxes




cmp('logits' , d_logits, logits)

# cmp('h', dh, h)
# cmp('W2', dW2, W2)
# cmp('b2', db2, b2)
# cmp('hpreact', dhpreact, hpreact)
# cmp('bngain', dbngain, bngain)
# cmp('bnbias', dbnbias, bnbias)
# cmp('bnraw', dbnraw, bnraw)
# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
# cmp('bnvar', dbnvar, bnvar)
# cmp('bndiff2', dbndiff2, bndiff2)
# cmp('bndiff', dbndiff, bndiff)
# cmp('bnmeani', dbnmeani, bnmeani)
# cmp('hprebn', dhprebn, hprebn)
# cmp('embcat', dembcat, embcat)
# cmp('W1', dW1, W1)
# cmp('b1', db1, b1)
# cmp('emb', demb, emb)
# cmp('C', dC, C)

torch.Size([32, 27])
logprobs        | exact: True  | approximate: True  | maximum difference: 0.0
torch.Size([32, 27])
probs           | exact: True  | approximate: True  | maximum difference: 0.0
torch.Size([32, 1])
counts_sum_inv  | exact: True  | approximate: True  | maximum difference: 0.0
torch.Size([32, 1])
counts_sum      | exact: True  | approximate: True  | maximum difference: 0.0
torch.Size([32, 27])
counts          | exact: True  | approximate: True  | maximum difference: 0.0
torch.Size([32, 27])
norm_logits     | exact: True  | approximate: True  | maximum difference: 0.0
torch.Size([32, 1])
logit_maxes     | exact: True  | approximate: True  | maximum difference: 0.0
torch.Size([32, 27])
logits          | exact: False | approximate: True  | maximum difference: 6.05359673500061e-09
