In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
words = open('names.txt', 'r').read().splitlines()

In [None]:
class Linear: # linear layer
    
    def __init__(self,fan_in,fan_out,bias=True):
        self.weight = torch.randn((fan_in,fan_out),generator=g)/fan_in**0.5 # uses Kaiming init
        self.bias = torch.zeros(fan_out) if bias else None # init to zeros if bias isn't false
        
    # when you call this module, it will calc wx+b (b will be omitted if bias=False)
    def __call__(self,x):
        self.b = x @ self.weight
        if self.bias is not None:
            self.out += self.bias
        return self.out
    
    # calling .parameters on this module will return the paramaters 
    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])
    
# implemented like PyTorch except for that:
    # affine is automatically True and cannot be changed (meaning we use gamma and beta)
    # track_running_stats is automatically set to True and cannot be changed
    # device by default is CPU
    # datatype by default is float32
class BatchNorm1d:
    
    def __init__(self,dim,eps=1e-5,momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True

        # params (trained with backprop)
        self.gamma = torch.ones(dim) # basically a weight
        self.beta = torch.zeros(dim) # basically a bias

        # buffers (trained with a running 'momentum update')
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)
    
    def __call__(self,x):
        # calculate the forward pass
        if self.training: # flag meaning we use these when training but not during testing or inference
            xmean = x.mean(0,keepdim=True) # batch mean
            xvar = x.var(0,keepdim=True) # batch variance
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x-xmean)/torch.sqrt(xvar + self.eps) # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        
        # update buffers (only updated during training)
        if self.training:
            with torch.no_grad():
                self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1-self.momentum) * self.running_var + self.momentum * xvar
        return self.out
    
    