In [20]:
# https://gluon.mxnet.io/chapter04_convolutional-neural-networks/cnn-batch-norm-scratch.html

In [1]:
import numpy as np
import mxnet as mx
from mxnet import nd, autograd
ctx = mx.gpu()

  from ._conv import register_converters as _register_converters


In [9]:
def transform(data, label):
    return nd.transpose(data.astype(np.float32), (2, 0, 1))/255, label.astype(np.float32)

batch_size = 64
train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform), batch_size, shuffle=True)
test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform), batch_size, shuffle=False)

In [3]:
def pure_batch_norm(X, gamma, beta, eps = 1e-5):
    if len(X.shape) not in (2, 4):
        raise ValueError('Only supports matrix shape (m, input feature size) or tensor shape(m, C, H, W)')
    
    # matrix shape (m, input feature size)
    if len(X.shape) == 2:
        # mini-batch mean
        mean = nd.mean(X, axis=0)
        # variance of mini-batch
        std = nd.sqrt(nd.mean(nd.square(X-mean), axis=0))
    
    # tensor shape(m, C, H, W)
    elif len(X.shape) == 4:
        # extract the dimensions
        m, C, H, W = X.shape
        # mini-batch mean
        mean = nd.mean(X, axis=(0, 2, 3), keepdims=True)
        # variance of mini-batch
        std = nd.sqrt(nd.mean(nd.square(X-mean), axis=(0, 2, 3), keepdims=True))
        
    # normalize
    X_hat = (X - mean) * 1.0/ (std + eps)
    # scale and shift
    out = gamma * X_hat + beta
    
    return out

In [4]:
A = nd.array([1,7,5,4,6,10], ctx=ctx).reshape((3,2))
gamma = nd.array([1, 1], ctx=ctx)
beta = nd.array([0, 0], ctx=ctx)

pure_batch_norm(A, gamma, beta)


[[-1.3887237  0.       ]
 [ 0.4629079 -1.2247398]
 [ 0.9258158  1.2247398]]
<NDArray 3x2 @gpu(0)>

In [5]:
B = nd.array([1,6,5,7,4,3,2,5,6,3,2,4,5,3,2,5,6], ctx=ctx).reshape((2,2,2,2))
pure_batch_norm(B, gamma, beta)


[[[[-1.6378378   0.88191265]
   [ 0.37796256  1.3858627 ]]

  [[ 0.30779096 -0.51298493]
   [-1.3337609   1.1285669 ]]]


 [[[ 0.88191265 -0.6299376 ]
   [-1.1338876  -0.12598751]]

  [[ 1.1285669  -0.51298493]
   [-1.3337609   1.1285669 ]]]]
<NDArray 2x2x2x2 @gpu(0)>

