In [28]:
import numpy as np

def batchnorm_forward(x, gamma, beta, eps):
 
  N, C, H, W = x.shape
  #step1: calculate mean
  mu = 1./N/H/W * np.sum(x, axis = (0,2,3), keepdims=True)
  #step2: subtract mean vector of every trainings example
  xmu = x - mu
  #step3: following the lower branch - calculation denominator
  sq = xmu ** 2
  #step4: calculate variance
  var = 1./N/H/W * np.sum(sq, axis = (0,2,3), keepdims=True)
  #step5: add eps for numerical stability, then sqrt
  sqrtvar = np.sqrt(var + eps)
  #step6: invert sqrtwar
  ivar = 1./sqrtvar
  #step7: execute normalization
  xhat = xmu * ivar
  #step8: Nor the two transformation steps
  gammax = gamma * xhat
  #step9
  out = gammax + beta
  #store intermediate
  cache = (xhat,gamma,xmu,ivar,sqrtvar,var,eps)
 
  return out, cache

def batchnorm_backward(dout, cache):
 
  #unfold the variables stored in cache
  xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache
 
  #get the dimensions of the input/output
  N,C,H,W = dout.shape
 
  #step9
  dbeta = np.sum(dout, axis=(0,2,3))
  dgammax = dout #not necessary, but more understandable
 
  #step8
  dgamma = np.sum(dgammax*xhat, axis=(0,2,3))
  dxhat = dgammax * gamma
 
  #step7
  divar = np.sum(dxhat*xmu, axis=(0,2,3), keepdims=True)
  dxmu1 = dxhat * ivar
 
  #step6
  dsqrtvar = -1. /(sqrtvar**2) * divar   # 1,c,1,1
 
  #step5
  dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar  #1,c,1,1
 
  #step4
  dsq = 1. /N/H/W * np.ones((N,C,H,W)) * dvar
 
  #step3
  dxmu2 = 2 * xmu * dsq
 
  #step2
  dx1 = (dxmu1 + dxmu2)
  dmu = -1 * np.sum(dxmu1+dxmu2, axis=(0,2,3), keepdims=True)
 
  #step1
  dx2 = 1. /N/H/W * np.ones((N,C,H,W)) * dmu
 
  #step0
  dx = dx1 + dx2
 
  return dx, dgamma, dbeta
 

In [32]:
import torch as t
x = np.random.rand(2,2,2,2)
gamma = np.array([1,1])
beta = np.array([0,0])
eps = 1e-5
output, cache = batchnorm_forward(x, gamma, beta, eps)

x_t = t.tensor(x.tolist(), requires_grad=True)
bn2d = t.nn.BatchNorm2d(2)
output_t = bn2d(x_t)

print(output.round(4))
print(output_t)

dout = np.ones(x.shape)
dx,dgamma,dbeta = batchnorm_backward(dout, cache)
print(dx)
print(dgamma)
print(dbeta)


bn2d = t.nn.BatchNorm2d(2)
output_t = bn2d(x_t)
f_t = output_t.sum()
f_t.retain_grad()
output_t.retain_grad()
x_t.retain_grad()
f_t.backward()
print(x_t.grad)
print(f_t.grad)
print(output_t.grad)
print(bn2d.weight.grad)
print(bn2d.bias.grad)



[[[[ 1.0382 -0.0181]
   [-0.7247  0.7601]]

  [[ 0.2904  1.237 ]
   [ 0.6701 -0.1244]]]


 [[[-0.7049 -1.0384]
   [-1.0714  1.7593]]

  [[-1.4978  0.8259]
   [ 0.3025 -1.7037]]]]
tensor([[[[ 1.0382, -0.0181],
          [-0.7247,  0.7601]],

         [[ 0.2904,  1.2370],
          [ 0.6701, -0.1244]]],


        [[[-0.7049, -1.0384],
          [-1.0714,  1.7593]],

         [[-1.4978,  0.8259],
          [ 0.3025, -1.7037]]]], grad_fn=<NativeBatchNormBackward>)
[[[[ 4.4408921e-16  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00]]]


 [[[ 0.0000000e+00 -4.4408921e-16]
   [-4.4408921e-16  4.4408921e-16]]

  [[ 0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00]]]]
[-8.8817842e-16  0.0000000e+00]
[8. 8.]
tensor([[[[ 5.4929e-08, -9.5836e-10],
          [-3.8344e-08,  4.0215e-08]],

         [[-7.7724e-08, -3.3107e-07],
          [-1.7934e-07,  3.3305e-08]]],


        [[[-3.7295e-08, -5.4942e-08],
        

In [33]:
# Forward pass 
N,C,H,W = x.shape
mu = 1.0/N/H/W*np.sum(x,axis = (0,2,3),keepdims=True) # Size (H,) 
sigma2 = 1/N/H/W*np.sum((x-mu)**2,axis=(0,2,3), keepdims=True)# Size (H,) 
hath = (x-mu)*(sigma2+eps)**(-1./2.)
y = gamma*hath+beta
print(y.round(4))

[[[[ 1.0382 -0.0181]
   [-0.7247  0.7601]]

  [[ 0.2904  1.237 ]
   [ 0.6701 -0.1244]]]


 [[[-0.7049 -1.0384]
   [-1.0714  1.7593]]

  [[-1.4978  0.8259]
   [ 0.3025 -1.7037]]]]


In [39]:
dy = np.ones(x.shape)
mu = 1.0/N/H/W*np.sum(x, axis = (0,2,3), keepdims=True) # Size (H,) 
var = 1.0/N/H/W*np.sum((x-mu)**2, axis=(0,2,3), keepdims=True)# Size (H,) 
dbeta = np.sum(dy, axis=(0,2,3))
dgamma = np.sum((x - mu) * (var + eps)**(-1. / 2.) * dy, axis=(0,2,3))
dx = (1./ N/H/W) * gamma * (var + eps)**(-1. / 2.) * (N*H*W *dy - np.sum(dy, axis=(0,2,3), keepdims=True)
    - (x - mu) * (var + eps)**(-1.0) * np.sum(dy * (x - mu), axis=(0,2,3), keepdims=True))
print(dx)
print(dgamma)
print(dbeta)

[[[[ 2.72836982e-16 -4.76023829e-18]
   [-1.90456787e-16  1.99748982e-16]]

  [[ 0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00]]]


 [[[-1.85244992e-16 -2.72901185e-16]
   [-2.81584924e-16  4.62362163e-16]]

  [[ 0.00000000e+00  0.00000000e+00]
   [ 0.00000000e+00  0.00000000e+00]]]]
[-8.8817842e-16  0.0000000e+00]
[8. 8.]


In [37]:
[[[[ 4.4408921e-16  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00]]

  [[ 0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00]]]


 [[[ 0.0000000e+00 -4.4408921e-16]
   [-4.4408921e-16  4.4408921e-16]]

  [[ 0.0000000e+00  0.0000000e+00]
   [ 0.0000000e+00  0.0000000e+00]]]]
[-8.8817842e-16  0.0000000e+00]
[8. 8.]

SyntaxError: invalid syntax (<ipython-input-37-4a406111c088>, line 1)