# Normalization : Batch vs Group
In this experiment, we compare Batch Normalization and Group Normalization under varying batch sizes.


**Objective:**  
Batch normalization performs poorly when batch size is small, while group normalization remains stable regardless of batch size. This experiment compares the two.

**Conditions:**  
- Use two identical network architectures, differing only in normalization layer: BatchNorm vs GroupNorm  
- Test batch sizes: 2, 4, 8, 16, 32, 64, 128

**Expected Outcome:**  
- BatchNorm only performs well above a certain threshold batch size (N).  
- GroupNorm should maintain similar performance regardless of batch size.  
- Even when BatchNorm performs well, its accuracy should not differ from GroupNorm by more than 3%.

## Theoretical Background

**Batch Normalization** normalizes activations using the statistics (mean and variance) of the current mini-batch. While effective for large batch sizes, its performance often degrades when the batch size is small due to inaccurate batch statistics.

**Group Normalization**, in contrast, divides channels into groups and computes normalization within each group—making it independent of batch size.

## Experimental Setup
- **Dataset**: CIFAR-10
- **model**: CNN with fixed structure, differing only in normalization method.
- **Normalization type**: BatchNorm vs GroupNorm
- **Batch sizes tested**: `2, 4, 8, 16, 32, 64, 128`
- **Evaluation Metric**: Top-1 Accuracy
- **Runs**: Each experiment repeated 5 times for averaging

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np


## HyperParameters

In [9]:
hyperparameters = {
    "batch_sizes": [2, 4, 8, 16, 32, 64, 128],
    "num_blocks": 4,
    "learning_rate": 0.001,
    "epochs": 10,
}

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("Device:", device)

Device: cuda


## Data Load

In [10]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


## Basic Block
- group size: 4

In [11]:
def get_norm(norm_type, num_channels):
    if norm_type == 'batch':
        return nn.BatchNorm2d(num_channels)
    else :
        return nn.GroupNorm(4, num_channels)

class NormBlock(nn.Module):
    def __init__(self, in_channels, out_channels, norm_type):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm = get_norm(norm_type, out_channels)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))

## Network

In [12]:
class Network(nn.Module):
    def __init__(self, num_blocks, norm_type):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.blocks = nn.Sequential(*[NormBlock(64, 64, norm_type) for _ in range(num_blocks)])
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.blocks(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

## Train, Evaluate

In [13]:
def train_model(model, train_loader, criterion, optimizer, device, norm_type):
    model.train()
    correct, total, loss_sum = 0, 0, 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        # average and standard variation
        if batch_idx == 0 and norm_type == 'batchnorm':  # first batch
            for module in model.modules():
                if isinstance(module, nn.BatchNorm2d):
                    print(f"[BN running_mean] mean={module.running_mean.mean().item():.4f}, std={module.running_var.mean().item():.4f}")
                    break 

    return loss_sum / len(train_loader), 100. * correct / total


def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss_sum += loss.item()
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return loss_sum / len(test_loader), 100. * correct / total

In [14]:
criterion = nn.CrossEntropyLoss()
results = {"batchnorm": [], "groupnorm": []}

for batch_size in hyperparameters['batch_sizes']:
    print(f"\nBatch size: {batch_size}")
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    for norm in ["batchnorm", "groupnorm"]:
        print(f"  Norm type: {norm}")
        model = Network(num_blocks=hyperparameters["num_blocks"], norm_type=norm[:-4]).to(device)
        optimizer = optim.Adam(model.parameters(), lr=hyperparameters['learning_rate'])

        best_test_acc = 0
        for epoch in range(hyperparameters['epochs']):
            train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device, norm_type=norm[:-4])
            test_loss, test_acc = evaluate_model(model, test_loader, criterion, device)
            best_test_acc = max(best_test_acc, test_acc)
            print(f"    Epoch {epoch+1}/{hyperparameters['epochs']} - Test Acc: {test_acc:.2f}%")

        results[norm].append(best_test_acc)

# print result
print("\nSummary")
print("{:<10} {:<15} {:<15}".format("Batch", "BatchNorm (%)", "GroupNorm (%)"))
for i, b in enumerate(hyperparameters['batch_sizes']):
    print("{:<10} {:<15.2f} {:<15.2f}".format(b, results['batchnorm'][i], results['groupnorm'][i]))


