# Data augmentation

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import torch
import torch.nn as nn
import multiprocessing
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from main import MNIST_dataset, MNIST_trainer

In [10]:
print("Torch version: ", torch.__version__)

####################################################################
# Set Device
####################################################################

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

Torch version:  2.5.1+cu124
Device:  cuda


In [11]:
####################################################################
# DataLoader Class
####################################################################
da = transforms.Compose(
    [
        transforms.RandomRotation((10)),
        transforms.ElasticTransform(alpha=90.0),
        transforms.ToTensor(),
    ]
)
train_dataset = MNIST_dataset(partition="train", da_transform=da)
test_dataset = MNIST_dataset(partition="test")

batch_size = 100
num_workers = multiprocessing.cpu_count() - 1
print("Num workers", num_workers)

train_dataloader = DataLoader(
    train_dataset, batch_size, shuffle=True, num_workers=num_workers
)
test_dataloader = DataLoader(
    test_dataset, batch_size, shuffle=False, num_workers=num_workers
)


Loading MNIST  train  Dataset...
	Total Len.:  60000 
 --------------------------------------------------

Loading MNIST  test  Dataset...
	Total Len.:  10000 
 --------------------------------------------------
Num workers 11


In [12]:
####################################################################
# Neural Network Class
####################################################################


# Creating our Neural Network - Fully Connected
class Net(nn.Module):
    def __init__(
        self,
        sizes=[[784, 1024], [1024, 1024], [1024, 1024], [1024, 10]],
        criterion=None,
    ):
        super(Net, self).__init__()

        self.layers = nn.ModuleList()

        for i in range(len(sizes) - 1):
            dims = sizes[i]
            self.layers.append(nn.Linear(dims[0], dims[1]))
            self.layers.append(nn.BatchNorm1d(dims[1]))
            self.layers.append(nn.ReLU())

        dims = sizes[-1]
        self.classifier = nn.Linear(dims[0], dims[1])

        self.criterion = criterion

    def forward(self, x, y=None):
        for layer in self.layers:
            x = layer(x)
        x = self.classifier(x)

        if y != None:
            loss = self.criterion(x, y)
            return loss, x
        return x


####################################################################
# Training settings
####################################################################

# Training hyperparameters
criterion = nn.CrossEntropyLoss()
# Instantiating the network and printing its architecture
num_classes = 10
net = Net(
    sizes=[[784, 1024], [1024, 1024], [1024, 1024], [1024, num_classes]],
    criterion=criterion,
)
print(net)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print("Params: ", count_parameters(net))
optimizer = optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-6, momentum=0.9)
epochs = 50

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, "min", patience=0, factor=0.9, threshold=0.001, cooldown=1, verbose=True
)

trainer = MNIST_trainer(
    net,
    train_dataloader,
    test_dataloader,
    optimizer,
    criterion,
    epochs,
    device,
    scheduler=scheduler,
    model_path="models/da5.pt",
)

Net(
  (layers): ModuleList(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=1024, out_features=1024, bias=True)
    (7): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (classifier): Linear(in_features=1024, out_features=10, bias=True)
  (criterion): CrossEntropyLoss()
)
Params:  2919434


In [13]:
####################################################################
# Training
####################################################################

trainer.train()


---- Start Training ----


Epoch 0: 100%|██████████| 600/600 [00:42<00:00, 14.27batch/s]
Test 0: 100%|██████████| 100/100 [00:00<00:00, 117.51batch/s]


	LR:  0.1
[Epoch 1] Train Loss: 0.008046 - Test Loss: 0.000915 - Train Accuracy: 77.67% - Test Accuracy: 97.01%


Epoch 1: 100%|██████████| 600/600 [00:44<00:00, 13.59batch/s]
Test 1: 100%|██████████| 100/100 [00:00<00:00, 122.16batch/s]


	LR:  0.1
[Epoch 2] Train Loss: 0.003869 - Test Loss: 0.000657 - Train Accuracy: 87.69% - Test Accuracy: 97.70%


Epoch 2: 100%|██████████| 600/600 [00:44<00:00, 13.63batch/s]
Test 2: 100%|██████████| 100/100 [00:00<00:00, 118.33batch/s]

	LR:  0.1
[Epoch 3] Train Loss: 0.003100 - Test Loss: 0.000544 - Train Accuracy: 90.16% - Test Accuracy: 98.25%



Epoch 3: 100%|██████████| 600/600 [00:44<00:00, 13.61batch/s]
Test 3: 100%|██████████| 100/100 [00:00<00:00, 114.19batch/s]

	LR:  0.1
