In [33]:
import torch 
import torch.nn.functional as F
import matplotlib.pyplot as plt 

g =torch.Generator().manual_seed(1)

class Linear:

    def __init__(self, fan_in, fan_out, bias = True):
        self.weight = torch.randn((fan_in, fan_out), generator = g) / fan_in**0.5
        self.bias   = torch.zeros(fan_out) if bias else None 
    
    def __call__(self,x):
        self.out = x @ self.weight
        if self.bias is not None:
            self.out += self.bias 
        return self.out 

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else[self.bias])

    
class Tanh:
    def __call__(self,x):
        self.out = torch.tanh(x)
        return self.out 
    def parameters(self):
        return []

class BatchNorm1d:

    def __init__(self, dim, eps = 1e-5, momentum = 0.1):

        self.eps = eps
        self.momentum = momentum 
        self.training = True 

        # parameters trained  
        self.gamma = torch.ones(dim)
        self.beta  = torch.zeros(dim)

        #buggers traied with running momentums  
        self.running_mean = torch.zeros(dim)
        self.running_std  = torch.ones(dim)

    def __call__(self, x):
        # calculate the forward pass
        if self.training:
            xmean = x.mean(0, keepdim=True) # batch mean
            xvar = x.var(0, keepdim=True) # batch variance
        else:
            xmean = self.running_mean
            xvar = self.running_var
            xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
            self.out = self.gamma * xhat + self.beta
        # update the buffers
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out
  
    def parameters(self):
        return [self.gamma, self.beta]





