In [1]:
import numpy as np

def batchnorm_forward(x, gamma, beta, eps):
 
  N, C, H, W = x.shape
  #step1: calculate mean
  mu = 1./N/H/W * np.sum(x, axis = (0,2,3), keepdims=True)
  #step2: subtract mean vector of every trainings example
  xmu = x - mu
  #step3: following the lower branch - calculation denominator
  sq = xmu ** 2
  #step4: calculate variance
  var = 1./N/H/W * np.sum(sq, axis = (0,2,3), keepdims=True)
  #step5: add eps for numerical stability, then sqrt
  sqrtvar = np.sqrt(var + eps)
  #step6: invert sqrtwar
  ivar = 1./sqrtvar
  #step7: execute normalization
  xhat = xmu * ivar
  #step8: Nor the two transformation steps
  gammax = gamma * xhat
  #step9
  out = gammax + beta
  #store intermediate
  cache = (xhat,gamma,xmu,ivar,sqrtvar,var,eps)
 
  return out, cache

def batchnorm_backward(dout, cache):
 
  #unfold the variables stored in cache
  xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache
 
  #get the dimensions of the input/output
  N,C,H,W = dout.shape
 
  #step9
  dbeta = np.sum(dout, axis=(0,2,3))
  dgammax = dout #not necessary, but more understandable
 
  #step8
  dgamma = np.sum(dgammax*xhat, axis=(0,2,3))
  dxhat = dgammax * gamma
 
  #step7
  divar = np.sum(dxhat*xmu, axis=(0,2,3), keepdims=True)
  dxmu1 = dxhat * ivar
 
  #step6
  dsqrtvar = -1. /(sqrtvar**2) * divar   # 1,c,1,1
 
  #step5
  dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar  #1,c,1,1
 
  #step4
  dsq = 1. /N/H/W * np.ones((N,C,H,W)) * dvar
 
  #step3
  dxmu2 = 2 * xmu * dsq
 
  #step2
  dx1 = (dxmu1 + dxmu2)
  dmu = -1 * np.sum(dxmu1+dxmu2, axis=(0,2,3), keepdims=True)
 
  #step1
  dx2 = 1. /N/H/W * np.ones((N,C,H,W)) * dmu
 
  #step0
  dx = dx1 + dx2
 
  return dx, dgamma, dbeta
 

In [14]:
import torch as t
x = np.random.rand(2,2,2,2).astype(np.float32)
gamma = np.array([1,1]).astype(np.float32)
beta = np.array([0,0]).astype(np.float32)
eps = 1e-5
output, cache = batchnorm_forward(x, gamma, beta, eps)

x_t = t.tensor(x.tolist(), requires_grad=True)
bn2d = t.nn.BatchNorm2d(2)
output_t = bn2d(x_t)

dout = np.ones(x.shape)
dx,dgamma,dbeta = batchnorm_backward(dout, cache)
print(dx)
print(dgamma)
print(dbeta)


bn2d = t.nn.BatchNorm2d(2)
output_t = bn2d(x_t)
f_t = output_t.sum()
f_t.retain_grad()
output_t.retain_grad()
x_t.retain_grad()
f_t.backward()
print(x_t.grad)
print(bn2d.weight.grad)
print(bn2d.bias.grad)



[[[[-2.80549015e-08 -1.23177513e-09]
   [-2.00540158e-08  1.83476256e-08]]

  [[-2.32424240e-08 -9.26448829e-08]
   [-1.12036563e-07 -4.19860831e-08]]]


 [[[ 2.87658239e-08  3.91537083e-08]
   [-3.16735553e-08 -5.25291144e-09]]

  [[ 7.98415041e-08  5.84881761e-08]
   [ 1.18420812e-07  1.31594646e-08]]]]
[-6.70552254e-08 -4.47034836e-08]
[8. 8.]
tensor([[[[-2.8055e-08, -1.2318e-09],
          [-2.0054e-08,  1.8348e-08]],

         [[-2.3242e-08, -9.2645e-08],
          [-1.1204e-07, -4.1986e-08]]],


        [[[ 2.8766e-08,  3.9154e-08],
          [-3.1674e-08, -5.2529e-09]],

         [[ 7.9842e-08,  5.8488e-08],
          [ 1.1842e-07,  1.3159e-08]]]])
tensor([-5.4370e-08, -1.3550e-07])
tensor([8., 8.])


In [15]:
xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache
print(xhat.dtype, gamma.dtype, xmu.dtype, dout.dtype, dgamma.dtype)
print(np.sum(dout*xhat, axis=(0,2,3)))

float32 int64 float32 float64 float64
[-6.70552254e-08 -4.47034836e-08]


In [16]:
# Forward pass 
N,C,H,W = x.shape
mu = 1.0/N/H/W*np.sum(x,axis = (0,2,3),keepdims=True)
sigma2 = 1/N/H/W*np.sum((x-mu)**2,axis=(0,2,3), keepdims=True)# Size (H,) 
hath = (x-mu)/np.sqrt(sigma2+eps)#**(-1./2.)
y = gamma*hath+beta
print(y.round(4))

