# Data augmentation

In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
import torch
import torch.nn as nn
import multiprocessing
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from main import MNIST_dataset, MNIST_trainer

In [15]:
print("Torch version: ", torch.__version__)

####################################################################
# Set Device
####################################################################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

Torch version:  2.5.1+cu124
Device:  cuda


In [16]:
####################################################################
# DataLoader Class
####################################################################
da = transforms.Compose(
    [transforms.RandomRotation((10)), transforms.AugMix(), transforms.ToTensor()]
)
train_dataset = MNIST_dataset(partition="train", da_transform=da)
test_dataset = MNIST_dataset(partition="test")

batch_size = 50
num_workers = multiprocessing.cpu_count() - 1
print("Num workers", num_workers)

train_dataloader = DataLoader(
    train_dataset, batch_size, shuffle=True, num_workers=num_workers
)
test_dataloader = DataLoader(
    test_dataset, batch_size, shuffle=False, num_workers=num_workers
)


Loading MNIST  train  Dataset...
	Total Len.:  60000 
 --------------------------------------------------

Loading MNIST  test  Dataset...
	Total Len.:  10000 
 --------------------------------------------------
Num workers 11


In [17]:
####################################################################
# Neural Network Class
####################################################################


# Creating our Neural Network - Fully Connected
class Net(nn.Module):
    def __init__(
        self,
        sizes=[[784, 1024], [1024, 1024], [1024, 1024], [1024, 10]],
        criterion=None,
    ):
        super(Net, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(len(sizes) - 1):
            dims = sizes[i]
            self.layers.append(nn.Linear(dims[0], dims[1]))
            self.layers.append(nn.BatchNorm1d(dims[1]))
            self.layers.append(nn.ReLU())

        dims = sizes[-1]
        self.classifier = nn.Linear(dims[0], dims[1])

        self.criterion = criterion

    def forward(self, x, y=None):
        for layer in self.layers:
            x = layer(x)
        x = self.classifier(x)

        if y != None:
            loss = self.criterion(x, y)
            return loss, x
        return x


####################################################################
# Training settings
####################################################################

# Training hyperparameters
criterion = nn.CrossEntropyLoss()
# Instantiating the network and printing its architecture
num_classes = 10
net = Net(
    sizes=[[784, 1024], [1024, 1024], [1024, 1024], [1024, num_classes]],
    criterion=criterion,
)
print(net)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print("Params: ", count_parameters(net))
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-6, momentum=0.9)
epochs = 50

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, "min", patience=5, threshold=0.001, cooldown=1, verbose=True
)

trainer = MNIST_trainer(
    net,
    train_dataloader,
    test_dataloader,
    optimizer,
    criterion,
    epochs,
    device,
    scheduler=scheduler,
    model_path="models/da3.pt",
)

Net(
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=1024, out_features=1024, bias=True)
    (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (classifier): Linear(in_features=1024, out_features=10, bias=True)
  (criterion): CrossEntropyLoss()
)
Params:  2919434


In [18]:
####################################################################
# Training
####################################################################

trainer.train()


---- Start Training ----


Epoch 0: 100%|██████████| 1200/1200 [00:36<00:00, 32.66batch/s]
Test 0: 100%|██████████| 200/200 [00:01<00:00, 163.95batch/s]


	LR:  0.01
[Epoch 1] Train Loss: 0.004400 - Test Loss: 0.001864 - Train Accuracy: 93.09% - Test Accuracy: 97.07%


Epoch 1: 100%|██████████| 1200/1200 [00:34<00:00, 34.76batch/s]
Test 1: 100%|██████████| 200/200 [00:01<00:00, 169.69batch/s]

	LR:  0.01
[Epoch 2] Train Loss: 0.002241 - Test Loss: 0.001440 - Train Accuracy: 96.41% - Test Accuracy: 97.83%



Epoch 2: 100%|██████████| 1200/1200 [00:36<00:00, 32.64batch/s]
Test 2: 100%|██████████| 200/200 [00:01<00:00, 156.78batch/s]

	LR:  0.01
[Epoch 3] Train Loss: 0.001811 - Test Loss: 0.001212 - Train Accuracy: 97.09% - Test Accuracy: 98.10%



Epoch 3: 100%|██████████| 1200/1200 [00:36<00:00, 32.49batch/s]
Test 3: 100%|██████████| 200/200 [00:01<00:00, 128.82batch/s]

	LR:  0.01