[Epoch 4] Train Loss: 0.002736 - Test Loss: 0.000517 - Train Accuracy: 91.36% - Test Accuracy: 98.29%



Epoch 4: 100%|██████████| 600/600 [00:43<00:00, 13.70batch/s]
Test 4: 100%|██████████| 100/100 [00:00<00:00, 117.82batch/s]


	LR:  0.1
[Epoch 5] Train Loss: 0.002538 - Test Loss: 0.000458 - Train Accuracy: 91.97% - Test Accuracy: 98.33%


Epoch 5: 100%|██████████| 600/600 [00:44<00:00, 13.44batch/s]
Test 5: 100%|██████████| 100/100 [00:00<00:00, 121.75batch/s]


	LR:  0.1
[Epoch 6] Train Loss: 0.002325 - Test Loss: 0.000372 - Train Accuracy: 92.61% - Test Accuracy: 98.73%


Epoch 6: 100%|██████████| 600/600 [00:43<00:00, 13.64batch/s]
Test 6: 100%|██████████| 100/100 [00:00<00:00, 121.23batch/s]

	LR:  0.09000000000000001
[Epoch 7] Train Loss: 0.002135 - Test Loss: 0.000394 - Train Accuracy: 93.30% - Test Accuracy: 98.74%



Epoch 7: 100%|██████████| 600/600 [00:44<00:00, 13.51batch/s]
Test 7: 100%|██████████| 100/100 [00:01<00:00, 97.88batch/s]


	LR:  0.09000000000000001
[Epoch 8] Train Loss: 0.001971 - Test Loss: 0.000337 - Train Accuracy: 93.72% - Test Accuracy: 98.96%


Epoch 8: 100%|██████████| 600/600 [00:43<00:00, 13.74batch/s]
Test 8: 100%|██████████| 100/100 [00:00<00:00, 123.23batch/s]

	LR:  0.08100000000000002
[Epoch 9] Train Loss: 0.001876 - Test Loss: 0.000349 - Train Accuracy: 94.03% - Test Accuracy: 98.94%



Epoch 9: 100%|██████████| 600/600 [00:44<00:00, 13.42batch/s]
Test 9: 100%|██████████| 100/100 [00:00<00:00, 115.05batch/s]

	LR:  0.08100000000000002
[Epoch 10] Train Loss: 0.001798 - Test Loss: 0.000284 - Train Accuracy: 94.25% - Test Accuracy: 99.04%



Epoch 10: 100%|██████████| 600/600 [00:43<00:00, 13.66batch/s]
Test 10: 100%|██████████| 100/100 [00:00<00:00, 124.90batch/s]

	LR:  0.07290000000000002
[Epoch 11] Train Loss: 0.001728 - Test Loss: 0.000298 - Train Accuracy: 94.51% - Test Accuracy: 98.96%



Epoch 11: 100%|██████████| 600/600 [00:42<00:00, 14.05batch/s]
Test 11: 100%|██████████| 100/100 [00:00<00:00, 128.90batch/s]

	LR:  0.07290000000000002
[Epoch 12] Train Loss: 0.001649 - Test Loss: 0.000293 - Train Accuracy: 94.75% - Test Accuracy: 99.00%



Epoch 12: 100%|██████████| 600/600 [00:42<00:00, 13.98batch/s]
Test 12: 100%|██████████| 100/100 [00:00<00:00, 117.23batch/s]

	LR:  0.06561000000000002
[Epoch 13] Train Loss: 0.001577 - Test Loss: 0.000316 - Train Accuracy: 94.83% - Test Accuracy: 98.95%



Epoch 13: 100%|██████████| 600/600 [00:43<00:00, 13.85batch/s]
Test 13: 100%|██████████| 100/100 [00:00<00:00, 120.95batch/s]


	LR:  0.06561000000000002
[Epoch 14] Train Loss: 0.001515 - Test Loss: 0.000280 - Train Accuracy: 95.12% - Test Accuracy: 99.17%


Epoch 14: 100%|██████████| 600/600 [00:42<00:00, 14.23batch/s]
Test 14: 100%|██████████| 100/100 [00:00<00:00, 125.02batch/s]

	LR:  0.06561000000000002
[Epoch 15] Train Loss: 0.001481 - Test Loss: 0.000267 - Train Accuracy: 95.31% - Test Accuracy: 99.15%