[[[[-1.1314 -0.0497]
   [-0.8087  0.7399]]

  [[-0.3018 -1.2031]
   [-1.4549 -0.5452]]]


 [[[ 1.16    1.5789]
   [-1.2773 -0.2118]]

  [[ 1.0368  0.7595]
   [ 1.5378  0.1709]]]]


# The forward flow of batchnorm2d
![image.png](attachment:image.png)

# The forward formula of batchnorm2d
![image.png](attachment:image.png)

# The derivative formula of batchnorm2d
![image.png](attachment:image.png)
The reference url[https://blog.csdn.net/Janet_Hu/article/details/78215951]

In [17]:
dy = np.ones(x.shape)
mu = 1.0/N/H/W*np.sum(x, axis = (0,2,3), keepdims=True) # Size (H,) 
var = 1.0/N/H/W*np.sum((x-mu)**2, axis=(0,2,3), keepdims=True)# Size (H,) 
dbeta1 = np.sum(dy, axis=(0,2,3))
dgamma1 = np.sum((x - mu) * (var + eps)**-0.5 * dy, axis=(0,2,3))

dvar = np.sum(dy*(x-mu)*(-1./2.)*(var+eps)**(-3./2.), axis=(0,2,3), keepdims=True)
dmu = np.sum(dy*-1*(var+eps)**(-1./2.), axis=(0,2,3), keepdims=True)+\
      dvar*np.sum(-2.*(x-mu), axis=(0,2,3), keepdims=True)*1.0/N/H/W
dx1 = dy*(var+eps)**(-1./2.)+dvar*2.0*(x-mu)/N/H/W+dmu*1./N/H/W


print(dx1)
print(dgamma1)
print(dbeta1)
print(np.allclose(dx, dx1))
print(np.allclose(dgamma, dgamma1))
print(np.allclose(dbeta, dbeta1))
print(np.allclose(dx, x_t.grad.detach().numpy()))
print(np.allclose(dgamma, bn2d.weight.grad.detach().numpy()))
print(np.allclose(dbeta, bn2d.bias.grad.detach().numpy()))

[[[[-2.80549011e-08 -1.23177513e-09]
   [-2.00540158e-08  1.83476256e-08]]

  [[-2.32424240e-08 -9.26448847e-08]
   [-1.12036566e-07 -4.19860831e-08]]]


 [[[ 2.87658239e-08  3.91537078e-08]
   [-3.16735549e-08 -5.25291144e-09]]

  [[ 7.98415076e-08  5.84881787e-08]
   [ 1.18420816e-07  1.31594664e-08]]]]
[-6.70552254e-08 -4.47034836e-08]
[8. 8.]
True
True
True
True
False
True


In [None]:
[[[[-2.80549015e-08 -1.23177513e-09]
   [-2.00540158e-08  1.83476256e-08]]

  [[-2.32424240e-08 -9.26448829e-08]
   [-1.12036563e-07 -4.19860831e-08]]]


 [[[ 2.87658239e-08  3.91537083e-08]
   [-3.16735553e-08 -5.25291144e-09]]

  [[ 7.98415041e-08  5.84881761e-08]
   [ 1.18420812e-07  1.31594646e-08]]]]
[-6.70552254e-08 -4.47034836e-08]
[8. 8.]
tensor([[[[-2.8055e-08, -1.2318e-09],
          [-2.0054e-08,  1.8348e-08]],

         [[-2.3242e-08, -9.2645e-08],
          [-1.1204e-07, -4.1986e-08]]],


        [[[ 2.8766e-08,  3.9154e-08],
          [-3.1674e-08, -5.2529e-09]],

         [[ 7.9842e-08,  5.8488e-08],
          [ 1.1842e-07,  1.3159e-08]]]])
tensor([-5.4370e-08, -1.3550e-07])
tensor([8., 8.])

In [18]:
dout = np.ones(x.shape, dtype=np.float32)
dgamma2 = np.sum((x - mu) * (var + eps)**(-1. / 2.) * dout, axis=(0,2,3))
dout = np.ones(x.shape, dtype=np.float64)
dgamma3 = np.sum((x - mu) * (var + eps)**(-1. / 2.) * dout, axis=(0,2,3))
print('torch:',bn2d.weight.grad)
print('others:',dgamma)
print('my float32:',dgamma2)   #float32
print('my float64:',dgamma3)   #float64
print(np.allclose(dgamma, dgamma2))
print(np.allclose(dgamma, dgamma3))

torch: tensor([-5.4370e-08, -1.3550e-07])
others: [-6.70552254e-08 -4.47034836e-08]
my float32: [-2.3841858e-07  2.3841858e-07]
my float64: [-6.70552254e-08 -4.47034836e-08]
False
True