Batch size: 2
  Norm type: batchnorm
    Epoch 1/10 - Test Acc: 44.94%
    Epoch 2/10 - Test Acc: 52.55%
    Epoch 3/10 - Test Acc: 54.83%
    Epoch 4/10 - Test Acc: 58.21%
    Epoch 5/10 - Test Acc: 60.30%
    Epoch 6/10 - Test Acc: 63.93%
    Epoch 7/10 - Test Acc: 66.70%
    Epoch 8/10 - Test Acc: 64.05%
    Epoch 9/10 - Test Acc: 68.09%
    Epoch 10/10 - Test Acc: 69.39%
  Norm type: groupnorm
    Epoch 1/10 - Test Acc: 44.28%
    Epoch 2/10 - Test Acc: 54.44%
    Epoch 3/10 - Test Acc: 57.81%
    Epoch 4/10 - Test Acc: 61.56%
    Epoch 5/10 - Test Acc: 62.63%
    Epoch 6/10 - Test Acc: 68.48%
    Epoch 7/10 - Test Acc: 67.35%
    Epoch 8/10 - Test Acc: 72.20%
    Epoch 9/10 - Test Acc: 70.41%
    Epoch 10/10 - Test Acc: 71.22%

Batch size: 4
  Norm type: batchnorm
    Epoch 1/10 - Test Acc: 55.87%
    Epoch 2/10 - Test Acc: 60.45%
    Epoch 3/10 - Test Acc: 66.17%
    Epoch 4/10 - Test Acc: 68.30%
    Epoch 5/10 - Test Acc: 70.41%
    Epoch 6/10 - Test Acc: 71.26%
    Epoch 7/10 

## Results

| **Batch Size** | **BatchNorm (Accuracy %)** | **GroupNorm (Accuracy %)** |
| --- | --- | --- |
| 2 | 67.83 ± 2.21 | **71.97 ± 1.91** |
| 4 | 75.92 ± 1.07 | 71.71 ± 1.15 |
| 8 | 77.74 ± 0.76 | 71.91 ± 0.53 |
| 16 | **77.79 ± 0.54** | 71.49 ± 0.54 |
| 32 | 76.83 ± 1.01 | 70.43 ± 1.26 |
| 64 | 75.37 ± 1.23 | 69.71 ± 0.55 |
| 128 | 72.96 ± 0.85 | 67.77 ± 0.87 |

**Test Accuracy by Increasing Number of Blocks**
- ReLU: 54.99% → 80.07% (+25.08%)
- Sigmoid: 44.59% → 32.97% (−11.62%)

**Gradient Magnitude Reduction**
- ReLU: 0.0124 → 0.0039 (−68.5%)
- Sigmoid: approximately 0.0090 → 0.0024 (−73.3%)

## Analysis

**Batch Normalization**

Batch Normalization exhibited lower accuracy and higher variance at smaller batch sizes. In particular, when the batch size was 2, it achieved the lowest performance and the highest standard deviation, indicating unstable training under insufficient batch statistics. From batch size 2 to 16, the accuracy increased significantly by 8.09%, peaking at 77.79% with a batch size of 16. However, as the batch size further increased, performance gradually declined, dropping by 4.83% to 72.96% at a batch size of 128. These findings align with the theoretical understanding that BatchNorm is highly dependent on sufficiently large batch sizes to estimate accurate normalization statistics. In this experiment, the method showed relatively robust performance within the mid-range batch sizes (between 4 and 32).
The average standard deviation across all batch sizes was 3.54, suggesting notable fluctuations in test accuracy across training runs.

**Group Normalization**

In contrast, Group Normalization maintained consistent performance across all tested batch sizes, with test accuracy ranging from 67.77% to 71.97%. The accuracy variation due to changes in batch size was relatively small. Moreover, the average standard deviation was 1.55, which is 1.99 points lower than that of BatchNorm, indicating more stable behavior under various training conditions. While GroupNorm yielded slightly lower peak accuracy than BatchNorm, it demonstrated significantly lower sensitivity to batch size, validating its design goal of being independent from batch dimensions.