![](https://res.cloudinary.com/dqagyeboj/image/upload/v1537335619/Untitled27_xlpayu.png)

In [19]:
def batch_norm(X, gamma, beta, bn_ewv_means, bn_ewv_vars, momentum=0.9, eps=1e-5, layer_name='', is_training=True, debug=False):
    if len(X.shape) not in (2, 4):
        raise ValueError('Only supports matrix shape (m, input feature size) or tensor shape(m, C, H, W)')
    
    if len(X.shape) == 2:
        # mini-batch mean
        mean = nd.mean(X, axis=0)
        # variance of mini-batch
        variance = nd.mean(nd.square(X-mean), axis=0)
        
        # normalize
        if is_training:
            # while training, using normalized mean and variance
            X_hat = (X - mean) / nd.sqrt(variance + eps)
        else:
            # while testing , using the pre-computed mean and variance
            X_hat = (X - bn_ewv_means[layer_name]) * 1.0 / nd.sqrt(bn_ewv_vars[layer_name] + eps)
        # scale and shift
        out = gamma * X_hat + beta
        
    elif len(X.shape) == 4:
        # extract the dimensions
        m, C, H, W = X.shape
        # mini-batch mean
        mean = nd.mean(X, axis=(0, 2, 3), keepdims=True)
        # variance of mini-batch
        variance = nd.mean(nd.square(X-mean), axis=(0, 2, 3), keepdims=True)
        
        # normalize
        if is_training:
            # while training, using normalized mean and variance
            X_hat = (X - mean) / nd.sqrt(variance + eps)
        else:
            # while testing , using the pre-computed mean and variance
            X_hat = (X - bn_ewv_means[layer_name]) * 1.0 / nd.sqrt(bn_ewv_vars[layer_name] + eps)
        # scale and shift
        out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))
    
    # Handle batch norm exponentially weighted average means and variance
    if layer_name not in bn_ewv_means:
        bn_ewv_means[layer_name] = mean
    else:
        bn_ewv_means[layer_name] = momentum * bn_ewv_means[layer_name] + (1 - momentum) * mean
        
    if layer_name not in bn_ewv_vars:
        bn_ewv_vars[layer_name] = variance
    else:
        bn_ewv_vars[layer_name] = momentum * bn_ewv_vars[layer_name] + (1 - momentum) * variance
    
    if debug:
        print('== info start ==')
        print('layer_name = {}'.format(layer_name))
        print('mean = {}'.format(mean))
        print('var = {}'.format(variance))
        print('bn_ewv_means = {}'.format(bn_ewv_means[layer_name]))
        print('bn_ewv_vars = {}'.format(bn_ewv_vars[layer_name]))
        print('output = {}'.format(out))
        print('== info end ==')
        
    return out

In [11]:
weight_scale = 0.01
num_fc = 128
num_inputs = 784
num_outputs = 10

W1 = nd.random_normal(shape=(20, 1, 3,3), scale=weight_scale, ctx=ctx)
b1 = nd.random_normal(shape=20, scale=weight_scale, ctx=ctx)

gamma1 = nd.random_normal(shape=20, loc=1, scale=weight_scale, ctx=ctx)
beta1 = nd.random_normal(shape=20, scale=weight_scale, ctx=ctx)

W2 = nd.random_normal(shape=(50, 20, 5, 5), scale=weight_scale, ctx=ctx)
b2 = nd.random_normal(shape=50, scale=weight_scale, ctx=ctx)

gamma2 = nd.random_normal(shape=50, loc=1, scale=weight_scale, ctx=ctx)
beta2 = nd.random_normal(shape=50, scale=weight_scale, ctx=ctx)

W3 = nd.random_normal(shape=(800, num_fc), scale=weight_scale, ctx=ctx)
b3 = nd.random_normal(shape=num_fc, scale=weight_scale, ctx=ctx)

gamma3 = nd.random_normal(shape=num_fc, loc=1, scale=weight_scale, ctx=ctx)
beta3 = nd.random_normal(shape=num_fc, scale=weight_scale, ctx=ctx)

W4 = nd.random_normal(shape=(num_fc, num_outputs), scale=weight_scale, ctx=ctx)
b4 = nd.random_normal(shape=10, scale=weight_scale, ctx=ctx)

params = [W1, b1, gamma1, beta1, W2, b2, gamma2, beta2, W3, b3, gamma3, beta3, W4, b4]

In [12]:
for param in params:
    param.attach_grad()

In [13]:
def relu(X):
    return nd.maximum(X, 0)

In [14]:
def softmax(y_linear):
    exp = nd.exp(y_linear-nd.max(y_linear))
    partition = nd.nansum(exp, axis=0, exclude=True).reshape((-1,1))
    return exp / partition

In [15]:
def softmax_cross_entropy(yhat_linear, y):
    return - nd.nansum(y * nd.log_softmax(yhat_linear), axis=0, exclude=True)