[Epoch 4] Train Loss: 0.001402 - Test Loss: 0.001012 - Train Accuracy: 97.73% - Test Accuracy: 98.32%



Epoch 4: 100%|██████████| 1200/1200 [00:36<00:00, 33.20batch/s]
Test 4: 100%|██████████| 200/200 [00:01<00:00, 161.59batch/s]


	LR:  0.01
[Epoch 5] Train Loss: 0.001246 - Test Loss: 0.000981 - Train Accuracy: 97.97% - Test Accuracy: 98.52%


Epoch 5: 100%|██████████| 1200/1200 [00:36<00:00, 32.48batch/s]
Test 5: 100%|██████████| 200/200 [00:01<00:00, 136.32batch/s]


	LR:  0.01
[Epoch 6] Train Loss: 0.001099 - Test Loss: 0.001003 - Train Accuracy: 98.26% - Test Accuracy: 98.53%


Epoch 6: 100%|██████████| 1200/1200 [00:37<00:00, 32.31batch/s]
Test 6: 100%|██████████| 200/200 [00:01<00:00, 150.08batch/s]

	LR:  0.01
[Epoch 7] Train Loss: 0.000995 - Test Loss: 0.000846 - Train Accuracy: 98.33% - Test Accuracy: 98.64%



Epoch 7: 100%|██████████| 1200/1200 [00:35<00:00, 33.83batch/s]
Test 7: 100%|██████████| 200/200 [00:01<00:00, 167.68batch/s]

	LR:  0.01
[Epoch 8] Train Loss: 0.000931 - Test Loss: 0.001035 - Train Accuracy: 98.49% - Test Accuracy: 98.56%



Epoch 8: 100%|██████████| 1200/1200 [00:35<00:00, 33.91batch/s]
Test 8: 100%|██████████| 200/200 [00:01<00:00, 164.06batch/s]

	LR:  0.01
[Epoch 9] Train Loss: 0.000890 - Test Loss: 0.000920 - Train Accuracy: 98.61% - Test Accuracy: 98.63%



Epoch 9: 100%|██████████| 1200/1200 [00:34<00:00, 34.93batch/s]
Test 9: 100%|██████████| 200/200 [00:01<00:00, 154.62batch/s]

	LR:  0.01
[Epoch 10] Train Loss: 0.000736 - Test Loss: 0.000849 - Train Accuracy: 98.82% - Test Accuracy: 98.75%



Epoch 10: 100%|██████████| 1200/1200 [00:34<00:00, 34.68batch/s]
Test 10: 100%|██████████| 200/200 [00:01<00:00, 163.78batch/s]

	LR:  0.01
[Epoch 11] Train Loss: 0.000715 - Test Loss: 0.000802 - Train Accuracy: 98.81% - Test Accuracy: 98.70%



Epoch 11: 100%|██████████| 1200/1200 [00:34<00:00, 34.85batch/s]
Test 11: 100%|██████████| 200/200 [00:01<00:00, 160.66batch/s]

	LR:  0.01
[Epoch 12] Train Loss: 0.000656 - Test Loss: 0.000745 - Train Accuracy: 98.92% - Test Accuracy: 98.81%



Epoch 12: 100%|██████████| 1200/1200 [00:35<00:00, 33.53batch/s]
Test 12: 100%|██████████| 200/200 [00:01<00:00, 154.58batch/s]

	LR:  0.01
[Epoch 13] Train Loss: 0.000609 - Test Loss: 0.000793 - Train Accuracy: 99.05% - Test Accuracy: 98.81%



Epoch 13: 100%|██████████| 1200/1200 [00:35<00:00, 33.60batch/s]
Test 13: 100%|██████████| 200/200 [00:01<00:00, 160.54batch/s]


	LR:  0.01
[Epoch 14] Train Loss: 0.000533 - Test Loss: 0.000784 - Train Accuracy: 99.10% - Test Accuracy: 98.87%


Epoch 14: 100%|██████████| 1200/1200 [00:37<00:00, 31.92batch/s]
Test 14: 100%|██████████| 200/200 [00:01<00:00, 126.04batch/s]

	LR:  0.01
[Epoch 15] Train Loss: 0.000576 - Test Loss: 0.000814 - Train Accuracy: 99.05% - Test Accuracy: 98.85%



Epoch 15: 100%|██████████| 1200/1200 [00:36<00:00, 32.59batch/s]
Test 15: 100%|██████████| 200/200 [00:01<00:00, 156.17batch/s]

	LR:  0.01
