# Batch normalization Task

In [1]:
import numpy as np

# define some helper classes helps me in my code

In [2]:
from typing import Any


class MyMean():
    def __call__(self, X, axis=0) -> np.ndarray:
        return self.forward(X, axis=0)

    def forward(self, X, axis=0) -> np.ndarray:
        self.M = X.shape[:axis+1]
        self.N = X.shape[axis+1:]
        self.axis = axis
        out = np.sum(X, axis=axis) / self.M
        return out

    def backward(self, dz) -> np.ndarray:
        dz = np.asarray(dz)
        assert dz.shape == self.N, "dz shape doesn't equal the number of features passed in forward propagation"
        dz = np.atleast_2d(dz)
        dout = np.repeat(dz, self.M, self.axis) / self.M
        return dout

class MyPow():
    def __init__(self, x, pow=2) -> None:
        self.x = x
        self.pow = pow

    def __call__(self):
        return self.forward()
    def forward(self):
        return self.x**self.pow

    def backward(self, dout):
        dout = dout * self.pow * self.x**(self.pow-1)
        return dout

class MyBroadcasting:
    def __call__(self, x:np.ndarray, shape:int) -> np.ndarray:
        return self.forward(x, shape)

    def forward(self, x:np.ndarray, shape:tuple):
        self.shape = shape
        out = np.broadcast_to(x, shape)
        return out

    def backward(self, dz:np.ndarray):
        assert dz.shape == self.shape, "dz shape doesn't equal the shape of z passed in forward propagation"
        dz = np.sum(dz, axis=0)
        return dz

class Center:
    def __init__(self) -> None:
        self.mymean = MyMean()
        self.mybrod = MyBroadcasting()

    def __call__(self, X, axis=0) -> Any:
        return self.forward(X, axis=axis)

    def forward(self, X, axis=0):
        mu = self.mymean(X, axis)
        mu_brod = self.mybrod(mu, X.shape)
        x_centered = X - mu_brod
        return x_centered
    
    def backward(self, dz):
        dx1 = dz
        dmu_brod = dz
        dmu = self.mybrod.backward(dmu_brod)
        dx2 = self.mymean.backward(dmu)
        dx = dx1+dx2
        return dx

class Mul:
    def __call__(self, x1, x2):
        return self.forward(x1, x2)

    def forward(self, x1, x2):
        self.x1 = x1
        self.x2 = x2
        return x1 * x2
    def backward(self, dz):
        return dz*self.x1, dz*self.x2
    
class Sum:
    def __call__(self, x1, x2):
        return self.forward(x1, x2)

    def forward(self, x1, x2):
        self.x1 = x1
        self.x2 = x2
        return x1 + x2
    def backward(self, dz):
        return dz, dz

### define BatchNorm and BatchNormalization classes, my and tutorial's classes

In [3]:
class BatchNorm: # class of mine
    def forward(self, x:np.ndarray, gamma, beta, eps):
        M, N = x.shape
        gamma , beta = np.array([gamma]), np.array([beta]*N)

        self.center = Center() # 
        xmu = self.center(x)

        self.pow2 = MyPow(xmu, 2) # 
        xmu2 = self.pow2()

        self.mymean = MyMean() # 
        var = self.mymean(xmu2)

        self.pow05 = MyPow(var+eps, 0.5) # 
        std = self.pow05()

        self.pow_1 = MyPow(std, -1) # 
        istd = self.pow_1()

        self.mybrod_istd = MyBroadcasting() # 
        istd_brod = self.mybrod_istd(istd, xmu.shape)

        self.mul_norm = Mul() # 
        x_norm = self.mul_norm(xmu, istd_brod)

        self.mybrod_gamma = MyBroadcasting() # 
        gamma_brod = self.mybrod_gamma(gamma, x_norm.shape)


        self.mul_gamma = Mul() # 
        gamma_x_norm = self.mul_gamma(x_norm, gamma_brod)

        self.mybrod_beta = MyBroadcasting() # 
        beta_brod = self.mybrod_beta(beta, gamma_x_norm.shape)

        self.sum = Sum() # 
        out = self.sum(gamma_x_norm, beta_brod)

        return out
    
    def backward(self, dout):
        grad_gamma_x_norm, grad_beta_brod = self.sum.backward(dout)
        grad_beta = self.mybrod_beta.backward(grad_beta_brod) # final

        grad_x_norm, grad_gamma_brod = self.mul_gamma.backward(grad_gamma_x_norm)
        grad_gamma = self.mybrod_gamma.backward(grad_gamma_brod) # final

        grad_x1mu, grad_istd_brod = self.mul_norm.backward(grad_x_norm)
        grad_istd = self.mybrod_istd.backward(grad_istd_brod)
        grad_std = self.pow_1.backward(grad_istd)
        grad_var = self.pow05.backward(grad_std)
        grad_xmu2 = self.mymean.backward(grad_var)
        grad_x2mu = self.pow2.backward(grad_xmu2)

        grad_xmu = grad_x1mu + grad_x2mu
        grad_x = self.center.backward(grad_xmu) # final

        return grad_x, grad_gamma, grad_beta