In [16]:
def net(X, bn_ewv_means, bn_ewv_vars, is_training = True, debug=False):
    ########################
    #  Define the computation of the first convolutional layer
    ########################
    h1_conv = nd.Convolution(data=X, weight=W1, bias=b1, kernel=(3,3), num_filter=20)
    h1_normed = batch_norm(h1_conv, gamma1, beta1, bn_ewv_means, bn_ewv_vars, layer_name='bn1', is_training=is_training)
    h1_activation = relu(h1_normed)
    h1 = nd.Pooling(data=h1_activation, pool_type="avg", kernel=(2,2), stride=(2,2))
    if debug:
        print("h1 shape: %s" % (np.array(h1.shape)))

    ########################
    #  Define the computation of the second convolutional layer
    ########################
    h2_conv = nd.Convolution(data=h1, weight=W2, bias=b2, kernel=(5,5), num_filter=50)
    h2_normed = batch_norm(h2_conv, gamma2, beta2, bn_ewv_means, bn_ewv_vars, layer_name='bn2', is_training=is_training)
    h2_activation = relu(h2_normed)
    h2 = nd.Pooling(data=h2_activation, pool_type="avg", kernel=(2,2), stride=(2,2))
    if debug:
        print("h2 shape: %s" % (np.array(h2.shape)))

    ########################
    #  Flattening h2 so that we can feed it into a fully-connected layer
    ########################
    h2 = nd.flatten(h2)
    if debug:
        print("Flat h2 shape: %s" % (np.array(h2.shape)))

    ########################
    #  Define the computation of the third (fully-connected) layer
    ########################
    h3_linear = nd.dot(h2, W3) + b3
    h3_normed = batch_norm(h3_linear, gamma3, beta3, bn_ewv_means, bn_ewv_vars, layer_name='bn3', is_training=is_training)
    h3 = relu(h3_normed)
    if debug:
        print("h3 shape: %s" % (np.array(h3.shape)))

    ########################
    #  Define the computation of the output layer
    ########################
    yhat_linear = nd.dot(h3, W4) + b4
    if debug:
        print("yhat_linear shape: %s" % (np.array(yhat_linear.shape)))

    return yhat_linear

In [17]:
bn_ewv_means, bn_ewv_vars = {}, {}
for data, _ in train_data:
    data = data.as_in_context(ctx)
    break

In [20]:
output = net(data, bn_ewv_means, bn_ewv_vars, is_training=True, debug=True)

h1 shape: [64 20 13 13]
h2 shape: [64 50  4  4]
Flat h2 shape: [ 64 800]
h3 shape: [ 64 128]
yhat_linear shape: [64 10]


In [21]:
def SGD(params, lr):
    for param in params:
        param[:] = param - lr * param.grad

In [22]:
def evaluate_accuracy(data_iterator, net, bn_ewv_means, bn_ewv_vars):
    numerator = 0.
    denominator = 0.
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        label_one_hot = nd.one_hot(label, 10)
        output = net(data, bn_ewv_means, bn_ewv_vars, is_training=False) # attention here!
        predictions = nd.argmax(output, axis=1)
        numerator += nd.sum(predictions == label)
        denominator += data.shape[0]
    return (numerator / denominator).asscalar()

In [23]:
epochs = 1
moving_loss = 0.
learning_rate = .001
bn_ewv_means, bn_ewv_vars = {}, {}

for e in range(epochs):
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        label_one_hot = nd.one_hot(label, num_outputs)
        with autograd.record():
            # we are in training process,
            # so we normalize the data using batch mean and variance
            output = net(data, bn_ewv_means, bn_ewv_vars, is_training=True)
            loss = softmax_cross_entropy(output, label_one_hot)
        loss.backward()
        SGD(params, learning_rate)

        ##########################
        #  Keep a moving average of the losses
        ##########################
        if i == 0:
            moving_loss = nd.mean(loss).asscalar()
        else:
            moving_loss = .99 * moving_loss + .01 * nd.mean(loss).asscalar()
    
    # test or train accuracy, it will take the last mean and variance
    test_accuracy = evaluate_accuracy(test_data, net, bn_ewv_means, bn_ewv_vars)
    train_accuracy = evaluate_accuracy(train_data, net, bn_ewv_means, bn_ewv_vars)
    print("Epoch %s. Loss: %s, Train_acc %s, Test_acc %s" % (e, moving_loss, train_accuracy, test_accuracy))

Epoch 0. Loss: 0.06240094275482456, Train_acc 0.9881667, Test_acc 0.989
