# Batch Normalization

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import torch
import torch.nn as nn
import multiprocessing
import torch.optim as optim
from torch.utils.data import DataLoader
from main import MNIST_dataset, MNIST_trainer

In [10]:
print("Torch version: ", torch.__version__)

####################################################################
# Set Device
####################################################################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

Torch version:  2.5.1+cu124
Device:  cuda


In [11]:
####################################################################
# DataLoader Class
####################################################################
train_dataset = MNIST_dataset(partition="train")
test_dataset = MNIST_dataset(partition="test")

batch_size = 100
num_workers = multiprocessing.cpu_count() - 1
print("Num workers", num_workers)
train_dataloader = DataLoader(
    train_dataset, batch_size, shuffle=True, num_workers=num_workers
)
test_dataloader = DataLoader(
    test_dataset, batch_size, shuffle=False, num_workers=num_workers
)


Loading MNIST  train  Dataset...
	Total Len.:  60000 
 --------------------------------------------------

Loading MNIST  test  Dataset...
	Total Len.:  10000 
 --------------------------------------------------
Num workers 11


In [12]:
####################################################################
# Neural Network Class
####################################################################


# Creating our Neural Network - Fully Connected
class Net(nn.Module):
    def __init__(
        self,
        sizes=[[784, 1024], [1024, 1024], [1024, 1024], [1024, 10]],
        criterion=None,
    ):
        super(Net, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(len(sizes) - 1):
            dims = sizes[i]
            self.layers.append(nn.Linear(dims[0], dims[1]))
            self.layers.append(nn.BatchNorm1d(dims[1]))
            self.layers.append(nn.ReLU())

        dims = sizes[-1]
        self.classifier = nn.Linear(dims[0], dims[1])

        self.criterion = criterion

    def forward(self, x, y=None):
        for layer in self.layers:
            x = layer(x)
        x = self.classifier(x)

        if y != None:
            loss = self.criterion(x, y)
            return loss, x
        return x


####################################################################
# Training settings
####################################################################

# Training hyperparameters
criterion = nn.CrossEntropyLoss()
# Instantiating the network and printing its architecture
num_classes = 10
net = Net(
    sizes=[[784, 1024], [1024, 1024], [1024, 1024], [1024, num_classes]],
    criterion=criterion,
)
print(net)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print("Params: ", count_parameters(net))
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-6, momentum=0.9)
epochs = 25

trainer = MNIST_trainer(
    net,
    train_dataloader,
    test_dataloader,
    optimizer,
    criterion,
    epochs,
    device,
    model_path="models/batchnorm.pt",
)

Net(
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=1024, out_features=1024, bias=True)
    (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (classifier): Linear(in_features=1024, out_features=10, bias=True)
  (criterion): CrossEntropyLoss()
)
Params:  2919434


In [13]:
####################################################################
# Training
####################################################################

trainer.train()


---- Start Training ----


Epoch 0: 100%|██████████| 600/600 [00:07<00:00, 85.16batch/s]
Test 0: 100%|██████████| 100/100 [00:00<00:00, 145.47batch/s]


[Epoch 1] Train Loss: 0.001681 - Test Loss: 0.000838 - Train Accuracy: 94.90% - Test Accuracy: 97.35%


Epoch 1: 100%|██████████| 600/600 [00:06<00:00, 89.37batch/s] 
Test 1: 100%|██████████| 100/100 [00:00<00:00, 119.05batch/s]


[Epoch 2] Train Loss: 0.000535 - Test Loss: 0.000657 - Train Accuracy: 98.33% - Test Accuracy: 97.93%


Epoch 2: 100%|██████████| 600/600 [00:07<00:00, 83.57batch/s]
Test 2: 100%|██████████| 100/100 [00:00<00:00, 111.12batch/s]

[Epoch 3] Train Loss: 0.000271 - Test Loss: 0.000592 - Train Accuracy: 99.15% - Test Accuracy: 98.22%



