### Import and install stuff


In [None]:
%%capture
import torch
import torch.nn as nn
import numpy as np
!sudo apt-get install libmlpack-dev 
torch.set_printoptions(precision=6)

### nn.InstanceNorm2d

In [None]:
# Input - N=2, C=3, H=3, W=2
A = torch.tensor([
                          [ 
                            [ 
                              [ 1.,   2.],
                              [ 3.,   4.],
                              [ 5.,   6.]
                            ],
                          
                            [ 
                              [ 7.,   8. ],
                              [ 9.,   10.],
                              [11.,   12.]
                            ],
                           
                            [ 
                              [ 13.,   14.],
                              [ 15.,   16.],
                              [ 17.,   18.]
                            ]                           
                          ],
                          [ 
                            [ 
                              [ 19.,   20.],
                              [ 21.,   22.],
                              [ 23.,   24.]
                            ],
                          
                            [ 
                              [ 25.,   26.],
                              [ 27.,   28.],
                              [ 29.,   30.]
                            ],
                           
                            [ 
                              [ 31.,   32.],
                              [ 33.,   34.],
                              [ 35.,   36.]
                            ]                           
                          ]                      
                       ])

layer = nn.InstanceNorm2d(3, affine=True, track_running_stats=True)

print('Before Instance Normalisation : ')
print('--------------------------------')
print('Input shape : {}' .format(A.shape))
print('--------------------------------')
print('Running mean: {}' .format(layer.running_mean))
print('Running variance: {}' .format(layer.running_var))
print('--------------------------------')

output = layer(torch.FloatTensor(A))

print('After Instance Normalisation : ')
print('--------------------------------')
print('Output: {}' .format(output.detach()))
print('--------------------------------')
print('Output shape : {}' .format(output.shape))
print('--------------------------------')
print('Running mean: {}' .format(layer.running_mean))
print('Running variance: {}' .format(layer.running_var))
print('--------------------------------')
print('Gamma : {}'.format(layer.weight.detach()))
print('Beta : {}'.format(layer.bias.detach()))
print('--------------------------------')

Before Instance Normalisation : 
--------------------------------
Input shape : torch.Size([2, 3, 3, 2])
--------------------------------
Running mean: tensor([0., 0., 0.])
Running variance: tensor([1., 1., 1.])
--------------------------------
After Instance Normalisation : 
--------------------------------
Output: tensor([[[[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]]],


        [[[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]]]])
--------------------------------
Output shape : torch.S

### Direct numpy implementation of Instance Norm forward() from PyTorch source code [here](https://github.com/pytorch/pytorch/blob/e7fe64f6a65cd427e503491f192c14476e18033b/caffe2/python/hypothesis_test.py#L2176)

- This is used to verify the calculation of mean and variance values.

In [None]:
#np.random.seed(1701)
#scale = np.random.rand(input_channels).astype(np.float32) + 0.5
#bias = np.random.rand(input_channels).astype(np.float32) - 0.5
#X = np.random.rand(batch_size, input_channels, size_a, size_b).astype(np.float32) - 0.5

input_channels = 3
batch_size = 2
size_a = 3
size_b = 2
epsilon = 1e-5
scale = np.ones(input_channels).astype(np.float32)
bias = np.zeros(input_channels).astype(np.float32)
X = np.array([
                          [ 
                            [ 
                              [ 1.,   2.],
                              [ 3.,   4.],
                              [ 5.,   6.]
                            ],
                          
                            [ 
                              [ 7.,   8. ],
                              [ 9.,   10.],
                              [11.,   12.]
                            ],
                           
                            [ 
                              [ 13.,   14.],
                              [ 15.,   16.],
                              [ 17.,   18.]
                            ]                           
                          ],
                          [ 
                            [ 
                              [ 19.,   20.],
                              [ 21.,   22.],
                              [ 23.,   24.]
                            ],
                          
                            [ 
                              [ 25.,   26.],
                              [ 27.,   28.],
                              [ 29.,   30.]
                            ],
                           
                            [ 
                              [ 31.,   32.],
                              [ 33.,   34.],
                              [ 35.,   36.]
                            ]                           
                          ]                      
                       ])
