# Applying Custom Weight Initialization in PyTorch


## Step 1: Setting Up the PyTorch Environment


In [None]:
import torch
import torch.nn as nn
# torch.nn.init is commonly imported as init
import torch.nn.init as init

## Step 2: Defining a Simple Neural Network with Custom Initialization


In [None]:
class CustomInitNetwork(nn.Module):
    """
    A simple feedforward neural network with custom weight initialization.
    - fc1 and fc2 (hidden layers) use Kaiming (He) initialization for ReLU activation.
    - fc3 (output layer) uses Xavier (Glorot) initialization.
    - All biases are initialized to zeros.
    """
    def __init__(self):
        super(CustomInitNetwork, self).__init__()
        # Define the fully connected layers
        self.fc1 = nn.Linear(784, 256)  # Input layer (e.g., flattened 28x28 image) to first hidden layer
        self.fc2 = nn.Linear(256, 128)  # First hidden layer to second hidden layer
        self.fc3 = nn.Linear(128, 10)   # Second hidden layer to output layer (10 classes)

        # Apply custom weight initialization immediately after defining layers
        self.init_weights()

    def init_weights(self):
        """
        Custom weight initialization method for the network's layers.
        Applies Kaiming uniform initialization to hidden layers (fc1, fc2)
        and Xavier uniform initialization to the output layer (fc3).
        All biases are initialized to zeros.
        """
        # Apply Kaiming (He) initialization for layers followed by ReLU activations
        # This is appropriate for self.fc1 and self.fc2
        init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu')
        if self.fc1.bias is not None:
            init.zeros_(self.fc1.bias)

        init.kaiming_uniform_(self.fc2.weight, nonlinearity='relu')
        if self.fc2.bias is not None:
            init.zeros_(self.fc2.bias)

        # Apply Xavier (Glorot) initialization for the output layer (self.fc3).
        # Xavier is often suitable for layers with linear activations or when
        # the subsequent activation (like softmax in CrossEntropyLoss) is not ReLU.
        init.xavier_uniform_(self.fc3.weight)
        if self.fc3.bias is not None:
            init.zeros_(self.fc3.bias)

    def forward(self, x):
        """
        Defines the forward pass of the neural network.
        Applies ReLU activation after the hidden layers.
        The output layer produces raw logits, as CrossEntropyLoss
        will internally apply softmax.
        """
        # Flatten the input tensor (e.g., from [batch_size, channels, height, width]
        # to [batch_size, features])
        x = x.view(x.size(0), -1)

        # Pass through the first hidden layer with ReLU activation
        x = torch.relu(self.fc1(x))
        # Pass through the second hidden layer with ReLU activation
        x = torch.relu(self.fc2(x))
        # Pass through the output layer (no activation here)
        x = self.fc3(x)
        return x

## Step 3: Implementing the Initialization in Practice


In [None]:
# Instantiate the model
model = CustomInitNetwork()

# Create some random input data (e.g., batch of 64, MNIST-like images flattened)
# Ensure the model is in evaluation mode if only doing inference,
# or training mode if intending to train. For this step, default (train) is fine.
input_data = torch.randn(64, 784)
output = model(input_data)

print("Custom initialized network output shape:", output.shape)

Custom initialized network output shape: torch.Size([64, 10])


## Step 4: Training Models with Custom Initialization


In [None]:
# Define a loss function and an optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Dummy target data for demonstration (64 samples, 10 classes)
# In a real scenario, this would come from your dataset
dummy_targets = torch.randint(0, 10, (64,))

# Simple training loop
num_epochs = 20
model.train() # Set the model to training mode

for epoch in range(num_epochs):
    optimizer.zero_grad()  # Clear previous gradients
    outputs = model(input_data)  # Forward pass
    loss = criterion(outputs, dummy_targets)  # Compute loss
    loss.backward()  # Backward pass (compute gradients)
    optimizer.step()  # Update weights

    if (epoch + 1) % 5 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Epoch [5/20], Loss: 2.2939
Epoch [10/20], Loss: 1.6650
Epoch [15/20], Loss: 1.2325
Epoch [20/20], Loss: 0.9281