Epoch 3: 100%|██████████| 600/600 [00:07<00:00, 75.78batch/s]
Test 3: 100%|██████████| 100/100 [00:00<00:00, 117.29batch/s]

[Epoch 4] Train Loss: 0.000164 - Test Loss: 0.000631 - Train Accuracy: 99.52% - Test Accuracy: 98.03%



Epoch 4: 100%|██████████| 600/600 [00:06<00:00, 86.65batch/s]
Test 4: 100%|██████████| 100/100 [00:00<00:00, 104.01batch/s]

[Epoch 5] Train Loss: 0.000101 - Test Loss: 0.000598 - Train Accuracy: 99.71% - Test Accuracy: 98.27%



Epoch 5: 100%|██████████| 600/600 [00:06<00:00, 88.83batch/s]
Test 5: 100%|██████████| 100/100 [00:00<00:00, 115.75batch/s]

[Epoch 6] Train Loss: 0.000073 - Test Loss: 0.000545 - Train Accuracy: 99.80% - Test Accuracy: 98.44%



Epoch 6: 100%|██████████| 600/600 [00:08<00:00, 71.39batch/s]
Test 6: 100%|██████████| 100/100 [00:00<00:00, 110.96batch/s]

[Epoch 7] Train Loss: 0.000042 - Test Loss: 0.000546 - Train Accuracy: 99.92% - Test Accuracy: 98.48%



Epoch 7: 100%|██████████| 600/600 [00:07<00:00, 85.15batch/s]
Test 7: 100%|██████████| 100/100 [00:01<00:00, 99.64batch/s] 

[Epoch 8] Train Loss: 0.000041 - Test Loss: 0.000547 - Train Accuracy: 99.90% - Test Accuracy: 98.47%



Epoch 8: 100%|██████████| 600/600 [00:06<00:00, 93.70batch/s] 
Test 8: 100%|██████████| 100/100 [00:00<00:00, 119.86batch/s]

[Epoch 9] Train Loss: 0.000030 - Test Loss: 0.000588 - Train Accuracy: 99.93% - Test Accuracy: 98.41%



Epoch 9: 100%|██████████| 600/600 [00:06<00:00, 90.65batch/s]
Test 9: 100%|██████████| 100/100 [00:00<00:00, 108.75batch/s]

[Epoch 10] Train Loss: 0.000023 - Test Loss: 0.000572 - Train Accuracy: 99.95% - Test Accuracy: 98.43%



Epoch 10: 100%|██████████| 600/600 [00:06<00:00, 87.25batch/s]
Test 10: 100%|██████████| 100/100 [00:00<00:00, 111.12batch/s]

[Epoch 11] Train Loss: 0.000017 - Test Loss: 0.000560 - Train Accuracy: 99.96% - Test Accuracy: 98.50%



Epoch 11: 100%|██████████| 600/600 [00:07<00:00, 80.42batch/s] 
Test 11: 100%|██████████| 100/100 [00:00<00:00, 115.10batch/s]

[Epoch 12] Train Loss: 0.000010 - Test Loss: 0.000516 - Train Accuracy: 99.99% - Test Accuracy: 98.56%



Epoch 12: 100%|██████████| 600/600 [00:06<00:00, 88.14batch/s]
Test 12: 100%|██████████| 100/100 [00:00<00:00, 136.07batch/s]

[Epoch 13] Train Loss: 0.000014 - Test Loss: 0.000520 - Train Accuracy: 99.97% - Test Accuracy: 98.59%



Epoch 13: 100%|██████████| 600/600 [00:06<00:00, 88.55batch/s]
Test 13: 100%|██████████| 100/100 [00:00<00:00, 116.62batch/s]

[Epoch 14] Train Loss: 0.000008 - Test Loss: 0.000529 - Train Accuracy: 99.99% - Test Accuracy: 98.67%



Epoch 14: 100%|██████████| 600/600 [00:07<00:00, 82.26batch/s]
Test 14: 100%|██████████| 100/100 [00:00<00:00, 118.33batch/s]

