<h1>CNN Token-Level Language Model</h1>

In [4]:
import torch

In [5]:
# model classes

class Linear:
  
  def __init__(self, fan_in, fan_out, bias=True):
    self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5 # kaiming initialization
    self.bias = torch.zeros(fan_out) if bias else None
    
  def __call__(self, x):
    self.out = x @ self.weight
    if self.bias is not None:
      self.out += self.bias
    return self.out
  
  def parameters(self):
    return [self.weight] + ([] if self.bias is None else [self.bias])
  
  
class BatchNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.momentum = momentum
    self.training = True
    # parameters (trained with backprop)
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
    # buffers (trained with updates)
    self.running_mean = torch.zeros(dim)
    self.running_var = torch.ones(dim)
    
  def __call__(self, x):
    # calculate forward pass
    if self.training:
      x_mean = x.mean(0, keepdim=True)
      x_var = x.var(0, keepdim=True)
    else:
      x_mean = self.running_mean
      x_var = self.running_var
    
    x_hat = (x - x_mean) / torch.sqrt(x_var + self.eps)
    self.out = self.gamma * x_hat + self.beta
    
    if self.training:
      with torch.no_grad():
      	self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * x_mean
      	self.running_var = (1 - self.momentum) * self.running_var + self.momentum * x_var