class BatchNormalization: # class of tutorial
    def forward(self, x, gamma, beta, eps):
        N, D = x.shape

        #step1: calculate mean
        mu = 1./N * np.sum(x, axis = 0)

        #step2: subtract mean vector of every trainings example
        xmu = x - mu

        #step3: following the lower branch - calculation denominator
        sq = xmu ** 2

        #step4: calculate variance
        var = 1./N * np.sum(sq, axis = 0)

        #step5: add eps for numerical stability, then sqrt
        sqrtvar = np.sqrt(var + eps)

        #step6: invert sqrtwar
        ivar = 1./sqrtvar

        #step7: execute normalization
        xhat = xmu * ivar

        #step8: Nor the two transformation steps
        gammax = gamma * xhat

        #step9
        out = gammax + beta

        #store intermediate
        self.cache = (xhat,gamma,xmu,ivar,sqrtvar,var,eps)
        

        return out
    
    def backward(self, dout):
        #unfold the variables stored in cache
        xhat,gamma,xmu,ivar,sqrtvar,var,eps = self.cache

        #get the dimensions of the input/output
        N,D = dout.shape

        #step9
        dbeta = np.sum(dout, axis=0)
        dgammax = dout #not necessary, but more understandable

        #step8
        dgamma = np.sum(dgammax*xhat, axis=0)
        dxhat = dgammax * gamma

        #step7
        divar = np.sum(dxhat*xmu, axis=0)
        dxmu1 = dxhat * ivar

        #step6
        dsqrtvar = -1. /(sqrtvar**2) * divar

        #step5
        dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar

        #step4
        dsq = 1. /N * np.ones((N,D)) * dvar

        #step3
        dxmu2 = 2 * xmu * dsq

        #step2
        dx1 = (dxmu1 + dxmu2)
        dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)

        #step1
        dx2 = 1. /N * np.ones((N,D)) * dmu

        #step0
        dx = dx1 + dx2

        return dx, dgamma, dbeta

In [5]:
# Create the data
arr = np.arange(20).reshape(5, 4)
gamma = 2
beta = 1
eps = 0.1

# test both classes
mybn = BatchNorm() # class of mine
myout = mybn.forward(arr, gamma, beta, eps)
mygrad_arr, mygrad_gamma, mygrad_beta = mybn.backward(myout)

ttbn = BatchNormalization() # class of tutorial
ttout = ttbn.forward(arr, gamma, beta, eps)
ttgrad_arr, ttgrad_gamma, ttgrad_beta = mybn.backward(ttout)

# make usre that both classes have the same outputs
print(np.all(myout == ttout))
print(np.all(mygrad_arr == ttgrad_arr))
print(np.all(mygrad_beta == ttgrad_beta))
print(np.all(mygrad_gamma == ttgrad_gamma))

True
True
True
True