Epoch 15: 100%|██████████| 600/600 [00:42<00:00, 13.97batch/s]
Test 15: 100%|██████████| 100/100 [00:00<00:00, 125.60batch/s]

	LR:  0.05904900000000002
[Epoch 16] Train Loss: 0.001443 - Test Loss: 0.000283 - Train Accuracy: 95.48% - Test Accuracy: 99.05%



Epoch 16: 100%|██████████| 600/600 [00:43<00:00, 13.90batch/s]
Test 16: 100%|██████████| 100/100 [00:00<00:00, 117.11batch/s]

	LR:  0.05904900000000002
[Epoch 17] Train Loss: 0.001392 - Test Loss: 0.000264 - Train Accuracy: 95.51% - Test Accuracy: 99.13%



Epoch 17: 100%|██████████| 600/600 [00:42<00:00, 14.25batch/s]
Test 17: 100%|██████████| 100/100 [00:00<00:00, 122.72batch/s]

	LR:  0.05904900000000002
[Epoch 18] Train Loss: 0.001367 - Test Loss: 0.000256 - Train Accuracy: 95.58% - Test Accuracy: 99.09%



Epoch 18: 100%|██████████| 600/600 [00:42<00:00, 14.15batch/s]
Test 18: 100%|██████████| 100/100 [00:00<00:00, 124.42batch/s]

	LR:  0.05314410000000002
[Epoch 19] Train Loss: 0.001332 - Test Loss: 0.000259 - Train Accuracy: 95.70% - Test Accuracy: 99.15%



Epoch 19: 100%|██████████| 600/600 [00:42<00:00, 13.96batch/s]
Test 19: 100%|██████████| 100/100 [00:00<00:00, 126.58batch/s]

	LR:  0.05314410000000002
[Epoch 20] Train Loss: 0.001310 - Test Loss: 0.000263 - Train Accuracy: 95.76% - Test Accuracy: 99.15%



Epoch 20: 100%|██████████| 600/600 [00:44<00:00, 13.45batch/s]
Test 20: 100%|██████████| 100/100 [00:00<00:00, 127.78batch/s]

	LR:  0.05314410000000002
[Epoch 21] Train Loss: 0.001259 - Test Loss: 0.000246 - Train Accuracy: 95.98% - Test Accuracy: 99.19%



Epoch 21: 100%|██████████| 600/600 [00:43<00:00, 13.94batch/s]
Test 21: 100%|██████████| 100/100 [00:00<00:00, 126.37batch/s]

	LR:  0.05314410000000002
[Epoch 22] Train Loss: 0.001267 - Test Loss: 0.000231 - Train Accuracy: 96.02% - Test Accuracy: 99.18%



Epoch 22: 100%|██████████| 600/600 [00:43<00:00, 13.75batch/s]
Test 22: 100%|██████████| 100/100 [00:00<00:00, 104.78batch/s]


	LR:  0.05314410000000002
[Epoch 23] Train Loss: 0.001185 - Test Loss: 0.000223 - Train Accuracy: 96.24% - Test Accuracy: 99.24%


Epoch 23: 100%|██████████| 600/600 [00:43<00:00, 13.72batch/s]
Test 23: 100%|██████████| 100/100 [00:00<00:00, 122.54batch/s]

	LR:  0.04782969000000002
[Epoch 24] Train Loss: 0.001230 - Test Loss: 0.000244 - Train Accuracy: 96.10% - Test Accuracy: 99.24%



Epoch 24: 100%|██████████| 600/600 [00:43<00:00, 13.90batch/s]
Test 24: 100%|██████████| 100/100 [00:00<00:00, 122.98batch/s]

	LR:  0.04782969000000002
[Epoch 25] Train Loss: 0.001184 - Test Loss: 0.000235 - Train Accuracy: 96.22% - Test Accuracy: 99.19%



Epoch 25: 100%|██████████| 600/600 [00:43<00:00, 13.79batch/s]
Test 25: 100%|██████████| 100/100 [00:00<00:00, 118.36batch/s]

	LR:  0.043046721000000024
[Epoch 26] Train Loss: 0.001191 - Test Loss: 0.000233 - Train Accuracy: 96.20% - Test Accuracy: 99.26%



Epoch 26: 100%|██████████| 600/600 [00:43<00:00, 13.91batch/s]
Test 26: 100%|██████████| 100/100 [00:00<00:00, 108.72batch/s]


	LR:  0.043046721000000024
[Epoch 27] Train Loss: 0.001166 - Test Loss: 0.000230 - Train Accuracy: 96.31% - Test Accuracy: 99.23%