[Epoch 16] Train Loss: 0.000542 - Test Loss: 0.000753 - Train Accuracy: 99.12% - Test Accuracy: 98.87%



Epoch 16: 100%|██████████| 1200/1200 [00:36<00:00, 32.96batch/s]
Test 16: 100%|██████████| 200/200 [00:01<00:00, 141.89batch/s]

	LR:  0.01
[Epoch 17] Train Loss: 0.000542 - Test Loss: 0.000753 - Train Accuracy: 99.11% - Test Accuracy: 98.89%



Epoch 17: 100%|██████████| 1200/1200 [00:36<00:00, 33.28batch/s]
Test 17: 100%|██████████| 200/200 [00:01<00:00, 161.78batch/s]


	LR:  0.001
[Epoch 18] Train Loss: 0.000501 - Test Loss: 0.000849 - Train Accuracy: 99.21% - Test Accuracy: 98.89%


Epoch 18: 100%|██████████| 1200/1200 [00:36<00:00, 32.97batch/s]
Test 18: 100%|██████████| 200/200 [00:01<00:00, 131.67batch/s]

	LR:  0.001
[Epoch 19] Train Loss: 0.000345 - Test Loss: 0.000663 - Train Accuracy: 99.45% - Test Accuracy: 99.01%



Epoch 19: 100%|██████████| 1200/1200 [00:34<00:00, 34.80batch/s]
Test 19: 100%|██████████| 200/200 [00:01<00:00, 164.19batch/s]


	LR:  0.001
[Epoch 20] Train Loss: 0.000289 - Test Loss: 0.000623 - Train Accuracy: 99.55% - Test Accuracy: 99.09%


Epoch 20: 100%|██████████| 1200/1200 [00:35<00:00, 33.76batch/s]
Test 20: 100%|██████████| 200/200 [00:01<00:00, 153.95batch/s]

	LR:  0.001
[Epoch 21] Train Loss: 0.000262 - Test Loss: 0.000597 - Train Accuracy: 99.60% - Test Accuracy: 99.03%



Epoch 21: 100%|██████████| 1200/1200 [00:35<00:00, 34.22batch/s]
Test 21: 100%|██████████| 200/200 [00:01<00:00, 145.90batch/s]

	LR:  0.001
[Epoch 22] Train Loss: 0.000248 - Test Loss: 0.000613 - Train Accuracy: 99.59% - Test Accuracy: 99.03%



Epoch 22: 100%|██████████| 1200/1200 [00:34<00:00, 34.47batch/s]
Test 22: 100%|██████████| 200/200 [00:01<00:00, 139.44batch/s]

	LR:  0.001
[Epoch 23] Train Loss: 0.000250 - Test Loss: 0.000632 - Train Accuracy: 99.61% - Test Accuracy: 99.05%



Epoch 23: 100%|██████████| 1200/1200 [00:36<00:00, 33.15batch/s]
Test 23: 100%|██████████| 200/200 [00:01<00:00, 153.88batch/s]

	LR:  0.001
[Epoch 24] Train Loss: 0.000217 - Test Loss: 0.000598 - Train Accuracy: 99.66% - Test Accuracy: 99.06%



Epoch 24: 100%|██████████| 1200/1200 [00:35<00:00, 34.25batch/s]
Test 24: 100%|██████████| 200/200 [00:01<00:00, 162.96batch/s]

	LR:  0.001
[Epoch 25] Train Loss: 0.000254 - Test Loss: 0.000605 - Train Accuracy: 99.60% - Test Accuracy: 99.07%



Epoch 25: 100%|██████████| 1200/1200 [00:34<00:00, 34.48batch/s]
Test 25: 100%|██████████| 200/200 [00:01<00:00, 166.23batch/s]

	LR:  0.001
[Epoch 26] Train Loss: 0.000210 - Test Loss: 0.000623 - Train Accuracy: 99.68% - Test Accuracy: 99.08%



Epoch 26: 100%|██████████| 1200/1200 [00:36<00:00, 32.94batch/s]
Test 26: 100%|██████████| 200/200 [00:01<00:00, 159.78batch/s]

	LR:  0.001
[Epoch 27] Train Loss: 0.000222 - Test Loss: 0.000593 - Train Accuracy: 99.64% - Test Accuracy: 99.07%



Epoch 27: 100%|██████████| 1200/1200 [00:36<00:00, 33.26batch/s]
Test 27: 100%|██████████| 200/200 [00:01<00:00, 151.25batch/s]

	LR:  0.001
