In [1]:
import torch
import torch.nn as nn

In [19]:
# Example 2D input tensor (Batch size: 2, Channels: 3, Height: 4, Width: 4)
input_tensor = torch.tensor([[[[1.0, 2.0, 3.0, 4.0], 
                                [5.0, 6.0, 7.0, 8.0], 
                                [9.0, 10.0, 11.0, 12.0], 
                                [13.0, 14.0, 15.0, 16.0]],
                               [[17.0, 18.0, 19.0, 20.0], 
                                [21.0, 22.0, 23.0, 24.0], 
                                [25.0, 26.0, 27.0, 28.0], 
                                [29.0, 30.0, 31.0, 32.0]],
                               [[33.0, 34.0, 35.0, 36.0], 
                                [37.0, 38.0, 39.0, 40.0], 
                                [41.0, 42.0, 43.0, 44.0], 
                                [45.0, 46.0, 47.0, 48.0]]],
                              [[[49.0, 50.0, 51.0, 52.0], 
                                [53.0, 54.0, 55.0, 56.0], 
                                [57.0, 58.0, 59.0, 60.0], 
                                [61.0, 62.0, 63.0, 64.0]],
                               [[65.0, 66.0, 67.0, 68.0], 
                                [69.0, 70.0, 71.0, 72.0], 
                                [73.0, 74.0, 75.0, 76.0], 
                                [77.0, 78.0, 79.0, 80.0]],
                               [[81.0, 82.0, 83.0, 84.0], 
                                [85.0, 86.0, 87.0, 88.0], 
                                [89.0, 90.0, 91.0, 92.0], 
                                [93.0, 94.0, 95.0, 96.0]]]])

In [3]:
input_tensor.shape

torch.Size([2, 3, 4, 4])

In [20]:
# Batch normalization layer for 2D inputs
batch_norm = nn.BatchNorm2d(num_features=3)  # 3 channels

In [21]:
# Print out the names and tensors of the parameters in batch_norm
for name, param in batch_norm.named_parameters():
    print(f'Name: {name}')
    print(f'Tensor: {param}')

Name: weight
Tensor: Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
Name: bias
Tensor: Parameter containing:
tensor([0., 0., 0.], requires_grad=True)


In [22]:
# Calculate the number of learnable parameters
num_params = sum(p.numel() for p in batch_norm.parameters())

print(f'Number of learnable parameters: {num_params}')

Number of learnable parameters: 6


In [23]:
# Update weights and biases
batch_norm.weight.data = torch.tensor([0.7, 0.4, 0.95])  # Scale (γ) per channel
batch_norm.bias.data = torch.tensor([0.2, 0.37, 0.67])    # Shift (β) per channel

In [24]:
# Print out the names and tensors of the parameters in batch_norm
for name, param in batch_norm.named_parameters():
    print(f'Name: {name}')
    print(f'Tensor: {param}')
    print(f'Shape: {param.shape}')

Name: weight
Tensor: Parameter containing:
tensor([0.7000, 0.4000, 0.9500], requires_grad=True)
Shape: torch.Size([3])
Name: bias
Tensor: Parameter containing:
tensor([0.2000, 0.3700, 0.6700], requires_grad=True)
Shape: torch.Size([3])


In [25]:
# Forward pass through BatchNorm layer
output = batch_norm(input_tensor)

In [10]:
output.shape

torch.Size([2, 3, 4, 4])

In [26]:
# Manual calculation of batch normalization
mean = input_tensor.mean(dim=(0, 2, 3), keepdim=True)  # Mean per channel
print(f'Mean per channel: {mean}')
print(f'Shape of mean tensor: {mean.shape}')

Mean per channel: tensor([[[[32.5000]],

         [[48.5000]],

         [[64.5000]]]])
Shape of mean tensor: torch.Size([1, 3, 1, 1])


In [27]:
# First Channel
X1C1 = input_tensor[0,0]
print(X1C1)
X2C1 = input_tensor[1,0]
print(X2C1)

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
tensor([[49., 50., 51., 52.],
        [53., 54., 55., 56.],
        [57., 58., 59., 60.],
        [61., 62., 63., 64.]])


In [28]:
Sum1 = X1C1 + X2C1
print(Sum1)
print(Sum1.sum())
print(Sum1.sum()/(16*2))