Epoch 27: 100%|██████████| 600/600 [00:44<00:00, 13.51batch/s]
Test 27: 100%|██████████| 100/100 [00:00<00:00, 123.60batch/s]

	LR:  0.043046721000000024
[Epoch 28] Train Loss: 0.001141 - Test Loss: 0.000218 - Train Accuracy: 96.32% - Test Accuracy: 99.28%



Epoch 28: 100%|██████████| 600/600 [00:43<00:00, 13.83batch/s]
Test 28: 100%|██████████| 100/100 [00:00<00:00, 124.95batch/s]

	LR:  0.03874204890000002
[Epoch 29] Train Loss: 0.001118 - Test Loss: 0.000231 - Train Accuracy: 96.35% - Test Accuracy: 99.20%



Epoch 29: 100%|██████████| 600/600 [00:42<00:00, 14.05batch/s]
Test 29: 100%|██████████| 100/100 [00:00<00:00, 121.22batch/s]


	LR:  0.03874204890000002
[Epoch 30] Train Loss: 0.001123 - Test Loss: 0.000214 - Train Accuracy: 96.42% - Test Accuracy: 99.32%


Epoch 30: 100%|██████████| 600/600 [00:42<00:00, 14.13batch/s]
Test 30: 100%|██████████| 100/100 [00:00<00:00, 127.53batch/s]

	LR:  0.03486784401000002
[Epoch 31] Train Loss: 0.001067 - Test Loss: 0.000216 - Train Accuracy: 96.56% - Test Accuracy: 99.28%



Epoch 31: 100%|██████████| 600/600 [00:42<00:00, 14.02batch/s]
Test 31: 100%|██████████| 100/100 [00:00<00:00, 112.64batch/s]

	LR:  0.03486784401000002
[Epoch 32] Train Loss: 0.001041 - Test Loss: 0.000228 - Train Accuracy: 96.64% - Test Accuracy: 99.22%



Epoch 32: 100%|██████████| 600/600 [00:43<00:00, 13.75batch/s]
Test 32: 100%|██████████| 100/100 [00:00<00:00, 125.82batch/s]

	LR:  0.03486784401000002
[Epoch 33] Train Loss: 0.001016 - Test Loss: 0.000211 - Train Accuracy: 96.78% - Test Accuracy: 99.32%



Epoch 33: 100%|██████████| 600/600 [00:43<00:00, 13.75batch/s]
Test 33: 100%|██████████| 100/100 [00:00<00:00, 124.53batch/s]

	LR:  0.03138105960900001
[Epoch 34] Train Loss: 0.001041 - Test Loss: 0.000248 - Train Accuracy: 96.72% - Test Accuracy: 99.20%



Epoch 34: 100%|██████████| 600/600 [00:42<00:00, 13.98batch/s]
Test 34: 100%|██████████| 100/100 [00:00<00:00, 123.42batch/s]

	LR:  0.03138105960900001
[Epoch 35] Train Loss: 0.001001 - Test Loss: 0.000206 - Train Accuracy: 96.79% - Test Accuracy: 99.31%



Epoch 35: 100%|██████████| 600/600 [00:43<00:00, 13.95batch/s]
Test 35: 100%|██████████| 100/100 [00:00<00:00, 128.88batch/s]

	LR:  0.028242953648100012
[Epoch 36] Train Loss: 0.000990 - Test Loss: 0.000236 - Train Accuracy: 96.89% - Test Accuracy: 99.21%



Epoch 36: 100%|██████████| 600/600 [00:45<00:00, 13.27batch/s]
Test 36: 100%|██████████| 100/100 [00:00<00:00, 117.33batch/s]

	LR:  0.028242953648100012
[Epoch 37] Train Loss: 0.000975 - Test Loss: 0.000214 - Train Accuracy: 96.91% - Test Accuracy: 99.30%



Epoch 37: 100%|██████████| 600/600 [00:44<00:00, 13.61batch/s]
Test 37: 100%|██████████| 100/100 [00:00<00:00, 119.36batch/s]

	LR:  0.025418658283290013
[Epoch 38] Train Loss: 0.000984 - Test Loss: 0.000217 - Train Accuracy: 96.82% - Test Accuracy: 99.23%



Epoch 38: 100%|██████████| 600/600 [00:43<00:00, 13.88batch/s]
Test 38: 100%|██████████| 100/100 [00:00<00:00, 117.94batch/s]

	LR:  0.025418658283290013
[Epoch 39] Train Loss: 0.000971 - Test Loss: 0.000194 - Train Accuracy: 96.87% - Test Accuracy: 99.31%