[Epoch 28] Train Loss: 0.000208 - Test Loss: 0.000597 - Train Accuracy: 99.67% - Test Accuracy: 99.05%



Epoch 28: 100%|██████████| 1200/1200 [00:37<00:00, 32.09batch/s]
Test 28: 100%|██████████| 200/200 [00:01<00:00, 161.07batch/s]

	LR:  0.001
[Epoch 29] Train Loss: 0.000211 - Test Loss: 0.000591 - Train Accuracy: 99.66% - Test Accuracy: 99.09%



Epoch 29: 100%|██████████| 1200/1200 [00:35<00:00, 33.49batch/s]
Test 29: 100%|██████████| 200/200 [00:01<00:00, 135.95batch/s]

	LR:  0.001
[Epoch 30] Train Loss: 0.000206 - Test Loss: 0.000614 - Train Accuracy: 99.71% - Test Accuracy: 99.07%



Epoch 30: 100%|██████████| 1200/1200 [00:37<00:00, 32.17batch/s]
Test 30: 100%|██████████| 200/200 [00:01<00:00, 133.01batch/s]

	LR:  0.001
[Epoch 31] Train Loss: 0.000200 - Test Loss: 0.000584 - Train Accuracy: 99.71% - Test Accuracy: 99.08%



Epoch 31: 100%|██████████| 1200/1200 [00:36<00:00, 32.47batch/s]
Test 31: 100%|██████████| 200/200 [00:01<00:00, 157.53batch/s]

	LR:  0.001
[Epoch 32] Train Loss: 0.000218 - Test Loss: 0.000591 - Train Accuracy: 99.69% - Test Accuracy: 99.10%



Epoch 32: 100%|██████████| 1200/1200 [00:36<00:00, 32.96batch/s]
Test 32: 100%|██████████| 200/200 [00:01<00:00, 154.03batch/s]

	LR:  0.001
[Epoch 33] Train Loss: 0.000203 - Test Loss: 0.000602 - Train Accuracy: 99.67% - Test Accuracy: 99.14%



Epoch 33: 100%|██████████| 1200/1200 [00:36<00:00, 33.26batch/s]
Test 33: 100%|██████████| 200/200 [00:01<00:00, 148.17batch/s]

	LR:  0.001
[Epoch 34] Train Loss: 0.000199 - Test Loss: 0.000598 - Train Accuracy: 99.68% - Test Accuracy: 99.06%



Epoch 34: 100%|██████████| 1200/1200 [00:36<00:00, 33.24batch/s]
Test 34: 100%|██████████| 200/200 [00:01<00:00, 154.39batch/s]

	LR:  0.001
[Epoch 35] Train Loss: 0.000189 - Test Loss: 0.000562 - Train Accuracy: 99.71% - Test Accuracy: 99.10%



Epoch 35: 100%|██████████| 1200/1200 [00:35<00:00, 33.38batch/s]
Test 35: 100%|██████████| 200/200 [00:01<00:00, 151.44batch/s]

	LR:  0.001
[Epoch 36] Train Loss: 0.000198 - Test Loss: 0.000578 - Train Accuracy: 99.68% - Test Accuracy: 99.05%



Epoch 36: 100%|██████████| 1200/1200 [00:38<00:00, 31.02batch/s]
Test 36: 100%|██████████| 200/200 [00:01<00:00, 133.08batch/s]

	LR:  0.001
[Epoch 37] Train Loss: 0.000193 - Test Loss: 0.000584 - Train Accuracy: 99.72% - Test Accuracy: 99.09%



Epoch 37: 100%|██████████| 1200/1200 [00:34<00:00, 34.36batch/s]
Test 37: 100%|██████████| 200/200 [00:01<00:00, 124.65batch/s]

	LR:  0.001
[Epoch 38] Train Loss: 0.000198 - Test Loss: 0.000585 - Train Accuracy: 99.69% - Test Accuracy: 99.12%



Epoch 38: 100%|██████████| 1200/1200 [00:35<00:00, 34.21batch/s]
Test 38: 100%|██████████| 200/200 [00:01<00:00, 160.40batch/s]

	LR:  0.001
[Epoch 39] Train Loss: 0.000192 - Test Loss: 0.000573 - Train Accuracy: 99.71% - Test Accuracy: 99.13%



