In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
class SimpleNNWithBatchNorm(nn.Module):
    def __init__(self, input_dim, hidden1_dim, hidden2_dim, output_dim):
        super().__init__() #to fetch all the methods of the parent class , (super Keyword = same as java)
        #if Super wont be used PyTorch will be be aware of the Layers and parameters etc passed.
        # in our case the nn.module is the Parent class from which the methods needs to be fetched.
        
        self.fc1 = nn.Linear(input_dim, hidden1_dim)
        self.bn1 = nn.BatchNorm1d(hidden1_dim)
        self.fc2 = nn.Linear(hidden1_dim, hidden2_dim)
        self.bn2 = nn.BatchNorm1d(hidden2_dim)
        self.fc3 = nn.Linear(hidden2_dim, output_dim)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = torch.sigmoid(self.fc3(x))
        return x        

In [7]:
X = torch.randn(16, 4)  # batch of 16 samples, 4 features each
model = SimpleNNWithBatchNorm(4, 8, 8, 1) #4 = input features 
# 8 and 8 are the neurons in the HiddenLayers
#and 1 is the neuron in the Output layer 
y_pred = model(X)
print(y_pred.shape)

torch.Size([16, 1])


In [8]:
torch.manual_seed(42)
X = torch.randn(8, 4) * 20 + 50  # mean ~50, std ~20

print("Before BatchNorm:")
print("Mean  :", X.mean(dim=0))
print("Std   :", X.std(dim=0))
print("Data  :\n", X)

# Define BatchNorm layer for 4 features
bn = nn.BatchNorm1d(4)
X_bn = bn(X)

print("\nAfter BatchNorm:")
print("Mean  :", X_bn.mean(dim=0))
print("Std   :", X_bn.std(dim=0))
print("Data  :\n", X_bn)

Before BatchNorm:
Mean  : tensor([57.6425, 58.9973, 50.8966, 49.9086])
Std   : tensor([22.7563, 21.4063, 12.7993, 29.3725])
Data  :
 tensor([[88.5383, 79.7457, 68.0143,  7.8896],
        [63.5684, 25.3091, 49.1386, 17.9067],
        [34.9573, 82.9745, 42.1504, 21.9279],
        [35.4424, 38.8114, 34.6232, 65.2489],
        [82.8463, 46.8081, 40.0520, 58.7918],
        [34.8374, 71.5664, 66.0160, 83.6124],
        [75.5825, 75.9285, 62.2093, 76.6948],
        [45.3675, 50.8352, 44.9685, 67.1972]])

After BatchNorm:
Mean  : tensor([-1.1921e-07,  7.4506e-08, -1.4901e-07, -1.6391e-07],
       grad_fn=<MeanBackward1>)
Std   : tensor([1.0690, 1.0690, 1.0690, 1.0690], grad_fn=<StdBackward0>)
Data  :
 tensor([[ 1.4514,  1.0362,  1.4297, -1.5293],
        [ 0.2784, -1.6824, -0.1468, -1.1647],
        [-1.0657,  1.1974, -0.7305, -1.0184],
        [-1.0429, -1.0081, -1.3592,  0.5583],
        [ 1.1840, -0.6087, -0.9058,  0.3233],
        [-1.0713,  0.6277,  1.2628,  1.2267],
        [ 0.8428,  0.