tensor([[50., 52., 54., 56.],
        [58., 60., 62., 64.],
        [66., 68., 70., 72.],
        [74., 76., 78., 80.]])
tensor(1040.)
tensor(32.5000)


In [29]:
# Second Channel
X1C2 = input_tensor[0,1]
print(X1C2)
X2C2 = input_tensor[1,1]
print(X2C2)

tensor([[17., 18., 19., 20.],
        [21., 22., 23., 24.],
        [25., 26., 27., 28.],
        [29., 30., 31., 32.]])
tensor([[65., 66., 67., 68.],
        [69., 70., 71., 72.],
        [73., 74., 75., 76.],
        [77., 78., 79., 80.]])


In [30]:
Sum2 = X1C2 + X2C2
print(Sum2)
print(Sum2.sum())
print(Sum2.sum()/(16*2))

tensor([[ 82.,  84.,  86.,  88.],
        [ 90.,  92.,  94.,  96.],
        [ 98., 100., 102., 104.],
        [106., 108., 110., 112.]])
tensor(1552.)
tensor(48.5000)


In [31]:
# Third Channel
X1C3 = input_tensor[0,2]
print(X1C3)
X2C3 = input_tensor[1,2]
print(X2C3)

tensor([[33., 34., 35., 36.],
        [37., 38., 39., 40.],
        [41., 42., 43., 44.],
        [45., 46., 47., 48.]])
tensor([[81., 82., 83., 84.],
        [85., 86., 87., 88.],
        [89., 90., 91., 92.],
        [93., 94., 95., 96.]])


In [32]:
Sum3 = X1C3 + X2C3
print(Sum3)
print(Sum3.sum())
print(Sum3.sum()/(16*2))

tensor([[114., 116., 118., 120.],
        [122., 124., 126., 128.],
        [130., 132., 134., 136.],
        [138., 140., 142., 144.]])
tensor(2064.)
tensor(64.5000)


In [33]:
variance = input_tensor.var(dim=(0, 2, 3), unbiased=False, keepdim=True)  # Variance per channel
print(f'Variance per channel: {variance}')
print(f'shape of variance tensor: {variance.shape}')

Variance per channel: tensor([[[[597.2500]],

         [[597.2500]],

         [[597.2500]]]])
shape of variance tensor: torch.Size([1, 3, 1, 1])


In [34]:
epsilon = batch_norm.eps  # Small constant for numerical stability
print(epsilon)

1e-05


In [35]:
print(input_tensor - mean)

