# Figure out spatial batch norm forward and backwards calculations.

## Spatial Batch Normalization
We already saw that batch normalization is a very useful technique for training deep fully-connected networks. Batch normalization can also be used for convolutional networks, but we need to tweak it a bit; the modification will be called "spatial batch normalization."

Normally batch-normalization accepts inputs of shape `(N, D)` and produces outputs of shape `(N, D)`, where we normalize across the minibatch dimension `N`. For data coming from convolutional layers, batch normalization needs to accept inputs of shape `(N, C, H, W)` and produce outputs of shape `(N, C, H, W)` where the `N` dimension gives the minibatch size and the `(H, W)` dimensions give the spatial size of the feature map.

If the feature map was produced using convolutions, then we expect the statistics of each feature channel to be relatively consistent both between different imagesand different locations within the same image. Therefore spatial batch normalization computes a mean and variance for each of the `C` feature channels by computing statistics over both the minibatch dimension `N` and the spatial dimensions `H` and `W`.

### Spatial batch normalization: forward

In [17]:
import numpy as np
from cs231n.gradient_check import eval_numerical_gradient_array, eval_numerical_gradient

def rel_error(x, y):
  """ returns relative error """
  return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

In [49]:
from cs231n.layers import batchnorm_forward


def spatial_batchnorm_forward(x, gamma, beta, bn_param):
    N, C, H, W = x.shape
    x = np.transpose(x, (0, 2, 3, 1))
    x = x.reshape((N * H * W, C))
    out, cache = batchnorm_forward(x, gamma, beta, bn_param)
    out = out.reshape((N, H, W, C))
    out = np.transpose(out, (0, 3, 1, 2))
    return (out, cache)
    
    

# Data
N, C, H, W = 2, 3, 4, 5
x = 4 * np.random.randn(N, C, H, W) + 10
gamma, beta = np.ones(C), np.zeros(C)
eps = 1e-7
bn_param = {
    'mode': 'train'
}

a = spatial_batchnorm_forward(x, gamma, beta, bn_param)[0]
b = quick_spatial_batchnorm_forward(x, gamma, beta, bn_param)[0]

print "err", rel_error(a, b)

 err 2.16160744521e-07


In [50]:
def spatial_batchnorm_forward_by_hand(x, gamma, beta, bn_param):
    """
    Computes the forward pass for spatial batch normalization.                                  

    Inputs:                                                                                     
    - x: Input data of shape (N, C, H, W)                                                       
    - gamma: Scale parameter, of shape (C,)                                                     
    - beta: Shift parameter, of shape (C,)                                                      
    - bn_param: Dictionary with the following keys:                                             
    - mode: 'train' or 'test'; required                                                       
    - eps: Constant for numeric stability                                                     
    - momentum: Constant for running mean / variance. momentum=0 means that                   
      old information is discarded completely at every time step, while                       
      momentum=1 means that new information is never incorporated. The                        
      default of momentum=0.9 should work well in most situations.                            
    - running_mean: Array of shape (D,) giving running mean of features                       
    - running_var Array of shape (D,) giving running variance of features                     

    Returns a tuple of:                                                                         
    - out: Output data, of shape (N, C, H, W)                                                   
    - cache: Values needed for the backward pass                                                
    """
    eps = bn_param.get('eps', 1e-7)

    # Collect per channel stats
    mean = np.mean(x, axis=(0, 2, 3))
    var =   np.var(x, axis=(0, 2, 3))

    # Get ready to broadcast per channel stats(C) to x(N, C, H, W)
    C = x.shape[1]
    mean =   mean.reshape(1, C, 1, 1)
    var =     var.reshape(1, C, 1, 1)
    gamma = gamma.reshape(1, C, 1, 1)
    beta =   beta.reshape(1, C, 1, 1)
    
    xhat = (x - mean) / np.sqrt(var + eps)
    y = gamma * xhat + beta
    
    return (y, ())



In [47]:
def sbf(x, gamma, beta, bn_param):
    eps = bn_param.get('eps', 1e-7)
    mean = np.mean(x, axis=(0, 2, 3))
    var =   np.var(x, axis=(0, 2, 3))
    C = x.shape[1]
    x_minus_mean = x - mean.reshape(1, C, 1, 1)
    x_minus_mean_sqr = x_minus_mean * x_minus_mean
    var = np.mean(x_minus_mean_sqr, axis=(0, 2, 3))
    sqrt_var = (var + eps) ** (0.5)
    one_over_sqrt_var = 1. / sqrt_var
    xhat = x_minus_mean * one_over_sqrt_var.reshape(1, C, 1, 1)
    y = gamma.reshape(1, C, 1, 1) * xhat + beta.reshape(1, C, 1, 1)
    return y

# Numerical gradient
# dout = 5 * np.random.randn(C)
dout = 5 * np.random.randn(N, C, H, W)
fx = lambda x: spatial_batchnorm_forward(x, gamma, beta, bn_param)[0]
fgamma = lambda gamma: spatial_batchnorm_forward(x, gamma, beta, bn_param)[0]
fbeta = lambda beta: spatial_batchnorm_forward(x, gamma, beta, bn_param)[0]
dx_num = eval_numerical_gradient_array(fx, x, dout)
dgamma_num = eval_numerical_gradient_array(fgamma, gamma, dout)
dbeta_num = eval_numerical_gradient_array(fbeta, beta, dout)