X = X.astype(np.float32)


def ref_nchw(x, scale, bias):

    x = x.reshape(batch_size * input_channels, size_a * size_b)
    y = (x - x.mean(1)[:, np.newaxis])
    y /= np.sqrt(x.var(1) + epsilon)[:, np.newaxis]
    y = y.reshape(batch_size, input_channels, size_a, size_b)
    y = y * scale.reshape(1, input_channels, 1, 1)
    y = y + bias.reshape(1, input_channels, 1, 1)

    print('--------------------------------')
    print('Input mean : ')
    print(x.mean(1))
    print('--------------------------------')
    print('Input variance : ')
    print(x.var(1))
    print('--------------------------------')

    return (y, )

y = ref_nchw(X,scale,bias)

--------------------------------
Input mean : 
[ 3.5  9.5 15.5 21.5 27.5 33.5]
--------------------------------
Input variance : 
[2.9166667 2.9166667 2.9166667 2.9166667 2.9166667 2.9166667]
--------------------------------


### nn.BatchNorm2d with reshaped input - emulates the idea followed in PyTorch source code.

- This idea is used in the mlpack implementation, where the Instance Norm just acts as a wrapper class for the BatchNorm class.

In [None]:
# Input - N=2, C=3, H=3, W=2
B = torch.tensor([
                          [ 
                            [ 
                              [ 1.,   2.],
                              [ 3.,   4.],
                              [ 5.,   6.]
                            ],
                          
                            [ 
                              [ 7.,   8. ],
                              [ 9.,   10.],
                              [11.,   12.]
                            ],
                           
                            [ 
                              [ 13.,   14.],
                              [ 15.,   16.],
                              [ 17.,   18.]
                            ]                           
                          ],
                          [ 
                            [ 
                              [ 19.,   20.],
                              [ 21.,   22.],
                              [ 23.,   24.]
                            ],
                          
                            [ 
                              [ 25.,   26.],
                              [ 27.,   28.],
                              [ 29.,   30.]
                            ],
                           
                            [ 
                              [ 31.,   32.],
                              [ 33.,   34.],
                              [ 35.,   36.]
                            ]                           
                          ]                      
                       ])


C = B.contiguous().view(1, 6, 3, 2);
layer = nn.BatchNorm2d(6, affine=True, track_running_stats=True)

print('Before Instance Normalisation : ')
print('--------------------------------')
print('Input shape : {}' .format(B.shape))
print('Input reshape : {}' .format(C.shape))
print('--------------------------------')
print('Running mean: {}' .format(layer.running_mean))
print('Running variance: {}' .format(layer.running_var))
print('--------------------------------')

output = layer(torch.FloatTensor(C))

print('After Instance Normalisation : ')
print('--------------------------------')
print('Output: {}' .format(output.detach()))
print('--------------------------------')
print('Output shape : {}' .format(output.shape))
print('--------------------------------')
print('These have to be reshaped to the format (3,2) from this (6,1) shape and then mean has to be taken across the rows.')
print('Running mean: {}' .format(layer.running_mean))
print('Running variance: {}' .format(layer.running_var))
print('--------------------------------')
print('Gamma : {}'.format(layer.weight.detach()))
print('Beta : {}'.format(layer.bias.detach()))
print('--------------------------------')

