In [None]:
"""
model: Neural Language Models (Bengio et al. 2003) URL: https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf

Dimension key:
# windows
B: batch size
T: sequence length

# input/output
V: vocabulary size
E: embedding dimension
D: model dimension
"""
import picograd
import matplotlib.pyplot as plt
%matplotlib inline
# from jaxtyping import ...
g = picograd.Generator().manual_seed(1337) # for .randn()

B, T = 32, 3
V, E, D = 27, 10, 200

# step: 0/200000, loss 27.63208770751953
# -> expected loss = nll = p(c) = -picograd.tensor(1/V=27).log() = 3.2958
# -> self.W = picograd.randn() is sampling from N(0, 1)
# -> self.W * [gain/sqrt(D_in)] (picograd.init_kaimingnormal())

# residuals + normalization + Adam/RMSprop has made initialization less fragile
# -> b/c initialization is fragile/intractable with *deep* neural networks

class Linear:
    def __init__(self, D_in, D_out, bias=True):
        self.W_DiDo = picograd.randn((D_in, D_out), generator=g) * (5/3)/D_in**0.5 # kaiming init (He et al. 2015)
        self.b_Do = picograd.zeros(D_out) if bias else None

    def __call__(self, X_Di):
        self.X_Do = X_Di @ self.W_DiDo
        if self.b_Do is not None:
            self.X_Do += self.b_Do
        self.out = self.X_Do
        return self.X_Do

    def parameters(self):
        return [self.W_DiDo] + ([] if self.b_Do is None else [self.b_Do])

class BatchNorm1D:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = picograd.ones(dim)
    self.beta = picograd.zeros(dim)
    # buffers (trained with a running 'momentum update')
    self.running_mean = picograd.zeros(dim)
    self.running_var = picograd.ones(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    if self.training:
      xmean = x.mean(0, keepdim=True) # batch mean
      xvar = x.var(0, keepdim=True) # batch variance
    else:
      xmean = self.running_mean
      xvar = self.running_var
    xhat = (x - xmean) / picograd.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    # update the buffers
    if self.training:
      with picograd.no_grad():
        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

class Tanh:
    def __call__(self, X_BD):
        self.X_BD = picograd.tanh(X_BD)
        # plt.hist(self.X_BD.view(-1).tolist(), 50); # distribution of weights
        # plt.imshow(self.X_BD.abs() > 0.99, cmap='gray', interpolation='nearest') # vanishing gradients
        self.out = self.X_BD
        return self.X_BD
    
    def parameters(self):
        return []

model = [
    Linear(T * E, D, bias=False), BatchNorm1D(D), Tanh(),
    Linear(D, D, bias=False), BatchNorm1D(D), Tanh(),
    Linear(D, V, bias=False), BatchNorm1D(V)
]

C = picograd.randn((V,E), generator=g)
params = [C] + [p for l in model for p in l.parameters()]
for p in params:
    p.requires_grad = True