In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [6]:
g =23

In [2]:
class Linear:

    def __init__(self,fan_in,fan_out,bias=True):
        self.weight = torch.randn(fan_in,fan_out)/fan_in**0.5
        self.bias = torch.zeros([fan_out]) if bias else None
    
    def __call__(self,x):
        self.out = x @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out
    def parameters(self):
        return [self.weight] +([] if self.bias is None else [self.bias])
    




        



In [11]:
class BatchNorm1D:
    def __init__(self,dim,eps=1e-5,momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training= True

        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)

    def __call__(self,x):
        if self.training:
            xmean = x.mean(0,keepdim= True)
            xvar = x.var(0,keepdim = True)
        else:
            xmean = self.running_mean
            xvar = self.running_var
            
        xhat = (x - xmean)/torch.sqrt(self.eps + xvar)
        self.out = self.gamma * xhat + self.beta
        
        if self.training:
            with torch.no_grad():
                self.running_mean = (1-self.momentum)*self.running_mean + self.momentum * xmean
                self.running_var = (1-self.momentum)*self.running_var + self.momentum * xvar

        return self.out

    def parameters(self):
        return [self.gamma,self.beta]
        
class Tanh:
    def __call__(self,x):
        self.out = torch.tanh(x)
        return self.out
    def parameters(self):
        return []







            


In [7]:
words = open('names.txt','r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.']=0
itos = {i:s for s,i in stoi.items()}
vocab_size=(len(itos))
vocab_size


27

In [15]:
n_embd = 10
n_hidden = 100
block_size = 3
g = torch.Generator().manual_seed(2147483647)

C = torch.randn((vocab_size,n_embd))
layers = [
    Linear(n_embd*block_size,n_hidden), BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden,n_hidden),BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden,n_hidden),BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden,n_hidden), BatchNorm1D(n_hidden),Tanh(),
    Linear(n_hidden,n_hidden),BatchNorm1D(n_hidden), Tanh(),
    Linear(n_hidden,vocab_size), BatchNorm1D(vocab_size),
]

with torch.no_grad():
    layers[-1].gamma *= 0.1
    for layer in layers[:-1]:
        if isinstance(layer,Linear):
            layer.weight *= 5/3

parameters = [C] + [p for layer in layers for p in layer.parameters()]

print(sum(p.nelement() for p in parameters))

for p in parameters:
    p.requires_grad = True


    

47551


In [13]:
#buidling the dataset
block_size = 3
def build_dataset(words):
    X,Y = [],[]

    for w in words:
        context = [0]*block_size
        for ch in w + '.':
            ix = stoi[ch]
            Y.append(ix)
            X.append(context)
            context = context[1:] +[ix]
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X,Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr, Ytr = build_dataset(words[:n1])
Xdev, Ydev = build_dataset(words[n1:n2])
Xte,Yte = build_dataset(words[n2:])

In [16]:
max_steps = 20000
batch_size = 32
lossi = []

for i in range(max_steps):
    #making a batch
    ix = torch.randint(0,Xtr.shape[0],(batch_size,),generator=g)
    Xb,Yb = Xtr[ix], Ytr[ix]

    #forward pass
    emb = C[Xb]
    x = emb.view(emb.shape[0],-1)
    for layer in layers:
        x = layer(x)

   
    
    
    loss = F.cross_entropy(x,Yb)
    for layer in layers:
        layer.out.retain_grad()
    #backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    #update
    lr = 0.1 if i<10000 else 0.01
    for p in parameters:
        p.data += -lr*p.grad

    #stats
    if i %1000 ==0:
        print(f'{loss.item():4f}')
    lossi.append(loss.log10().item())

    

3.287393


2.246829
2.114303
2.731234
2.386299
2.241330
2.016474
2.467661
2.457840
2.296380
1.915739
2.505448
2.269818
1.738957
2.316391
2.484395
1.925723
2.258933
2.174850
2.134204


statistics to keep track of: activations in the forward pass, gradients in the backward pass, update to data ratio of weights this ratio should be typically --3 ont he log scale.

areas of research: best ways to initialise, best ways to normlaise