tensor([[[[-31.5000, -30.5000, -29.5000, -28.5000],
          [-27.5000, -26.5000, -25.5000, -24.5000],
          [-23.5000, -22.5000, -21.5000, -20.5000],
          [-19.5000, -18.5000, -17.5000, -16.5000]],

         [[-31.5000, -30.5000, -29.5000, -28.5000],
          [-27.5000, -26.5000, -25.5000, -24.5000],
          [-23.5000, -22.5000, -21.5000, -20.5000],
          [-19.5000, -18.5000, -17.5000, -16.5000]],

         [[-31.5000, -30.5000, -29.5000, -28.5000],
          [-27.5000, -26.5000, -25.5000, -24.5000],
          [-23.5000, -22.5000, -21.5000, -20.5000],
          [-19.5000, -18.5000, -17.5000, -16.5000]]],


        [[[ 16.5000,  17.5000,  18.5000,  19.5000],
          [ 20.5000,  21.5000,  22.5000,  23.5000],
          [ 24.5000,  25.5000,  26.5000,  27.5000],
          [ 28.5000,  29.5000,  30.5000,  31.5000]],

         [[ 16.5000,  17.5000,  18.5000,  19.5000],
          [ 20.5000,  21.5000,  22.5000,  23.5000],
          [ 24.5000,  25.5000,  26.5000,  27.5000],
  

In [43]:
print(torch.sqrt(variance + epsilon))

tensor([[[[24.4387]],

         [[24.4387]],

         [[24.4387]]]])


In [36]:
# Normalize: (x - mean) / sqrt(variance + epsilon)
x_normalized = (input_tensor - mean) / torch.sqrt(variance + epsilon)
print(x_normalized)


tensor([[[[-1.2889, -1.2480, -1.2071, -1.1662],
          [-1.1253, -1.0843, -1.0434, -1.0025],
          [-0.9616, -0.9207, -0.8798, -0.8388],
          [-0.7979, -0.7570, -0.7161, -0.6752]],

         [[-1.2889, -1.2480, -1.2071, -1.1662],
          [-1.1253, -1.0843, -1.0434, -1.0025],
          [-0.9616, -0.9207, -0.8798, -0.8388],
          [-0.7979, -0.7570, -0.7161, -0.6752]],

         [[-1.2889, -1.2480, -1.2071, -1.1662],
          [-1.1253, -1.0843, -1.0434, -1.0025],
          [-0.9616, -0.9207, -0.8798, -0.8388],
          [-0.7979, -0.7570, -0.7161, -0.6752]]],


        [[[ 0.6752,  0.7161,  0.7570,  0.7979],
          [ 0.8388,  0.8798,  0.9207,  0.9616],
          [ 1.0025,  1.0434,  1.0843,  1.1253],
          [ 1.1662,  1.2071,  1.2480,  1.2889]],

         [[ 0.6752,  0.7161,  0.7570,  0.7979],
          [ 0.8388,  0.8798,  0.9207,  0.9616],
          [ 1.0025,  1.0434,  1.0843,  1.1253],
          [ 1.1662,  1.2071,  1.2480,  1.2889]],

         [[ 0.6752,  0.7161,

In [45]:
w1 = batch_norm.weight.view(1,-1,1,1)
print(w1)
print(f'shape of w1: {w1.shape}')

tensor([[[[0.7000]],

         [[0.4000]],

         [[0.9500]]]], grad_fn=<ViewBackward0>)
shape of w1: torch.Size([1, 3, 1, 1])


In [47]:
batch_norm.weight.view(1, -1, 1, 1) * x_normalized

tensor([[[[-0.9023, -0.8736, -0.8450, -0.8163],
          [-0.7877, -0.7590, -0.7304, -0.7018],
          [-0.6731, -0.6445, -0.6158, -0.5872],
          [-0.5585, -0.5299, -0.5013, -0.4726]],

         [[-0.5156, -0.4992, -0.4828, -0.4665],
          [-0.4501, -0.4337, -0.4174, -0.4010],
          [-0.3846, -0.3683, -0.3519, -0.3355],
          [-0.3192, -0.3028, -0.2864, -0.2701]],

         [[-1.2245, -1.1856, -1.1467, -1.1079],
          [-1.0690, -1.0301, -0.9913, -0.9524],
          [-0.9135, -0.8746, -0.8358, -0.7969],
          [-0.7580, -0.7191, -0.6803, -0.6414]]],


        [[[ 0.4726,  0.5013,  0.5299,  0.5585],
          [ 0.5872,  0.6158,  0.6445,  0.6731],
          [ 0.7018,  0.7304,  0.7590,  0.7877],
          [ 0.8163,  0.8450,  0.8736,  0.9023]],

         [[ 0.2701,  0.2864,  0.3028,  0.3192],
          [ 0.3355,  0.3519,  0.3683,  0.3846],
          [ 0.4010,  0.4174,  0.4337,  0.4501],
          [ 0.4665,  0.4828,  0.4992,  0.5156]],

         [[ 0.6414,  0.6803,

In [37]:
# Scale and shift: γ * x_normalized + β
manual_output = batch_norm.weight.view(1, -1, 1, 1) * x_normalized + batch_norm.bias.view(1, -1, 1, 1)
print(manual_output)

tensor([[[[-7.0226e-01, -6.7361e-01, -6.4497e-01, -6.1633e-01],
          [-5.8769e-01, -5.5904e-01, -5.3040e-01, -5.0176e-01],
          [-4.7311e-01, -4.4447e-01, -4.1583e-01, -3.8718e-01],
          [-3.5854e-01, -3.2990e-01, -3.0125e-01, -2.7261e-01]],

         [[-1.4558e-01, -1.2921e-01, -1.1284e-01, -9.6473e-02],
          [-8.0106e-02, -6.3738e-02, -4.7371e-02, -3.1003e-02],
          [-1.4636e-02,  1.7316e-03,  1.8099e-02,  3.4467e-02],
          [ 5.0834e-02,  6.7202e-02,  8.3569e-02,  9.9937e-02]],

         [[-5.5449e-01, -5.1562e-01, -4.7675e-01, -4.3787e-01],
          [-3.9900e-01, -3.6013e-01, -3.2126e-01, -2.8238e-01],
          [-2.4351e-01, -2.0464e-01, -1.6576e-01, -1.2689e-01],
          [-8.8019e-02, -4.9146e-02, -1.0274e-02,  2.8599e-02]]],


        [[[ 6.7261e-01,  7.0125e-01,  7.2990e-01,  7.5854e-01],
          [ 7.8718e-01,  8.1583e-01,  8.4447e-01,  8.7311e-01],
          [ 9.0176e-01,  9.3040e-01,  9.5904e-01,  9.8769e-01],
          [ 1.0163e+00,  1.0450e

In [38]:
# # Print results
# print("Input Tensor:")
# print(input_tensor)
# print("\nMean per Channel:")
# print(mean)
# print("\nVariance per Channel:")
# print(variance)
# print("\nNormalized Value (x_normalized):")
# print(x_normalized)
print("\nBatch Normalization Output (PyTorch):")
print(output)



Batch Normalization Output (PyTorch):
tensor([[[[-7.0226e-01, -6.7361e-01, -6.4497e-01, -6.1633e-01],
          [-5.8769e-01, -5.5904e-01, -5.3040e-01, -5.0176e-01],
          [-4.7311e-01, -4.4447e-01, -4.1583e-01, -3.8718e-01],
          [-3.5854e-01, -3.2990e-01, -3.0125e-01, -2.7261e-01]],

         [[-1.4558e-01, -1.2921e-01, -1.1284e-01, -9.6473e-02],
          [-8.0106e-02, -6.3738e-02, -4.7371e-02, -3.1003e-02],
          [-1.4636e-02,  1.7316e-03,  1.8099e-02,  3.4467e-02],
          [ 5.0834e-02,  6.7202e-02,  8.3569e-02,  9.9937e-02]],

         [[-5.5449e-01, -5.1562e-01, -4.7675e-01, -4.3787e-01],
          [-3.9900e-01, -3.6013e-01, -3.2126e-01, -2.8238e-01],
          [-2.4351e-01, -2.0464e-01, -1.6576e-01, -1.2689e-01],
          [-8.8019e-02, -4.9146e-02, -1.0273e-02,  2.8599e-02]]],


        [[[ 6.7261e-01,  7.0125e-01,  7.2990e-01,  7.5854e-01],
          [ 7.8718e-01,  8.1583e-01,  8.4447e-01,  8.7311e-01],
          [ 9.0176e-01,  9.3040e-01,  9.5904e-01,  9.8769

In [40]:
print("\nBatch Normalization Output (Manual Calculation):")
print(manual_output)


Batch Normalization Output (Manual Calculation):
tensor([[[[-7.0226e-01, -6.7361e-01, -6.4497e-01, -6.1633e-01],
          [-5.8769e-01, -5.5904e-01, -5.3040e-01, -5.0176e-01],
          [-4.7311e-01, -4.4447e-01, -4.1583e-01, -3.8718e-01],
          [-3.5854e-01, -3.2990e-01, -3.0125e-01, -2.7261e-01]],

         [[-1.4558e-01, -1.2921e-01, -1.1284e-01, -9.6473e-02],
          [-8.0106e-02, -6.3738e-02, -4.7371e-02, -3.1003e-02],
          [-1.4636e-02,  1.7316e-03,  1.8099e-02,  3.4467e-02],
          [ 5.0834e-02,  6.7202e-02,  8.3569e-02,  9.9937e-02]],

         [[-5.5449e-01, -5.1562e-01, -4.7675e-01, -4.3787e-01],
          [-3.9900e-01, -3.6013e-01, -3.2126e-01, -2.8238e-01],
          [-2.4351e-01, -2.0464e-01, -1.6576e-01, -1.2689e-01],
          [-8.8019e-02, -4.9146e-02, -1.0274e-02,  2.8599e-02]]],


        [[[ 6.7261e-01,  7.0125e-01,  7.2990e-01,  7.5854e-01],
          [ 7.8718e-01,  8.1583e-01,  8.4447e-01,  8.7311e-01],
          [ 9.0176e-01,  9.3040e-01,  9.5904e-