# Calculated gradient

# forward
mean = np.mean(x, axis=(0, 2, 3))
x_minus_mean = x - mean.reshape(1, C, 1, 1)
x_minus_mean_sqr = x_minus_mean * x_minus_mean
var = np.mean(x_minus_mean_sqr, axis=(0, 2, 3))
sqrt_var = np.sqrt(var + eps)
one_over_sqrt_var = 1. / sqrt_var
xhat = x_minus_mean * one_over_sqrt_var.reshape(1, C, 1, 1)
y = gamma.reshape(1, C, 1, 1) * xhat + beta.reshape(1, C, 1, 1)

# back prop across beta and gamma to dxhat

# 9 + beta                                                                                  
dbeta = np.sum(dout, axis=(0, 2, 3))
dgamma_xhat = dout

# 8 xhat * gamma                                                                            
dgamma = np.sum(dgamma_xhat * xhat, axis=(0, 2, 3))
dxhat = dgamma_xhat * gamma.reshape(1, C, 1, 1)

# 7) backprop dxhat to done_over_sqrt_var and dx_minus_mean
done_over_sqrt_var = np.sum(x_minus_mean * dxhat, axis=(0, 2, 3))
dx_minus_mean_top = dxhat * one_over_sqrt_var.reshape(1, C, 1, 1)

# 6) Backprop 1/x 
dsqrt_var = -(sqrt_var ** (-2)) * done_over_sqrt_var

# 5) backprop sqrt(var + eps)
dvar = 0.5 * ((var + eps) ** (-0.5)) * dsqrt_var

# 4) Back prop mean(x - mean)
dx_minus_mean_sqr = (1. / (N * H * W)) * np.ones((N, C, H, W)) * dvar.reshape(1, C, 1, 1)

# 3) Back prop x^2
dx_minus_mean_bottom = 2 * x_minus_mean * dx_minus_mean_sqr

# 2) backprop dx_minus_mean to input x
# colect both inbound gradient 
dx_minus_mean = dx_minus_mean_bottom + dx_minus_mean_top
dx_input = dx_minus_mean
dx_mean = -np.sum(dx_minus_mean, axis=(0, 2, 3))

# 1) back prop dmean to dx
dx_via_mean = (1. / (N * H * W)) * np.ones((N, C, H, W)) * dx_mean.reshape(1, C, 1, 1)

dx = dx_input + dx_via_mean

print "dx rel_error", rel_error(dx, dx_num)
print "dgamma rel_error", rel_error(dgamma, dgamma_num)
print "dbeta rel_error", rel_error(dbeta, dbeta_num)

# print "dx_num\n", dx_num
# print "dx\n", dx

dx rel_error 9.2792051324e-07
dgamma rel_error 1.67667433096e-11
dbeta rel_error 7.91093235756e-13


In [48]:
# Check the training-time forward pass by checking means and variances
# of features both before and after spatial batch normalization

N, C, H, W = 2, 3, 4, 5
x = 4 * np.random.randn(N, C, H, W) + 10
x = x.astype(np.float64)

print 'Before spatial batch normalization:'
print '  Shape: ', x.shape
print '  Means: ', x.mean(axis=(0, 2, 3))
print '  Stds: ', x.std(axis=(0, 2, 3))

# Means should be close to zero and stds close to one
gamma, beta = np.ones(C), np.zeros(C)
bn_param = {'mode': 'train'}
out, _ = spatial_batchnorm_forward(x, gamma, beta, bn_param)
print 'After spatial batch normalization:'
print '  Shape: ', out.shape
print '  Means: ', out.mean(axis=(0, 2, 3))
print '  Stds: ', out.std(axis=(0, 2, 3))

# Means should be close to beta and stds close to gamma
gamma, beta = np.asarray([3, 4, 5]), np.asarray([6, 7, 8])
out, _ = spatial_batchnorm_forward(x, gamma, beta, bn_param)
print 'After spatial batch normalization (nontrivial gamma, beta):'
print '  Shape: ', out.shape
print '  Means: ', out.mean(axis=(0, 2, 3))
print '  Stds: ', out.std(axis=(0, 2, 3))

Before spatial batch normalization:
  Shape:  (2, 3, 4, 5)
  Means:  [ 10.05246971   9.25304674   8.73905878]
  Stds:  [ 5.02652078  3.73102912  4.3636944 ]
After spatial batch normalization:
  Shape:  (2, 3, 4, 5)
  Means:  [ -1.02695630e-16   2.22044605e-16   3.05311332e-17]
  Stds:  [ 1.  1.  1.]
After spatial batch normalization (nontrivial gamma, beta):
  Shape:  (2, 3, 4, 5)
  Means:  [ 6.  7.  8.]
  Stds:  [ 2.99999999  3.99999999  4.99999999]
