In [1]:
import torch

from torch import nn

## Step 1: We will first create a batch of 2 image-like inputs containing channels 2 x width 3 x height 3. The batch shape should be [2, 2, 3, 3]

## Image-Like random channels are created using torch.normal function which takes mean, standard deviation, size attributes.

## For random channel 1, a mean of 2 and std deviation of 3 are used. 

## For random channel 2, a mean of 3 and std deviation of 2 are used. 

In [2]:
img1_ch1 = torch.normal( 2, 3, size=(3, 3) )
img1_ch2 = torch.normal( 3, 2, size=(3, 3) )

img1 = torch.stack( (img1_ch1, img1_ch2), dim=0 )


img2_ch1 = torch.normal( 2, 3, size=(3, 3))
img2_ch2 = torch.normal( 3, 2, size=(3, 3))

img2 = torch.stack( (img2_ch1, img2_ch2), dim=0 )

batch = torch.stack( (img1, img2), dim=0 )

print( batch.shape )

batch

torch.Size([2, 2, 3, 3])


tensor([[[[ 0.7420, -1.8935,  4.7950],
          [ 2.1419,  1.8809,  5.0003],
          [ 4.2511,  1.8103, -5.3357]],

         [[ 3.0039,  5.6690,  1.4707],
          [ 2.8406, -0.7124,  4.6247],
          [ 0.2142,  4.1729,  3.5648]]],


        [[[ 3.6382, -1.7946,  0.0931],
          [-1.4709,  5.2376,  2.2670],
          [ 6.0320,  5.0774,  0.6713]],

         [[ 2.2279,  2.8390,  3.7729],
          [-0.2636,  0.0373,  2.5901],
          [ 5.9240,  4.0341,  4.1209]]]])

## Step 2: We will compute the 2d Batch Normalization manually using the below formula:

<div style="background-color:white">
    <img src="./imgs/normalization_formula.png" />
</div>

## While γ and β are the learnable scale and shift parameters, we will use γ=1 and β=0 values which are also the default values for the Batch Norm2D. Also, we will use ϵ = 0.00001 which is again the default value for Batch Norm2D

## The tensor 'mean' function is used to compute the mean of the values of each 'channel' dimension across the entire batch, as shown below. Likewise, the tensor 'var' function is used to compute the variance of the values of each 'channel' dimension across the entire batch, as shown below.

In [3]:
epsilon=1e-05

'''
The tensor 'mean' function is used to compute the mean of the values of each 'channel' dimension across the entire batch, as shown below.
'''
mean = batch.mean( [0,2,3] )


'''
The tensor 'var' function is used to compute the variance of the values of each 'channel' dimension across the entire batch, as shown below.
'''
var = batch.var( [0,2,3], unbiased=False )


'''
Note that shape of the evaluated mean and var as shown below
'''

print( "Shape of Evaluated Mean: {}\n".format( mean.shape ) )
print( "Evaluated Mean:\n {}".format( mean ) )

print( "Shape of Evaluated Variance: {}\n".format( var.shape ) )
print( "Evaluated Variance:\n {}".format( var ) )

Shape of Evaluated Mean: torch.Size([2])

Evaluated Mean:
 tensor([1.8413, 2.7851])
Shape of Evaluated Variance: torch.Size([2])

Evaluated Variance:
 tensor([9.0480, 3.6792])


## Step 3: The evaluated mean and variance tensors will have to be reshaped so that they can be used in element-wise computations, as in the normalization formula above.

In [4]:
'''
We will reshape the mean and variance tensors so that they can be used in element-wise computations, as in the normalization formula above.
'''
reshaped_mean =  mean[None, :, None, None]
reshaped_var = var[None, :, None, None] 

print( "Shape of Reshaped Mean: {}\n".format( reshaped_mean.shape ) )
print( "Reshaped Mean:\n\n {}\n".format( reshaped_mean ) )
print( "\n####################################\n" )
print( "Shape of Reshaped Variance: {}\n".format( reshaped_var.shape ) )
print( "Reshaped Variance:\n\n {}".format( reshaped_var ) )

Shape of Reshaped Mean: torch.Size([1, 2, 1, 1])

Reshaped Mean:

 tensor([[[[1.8413]],

         [[2.7851]]]])


####################################

Shape of Reshaped Variance: torch.Size([1, 2, 1, 1])

Reshaped Variance:

 tensor([[[[9.0480]],

         [[3.6792]]]])


## Step 4: We will normalize the batch manually by evaluating the formula.

## Note that broadcasting of the values will be automatically applied when element-wise computations are carried out, when any corresponding dimensions such as height and width do not match.

## γ=1 and β=0 values are used in the manually evaluated formula.