[Epoch 15] Train Loss: 0.000004 - Test Loss: 0.000531 - Train Accuracy: 100.00% - Test Accuracy: 98.59%



Epoch 15: 100%|██████████| 600/600 [00:07<00:00, 81.79batch/s]
Test 15: 100%|██████████| 100/100 [00:00<00:00, 115.56batch/s]

[Epoch 16] Train Loss: 0.000005 - Test Loss: 0.000523 - Train Accuracy: 100.00% - Test Accuracy: 98.62%



Epoch 16: 100%|██████████| 600/600 [00:06<00:00, 90.25batch/s] 
Test 16: 100%|██████████| 100/100 [00:00<00:00, 119.22batch/s]

[Epoch 17] Train Loss: 0.000003 - Test Loss: 0.000520 - Train Accuracy: 100.00% - Test Accuracy: 98.61%



Epoch 17: 100%|██████████| 600/600 [00:06<00:00, 92.16batch/s]
Test 17: 100%|██████████| 100/100 [00:00<00:00, 123.04batch/s]

[Epoch 18] Train Loss: 0.000003 - Test Loss: 0.000506 - Train Accuracy: 100.00% - Test Accuracy: 98.62%



Epoch 18: 100%|██████████| 600/600 [00:06<00:00, 91.23batch/s]
Test 18: 100%|██████████| 100/100 [00:00<00:00, 114.85batch/s]

[Epoch 19] Train Loss: 0.000003 - Test Loss: 0.000543 - Train Accuracy: 100.00% - Test Accuracy: 98.58%



Epoch 19: 100%|██████████| 600/600 [00:06<00:00, 90.54batch/s]
Test 19: 100%|██████████| 100/100 [00:00<00:00, 124.66batch/s]

[Epoch 20] Train Loss: 0.000008 - Test Loss: 0.000531 - Train Accuracy: 99.98% - Test Accuracy: 98.59%



Epoch 20: 100%|██████████| 600/600 [00:06<00:00, 92.11batch/s]
Test 20: 100%|██████████| 100/100 [00:01<00:00, 99.30batch/s]

[Epoch 21] Train Loss: 0.000011 - Test Loss: 0.000583 - Train Accuracy: 99.97% - Test Accuracy: 98.48%



Epoch 21: 100%|██████████| 600/600 [00:07<00:00, 83.02batch/s]
Test 21: 100%|██████████| 100/100 [00:00<00:00, 110.57batch/s]

[Epoch 22] Train Loss: 0.000005 - Test Loss: 0.000547 - Train Accuracy: 100.00% - Test Accuracy: 98.54%



Epoch 22: 100%|██████████| 600/600 [00:07<00:00, 78.10batch/s]
Test 22: 100%|██████████| 100/100 [00:01<00:00, 87.61batch/s]

[Epoch 23] Train Loss: 0.000007 - Test Loss: 0.000548 - Train Accuracy: 99.98% - Test Accuracy: 98.66%



Epoch 23: 100%|██████████| 600/600 [00:09<00:00, 64.95batch/s]
Test 23: 100%|██████████| 100/100 [00:00<00:00, 101.13batch/s]

[Epoch 24] Train Loss: 0.000007 - Test Loss: 0.000584 - Train Accuracy: 99.98% - Test Accuracy: 98.52%



Epoch 24: 100%|██████████| 600/600 [00:07<00:00, 83.08batch/s]
Test 24: 100%|██████████| 100/100 [00:00<00:00, 116.03batch/s]

[Epoch 25] Train Loss: 0.000003 - Test Loss: 0.000522 - Train Accuracy: 100.00% - Test Accuracy: 98.65%

BEST TEST ACCURACY:  98.67  in epoch  13





In [14]:
####################################################################
# Load best weights
####################################################################

trainer.get_model()

  test_loss += self.criterion(outputs, labels)
Test 24: 100%|██████████| 100/100 [00:00<00:00, 101.85batch/s]

Final best acc:  98.67