Epoch 39: 100%|██████████| 600/600 [00:43<00:00, 13.92batch/s]
Test 39: 100%|██████████| 100/100 [00:00<00:00, 124.33batch/s]

	LR:  0.022876792454961013
[Epoch 40] Train Loss: 0.000942 - Test Loss: 0.000205 - Train Accuracy: 97.00% - Test Accuracy: 99.30%



Epoch 40: 100%|██████████| 600/600 [00:43<00:00, 13.66batch/s]
Test 40: 100%|██████████| 100/100 [00:00<00:00, 124.24batch/s]

	LR:  0.022876792454961013
[Epoch 41] Train Loss: 0.000896 - Test Loss: 0.000215 - Train Accuracy: 97.08% - Test Accuracy: 99.31%



Epoch 41: 100%|██████████| 600/600 [00:43<00:00, 13.82batch/s]
Test 41: 100%|██████████| 100/100 [00:00<00:00, 129.38batch/s]

	LR:  0.020589113209464913
[Epoch 42] Train Loss: 0.000924 - Test Loss: 0.000203 - Train Accuracy: 97.05% - Test Accuracy: 99.29%



Epoch 42: 100%|██████████| 600/600 [00:43<00:00, 13.89batch/s]
Test 42: 100%|██████████| 100/100 [00:00<00:00, 117.42batch/s]

	LR:  0.020589113209464913
[Epoch 43] Train Loss: 0.000908 - Test Loss: 0.000193 - Train Accuracy: 97.03% - Test Accuracy: 99.39%



Epoch 43: 100%|██████████| 600/600 [00:42<00:00, 14.11batch/s]
Test 43: 100%|██████████| 100/100 [00:00<00:00, 124.04batch/s]

	LR:  0.01853020188851842
[Epoch 44] Train Loss: 0.000891 - Test Loss: 0.000204 - Train Accuracy: 97.14% - Test Accuracy: 99.37%



Epoch 44: 100%|██████████| 600/600 [00:42<00:00, 14.14batch/s]
Test 44: 100%|██████████| 100/100 [00:00<00:00, 122.56batch/s]

	LR:  0.01853020188851842
[Epoch 45] Train Loss: 0.000888 - Test Loss: 0.000199 - Train Accuracy: 97.16% - Test Accuracy: 99.33%



Epoch 45: 100%|██████████| 600/600 [00:42<00:00, 14.13batch/s]
Test 45: 100%|██████████| 100/100 [00:00<00:00, 123.04batch/s]

	LR:  0.01667718169966658
[Epoch 46] Train Loss: 0.000883 - Test Loss: 0.000210 - Train Accuracy: 97.13% - Test Accuracy: 99.32%



Epoch 46: 100%|██████████| 600/600 [00:42<00:00, 13.97batch/s]
Test 46: 100%|██████████| 100/100 [00:00<00:00, 103.71batch/s]

	LR:  0.01667718169966658
[Epoch 47] Train Loss: 0.000873 - Test Loss: 0.000185 - Train Accuracy: 97.17% - Test Accuracy: 99.46%



Epoch 47: 100%|██████████| 600/600 [00:42<00:00, 14.14batch/s]
Test 47: 100%|██████████| 100/100 [00:00<00:00, 126.99batch/s]

	LR:  0.01667718169966658
[Epoch 48] Train Loss: 0.000882 - Test Loss: 0.000184 - Train Accuracy: 97.23% - Test Accuracy: 99.35%



Epoch 48: 100%|██████████| 600/600 [00:43<00:00, 13.95batch/s]
Test 48: 100%|██████████| 100/100 [00:00<00:00, 126.68batch/s]

	LR:  0.015009463529699923
[Epoch 49] Train Loss: 0.000856 - Test Loss: 0.000187 - Train Accuracy: 97.21% - Test Accuracy: 99.38%



Epoch 49: 100%|██████████| 600/600 [00:43<00:00, 13.81batch/s]
Test 49: 100%|██████████| 100/100 [00:00<00:00, 123.78batch/s]

	LR:  0.015009463529699923
[Epoch 50] Train Loss: 0.000868 - Test Loss: 0.000198 - Train Accuracy: 97.16% - Test Accuracy: 99.33%

BEST TEST ACCURACY:  99.46  in epoch  46





In [14]:
####################################################################
# Load best weights
####################################################################

trainer.get_model()

Test 49: 100%|██████████| 100/100 [00:00<00:00, 123.71batch/s]

Final best acc:  99.46