Epoch 39: 100%|██████████| 1200/1200 [00:34<00:00, 34.47batch/s]
Test 39: 100%|██████████| 200/200 [00:01<00:00, 159.13batch/s]

	LR:  0.001
[Epoch 40] Train Loss: 0.000181 - Test Loss: 0.000577 - Train Accuracy: 99.71% - Test Accuracy: 99.12%



Epoch 40: 100%|██████████| 1200/1200 [00:36<00:00, 32.56batch/s]
Test 40: 100%|██████████| 200/200 [00:01<00:00, 157.00batch/s]

	LR:  0.0001
[Epoch 41] Train Loss: 0.000179 - Test Loss: 0.000599 - Train Accuracy: 99.72% - Test Accuracy: 99.08%



Epoch 41: 100%|██████████| 1200/1200 [00:40<00:00, 29.38batch/s]
Test 41: 100%|██████████| 200/200 [00:01<00:00, 138.97batch/s]

	LR:  0.0001
[Epoch 42] Train Loss: 0.000179 - Test Loss: 0.000599 - Train Accuracy: 99.75% - Test Accuracy: 99.05%



Epoch 42: 100%|██████████| 1200/1200 [00:40<00:00, 29.44batch/s]
Test 42: 100%|██████████| 200/200 [00:01<00:00, 150.21batch/s]


	LR:  0.0001
[Epoch 43] Train Loss: 0.000174 - Test Loss: 0.000595 - Train Accuracy: 99.76% - Test Accuracy: 99.08%


Epoch 43: 100%|██████████| 1200/1200 [00:38<00:00, 31.03batch/s]
Test 43: 100%|██████████| 200/200 [00:01<00:00, 108.34batch/s]

	LR:  0.0001
[Epoch 44] Train Loss: 0.000186 - Test Loss: 0.000609 - Train Accuracy: 99.72% - Test Accuracy: 99.06%



Epoch 44: 100%|██████████| 1200/1200 [00:42<00:00, 28.55batch/s]
Test 44: 100%|██████████| 200/200 [00:01<00:00, 108.90batch/s]


	LR:  0.0001
[Epoch 45] Train Loss: 0.000171 - Test Loss: 0.000587 - Train Accuracy: 99.75% - Test Accuracy: 99.10%


Epoch 45: 100%|██████████| 1200/1200 [00:38<00:00, 31.02batch/s]
Test 45: 100%|██████████| 200/200 [00:01<00:00, 140.14batch/s]

	LR:  0.0001
[Epoch 46] Train Loss: 0.000187 - Test Loss: 0.000595 - Train Accuracy: 99.74% - Test Accuracy: 99.10%



Epoch 46: 100%|██████████| 1200/1200 [00:41<00:00, 29.03batch/s]
Test 46: 100%|██████████| 200/200 [00:01<00:00, 106.55batch/s]

	LR:  0.0001
[Epoch 47] Train Loss: 0.000196 - Test Loss: 0.000589 - Train Accuracy: 99.71% - Test Accuracy: 99.08%



Epoch 47: 100%|██████████| 1200/1200 [00:40<00:00, 29.29batch/s]
Test 47: 100%|██████████| 200/200 [00:01<00:00, 152.57batch/s]

	LR:  1e-05
[Epoch 48] Train Loss: 0.000166 - Test Loss: 0.000585 - Train Accuracy: 99.76% - Test Accuracy: 99.14%



Epoch 48: 100%|██████████| 1200/1200 [00:38<00:00, 30.99batch/s]
Test 48: 100%|██████████| 200/200 [00:01<00:00, 150.85batch/s]

	LR:  1e-05
[Epoch 49] Train Loss: 0.000175 - Test Loss: 0.000596 - Train Accuracy: 99.72% - Test Accuracy: 99.05%



Epoch 49: 100%|██████████| 1200/1200 [00:36<00:00, 33.15batch/s]
Test 49: 100%|██████████| 200/200 [00:01<00:00, 119.96batch/s]

	LR:  1e-05
[Epoch 50] Train Loss: 0.000181 - Test Loss: 0.000572 - Train Accuracy: 99.72% - Test Accuracy: 99.14%

BEST TEST ACCURACY:  99.14  in epoch  32





In [19]:
####################################################################
# Load best weights
####################################################################

trainer.get_model()

  self.net.load_state_dict(torch.load(self.model_path))
Test 49: 100%|██████████| 200/200 [00:01<00:00, 136.00batch/s]

Final best acc:  99.14