In [5]:
'''
Note that broadcasting of the values will be automatically applied when element-wise computations are carried out, when any corresponding dimensions such as height and width do not match.

Note that γ=1 and β=0 values are used in the below formula.
'''

manually_normalized = ( batch - reshaped_mean ) / (torch.sqrt( reshaped_var + epsilon ) )

## Manually normalized batch of input tensors must match the tensors normalized by batch normalization 2D module.

In [6]:
'''
Manually normalized batch of input tensors must match the tensors normalized by batch normalization 2D module.
'''

print( "Manually Normalized Batch:\n\n {}".format( manually_normalized ) )

Manually Normalized Batch:

 tensor([[[[-0.3654, -1.2416,  0.9819],
          [ 0.0999,  0.0132,  1.0502],
          [ 0.8011, -0.0103, -2.3860]],

         [[ 0.1141,  1.5035, -0.6852],
          [ 0.0290, -1.8234,  0.9591],
          [-1.3403,  0.7235,  0.4065]]],


        [[[ 0.5974, -1.2088, -0.5812],
          [-1.1011,  1.1291,  0.1415],
          [ 1.3932,  1.0758, -0.3890]],

         [[-0.2905,  0.0281,  0.5150],
          [-1.5894, -1.4325, -0.1016],
          [ 1.6365,  0.6512,  0.6965]]]])


## Step 5: An instance of PyTorch Batch Norm 2D module will be used to normalize the input batch.

## An instance of the Batch Norm 2D module is created as below by indicating that the input batch has 2 channel dimensions.

## In addition, the 'affine' boolean property is set to False to indicate that the scale and shift parameters need not be applied for this simple example case. Batch Norm 2D module by default applied the scale and shift parameters which are updated during training.

In [7]:
'''
An instance of the Batch Norm 2D module is created as below by indicating that the input batch has 2 channel dimensions. In addition, the 'affine' boolean property is set to False 
to indicate that the scale and shift parameters need not be applied for this simple example case. Batch Norm 2D module by default applied the scale and shift parameters which are updated during training.
'''
bnorm2d = nn.BatchNorm2d( 2, affine=False ) 

'''
Batches of inputs are then normalized by the module instance as follows.
'''
bn_model_normalized = bnorm2d( batch )

## Step 6: Compare the manually normalized and module normalized batches of inputs to verify that normalized tensor values match.

## As can be seen from the display of the normalized batches, both manually normalized and module-normalized batches of input tensors match in values.

In [8]:
'''
Note that both manually normalized and module-normalized batches of input tensors match in values.  
'''

print( "Module Normalized Batch:\n\n {}".format( bn_model_normalized ) )

Module Normalized Batch:

 tensor([[[[-0.3654, -1.2416,  0.9819],
          [ 0.0999,  0.0132,  1.0502],
          [ 0.8011, -0.0103, -2.3860]],

         [[ 0.1141,  1.5035, -0.6852],
          [ 0.0290, -1.8234,  0.9591],
          [-1.3403,  0.7235,  0.4065]]],


        [[[ 0.5974, -1.2088, -0.5812],
          [-1.1011,  1.1291,  0.1415],
          [ 1.3932,  1.0758, -0.3890]],

         [[-0.2905,  0.0281,  0.5150],
          [-1.5894, -1.4325, -0.1016],
          [ 1.6365,  0.6512,  0.6965]]]])


In [9]:
if torch.allclose( manually_normalized, bn_model_normalized ):
    print( "SUCCESS: Both manually normalized and module normalized batches match in values" )
else:
    raise Exception("ERROR: Manually normalized and module normalized batches DO NOT match in values" )

SUCCESS: Both manually normalized and module normalized batches match in values


## Step 7: Verify that the mean and standard deviation of each channel in the normalized batch is zero and one respectively

## ( Mean and Standard Deviation are expected to be very close to 0 and 1).

In [10]:
print( "Normalized Batch Channel 1 Mean: {}\n".format( torch.mean( bn_model_normalized[ :, 0, :, : ] ) ) )
print( "Normalized Batch Channel 1 Std Deviation: {}\n".format( torch.std( bn_model_normalized[ :, 0, :, : ] ) ) )
print( "\n####################################\n" )
print( "Normalized Batch Channel 2 Mean: {}\n".format( torch.mean( bn_model_normalized[ :, 1, :, : ] ) ) )
print( "Normalized Batch Channel 2 Std Deviation: {}\n".format( torch.std( bn_model_normalized[ :, 1, :, : ] ) ) )

Normalized Batch Channel 1 Mean: 6.622738357719982e-09

Normalized Batch Channel 1 Std Deviation: 1.0289908647537231


####################################

Normalized Batch Channel 2 Mean: 9.934107758624577e-09

Normalized Batch Channel 2 Std Deviation: 1.0289901494979858