Before Instance Normalisation : 
--------------------------------
Input shape : torch.Size([2, 3, 3, 2])
Input reshape : torch.Size([1, 6, 3, 2])
--------------------------------
Running mean: tensor([0., 0., 0., 0., 0., 0.])
Running variance: tensor([1., 1., 1., 1., 1., 1.])
--------------------------------
After Instance Normalisation : 
--------------------------------
Output: tensor([[[[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.463848]],

         [[-1.463848, -0.878309],
          [-0.292770,  0.292770],
          [ 0.878309,  1.4638

### mlpack


In [None]:
%%capture
%%writefile test.cpp
#include <iostream>
#include <armadillo>

using namespace std;
using namespace arma;

int main()
{
  ///////////////////////////ANN LAYER TEST (USER INPUT)////////////////////////
  arma::mat input;
  input << 1  << 19  << arma::endr
        << 2  << 20  << arma::endr
        << 3  << 21  << arma::endr
        << 4  << 22  << arma::endr
        << 5  << 23  << arma::endr
        << 6  << 24  << arma::endr
        << 7  << 25  << arma::endr
        << 8  << 26  << arma::endr
        << 9  << 27  << arma::endr
        << 10 << 28  << arma::endr
        << 11 << 29  << arma::endr
        << 12 << 30  << arma::endr
        << 13 << 31  << arma::endr
        << 14 << 32  << arma::endr
        << 15 << 33  << arma::endr
        << 16 << 34  << arma::endr
        << 17 << 35  << arma::endr
        << 18 << 36  << arma::endr;

  size_t size = 3; // number of channels
  const double eps = 1e-5;
  const double momentum = 0.1;
  //////////////////////////////////////////////////////////////////////////////

  ///////////////////////////INSTANCE NORM FORWARD//////////////////////////////
  const size_t shapeA = input.n_rows;
  const size_t shapeB = input.n_cols;
  const size_t shapeC = size;
  arma::mat runningTemp = arma::zeros(shapeC, 1);
  size *= input.n_cols;
  input = arma::vectorise(input);

    /////////////////////////BATCH NORM RESET + FORWARD/////////////////////////
    arma::mat weights, runningMean, runningVariance;
    weights.set_size(size + size, 1); 
    runningMean.zeros(size, 1); 
    runningVariance.ones(size, 1); 
    arma::mat gamma, beta;
    gamma = arma::mat(weights.memptr(), size, 1, false, false);  
    beta = arma::mat(weights.memptr() + gamma.n_elem, size, 1, false, false); 
    gamma.fill(1.0);
    beta.fill(0.0);
    const size_t batchSize = input.n_cols;
    const size_t inputSize = input.n_rows / size;
    arma::mat output;
    output.set_size(arma::size(input));
    arma::cube inputTemp(const_cast<arma::mat&>(input).memptr(), inputSize, size, batchSize, false, false);
    arma::cube outputTemp(const_cast<arma::mat&>(output).memptr(), inputSize, size, input.n_cols, false, false);
    outputTemp = inputTemp;
    arma::mat mean = arma::mean(arma::mean(inputTemp, 2), 0);
    arma::mat variance = arma::mean(arma::mean(arma::pow(inputTemp.each_slice() - arma::repmat(mean,inputSize, 1), 2), 2), 0);
    outputTemp.each_slice() -= arma::repmat(mean, inputSize, 1);
    arma::cube inputMean;
    inputMean.set_size(arma::size(inputTemp));
    inputMean = outputTemp;
    outputTemp.each_slice() /= arma::sqrt(arma::repmat(variance, inputSize, 1) + eps);
    arma::cube normalized;
    normalized.set_size(arma::size(inputTemp));
    normalized = outputTemp;
    outputTemp.each_slice() %= arma::repmat(gamma.t(),inputSize, 1);
    outputTemp.each_slice() += arma::repmat(beta.t(), inputSize, 1);
    double nElements = 1.0 / (input.n_elem - size + eps);
    runningMean = (1 - momentum) * runningMean + momentum * mean.t();
    runningVariance = (1 - momentum) * runningVariance + input.n_elem * nElements * momentum * variance.t();
    //////////////////////////////////////////////////////////////////////////////
 
  input.reshape(shapeA, shapeB);
  output.reshape(shapeA, shapeB);
  runningMean.reshape(shapeC, shapeB);
  runningVariance.reshape(shapeC, shapeB);
  runningTemp = arma::mean(runningMean, 1);
  runningMean.set_size(shapeC, 1);
  runningMean = runningTemp;
  runningTemp = arma::mean(runningVariance, 1);
  runningVariance.set_size(shapeC, 1);
  runningVariance = runningTemp;
  mean.reshape(shapeC, shapeB);
  //////////////////////////////////////////////////////////////////////////////
 
  ///////////////////////////ANN LAYER TEST (USER INPUT)////////////////////////
  arma::mat gy = output;
  //////////////////////////////////////////////////////////////////////////////

  ///////////////////////////INSTANCE NORM BACKWARD/////////////////////////////
  gy = arma::vectorise(gy);
  input = arma::vectorise(input);
 
    ///////////////////////////BATCH NORM BACKWARD////////////////////////////////
    arma::mat g;
    const arma::mat stdInv = 1.0 / arma::sqrt(variance + eps);
    g.set_size(arma::size(input));
    arma::cube gyTemp(const_cast<arma::mat&>(gy).memptr(), input.n_rows / size, size, input.n_cols, false, false);
    arma::cube gTemp(const_cast<arma::mat&>(g).memptr(), input.n_rows / size, size, input.n_cols, false, false);
    arma::cube norm = gyTemp.each_slice() % arma::repmat(gamma.t(), input.n_rows / size, 1);
    arma::mat temp = arma::sum(norm % inputMean, 2);
    arma::mat vars = temp % arma::repmat(arma::pow(stdInv, 3), input.n_rows / size, 1) * -0.5;
    gTemp = (norm.each_slice() % arma::repmat(stdInv, input.n_rows / size, 1) + (inputMean.each_slice() % vars * 2)) / input.n_cols;
    arma::mat normTemp = arma::sum(norm.each_slice() %arma::repmat(-stdInv, input.n_rows / size, 1) , 2) / input.n_cols;
    gTemp.each_slice() += normTemp;
    //////////////////////////////////////////////////////////////////////////////
 
  input.reshape(shapeA, shapeB);
  output.reshape(shapeA, shapeB);
  g.reshape(shapeA, shapeB);
  gy.reshape(shapeA, shapeB);
  //////////////////////////////////////////////////////////////////////////////
 
  cout << "-----------------------------------" << endl;
  mean.print("Input mean: ");
  cout << "-----------------------------------" << endl;
  variance.print("Input variance: ");
  cout << "-----------------------------------" << endl;
  output.print("Output: ");
  cout << "-----------------------------------" << endl;
  runningMean.print("Running Mean: ");
  cout << "-----------------------------------" << endl;
  runningVariance.print("Running Variance: ");
  cout << "-----------------------------------" << endl;
  g.print("g: ");
  cout << "-----------------------------------" << endl;
  gy.print("gy: ");
  cout << "-----------------------------------" << endl;
  cout << "Sum of values in g matrix : " << arma::accu(g) << endl;
  cout << "-----------------------------------" << endl;

  return 0;
}

In [None]:
%%script bash
g++ test.cpp -o test -larmadillo && ./test

-----------------------------------
Input mean: 
    3.5000   21.5000
    9.5000   27.5000
   15.5000   33.5000
-----------------------------------
Input variance: 
   2.9167   2.9167   2.9167   2.9167   2.9167   2.9167
-----------------------------------
Output: 
  -1.4638  -1.4638
  -0.8783  -0.8783
  -0.2928  -0.2928
   0.2928   0.2928
   0.8783   0.8783
   1.4638   1.4638
  -1.4638  -1.4638
  -0.8783  -0.8783
  -0.2928  -0.2928
   0.2928   0.2928
   0.8783   0.8783
   1.4638   1.4638
  -1.4638  -1.4638
  -0.8783  -0.8783
  -0.2928  -0.2928
   0.2928   0.2928
   0.8783   0.8783
   1.4638   1.4638
-----------------------------------
Running Mean: 
   1.2500
   1.8500
   2.4500
-----------------------------------
Running Variance: 
   1.2500
   1.2500
   1.2500
-----------------------------------
g: 
   1.8367   1.8367
   0.3967   0.3967
   0.0147   0.0147
  -0.0147  -0.0147
  -0.3967  -0.3967
  -1.8367  -1.8367
   1.8367   1.8367
   0.3967   0.3967
   0.0147   0.0147
  -0.0147  -0.01