In [None]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 4, 3, padding=1),  # 1*3*3*4 + 4 = 40 params
            nn.ReLU(),
            nn.BatchNorm2d(4),              # 8 params
            nn.MaxPool2d(2, 2)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(4, 8, 3, padding=1),  # 4*3*3*8 + 8 = 296 params
            nn.ReLU(),
            nn.BatchNorm2d(8),              # 16 params
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.1)
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(8, 12, 3, padding=1), # 8*3*3*12 + 12 = 876 params
            nn.ReLU(),
            nn.BatchNorm2d(12),             # 24 params
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.1)
        )
        
        # 12 channels * 3 * 3 = 108 neurons after three max pools (28->14->7->3)
        self.fc1 = nn.Linear(12 * 3 * 3, 10)  # 108*10 + 10 = 1090 params
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(-1, 12 * 3 * 3)  # Flatten
        x = F.dropout(x, p=0.1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=1)

In [None]:
#!pip install torchinfo
from torchinfo import summary
use_cuda = torch.cuda.is_available()
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if use_cuda else "cpu")
model = Net().to(device)
# Create a dummy input tensor on the correct device
summary(model, input_size=(1, 1, 28, 28), device=device)

In [None]:


torch.manual_seed(1456)
batch_size = 512

kwargs = {'num_workers': 0, 'pin_memory': True} if device.type in ["cuda", "mps"] else {}
# train_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=True, download=True,
#                     transform=transforms.Compose([
#                         transforms.ToTensor(),
#                         transforms.Normalize((0.1307,), (0.3081,))
#                     ])),
#     batch_size=batch_size, shuffle=True, **kwargs)
# test_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=False, transform=transforms.Compose([
#                         transforms.ToTensor(),
#                         transforms.Normalize((0.1307,), (0.3081,))
#                     ])),
#     batch_size=batch_size, shuffle=True, **kwargs)


import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2
import random
# Define the augmentation pipeline
train_transforms = A.Compose([
    A.ShiftScaleRotate(
        shift_limit=0.0625,
        scale_limit=0.1,
        rotate_limit=15,
        p=0.7,
        border_mode=cv2.BORDER_CONSTANT,
        value=0
    ),
    A.GridDistortion(num_steps=5, distort_limit=0.2, p=0.3),
    A.GaussNoise(var_limit=(5.0, 30.0), p=0.3),
    A.Perspective(scale=(0.05, 0.1), p=0.3, keep_size=True, pad_mode=cv2.BORDER_CONSTANT, pad_val=0),

    A.ElasticTransform(
         alpha=1.0,
         sigma=10.0,
         alpha_affine=None,  # Set to None as required by newer versions
         interpolation=cv2.INTER_LINEAR,
         border_mode=cv2.BORDER_CONSTANT,
         value=0,
         p=0.3
    ),
    
    # CoarseDropout as alternative to regular dropout
    A.CoarseDropout(
        max_holes=2,
        max_height=8,
        max_width=8,
        min_holes=1,
        fill_value=0,
        p=0.2
    ),

    A.Normalize(
        mean=[0.1307],
        std=[0.3081],
    ),
    ToTensorV2(),
])

# train_transforms = A.Compose([
#     # Essential transforms for MNIST
#     A.ShiftScaleRotate(
#         shift_limit=0.0625,      # Increased slightly
#         scale_limit=0.1,     # Increased slightly
#         rotate_limit=15,      # Increased rotation range
#         p=0.8,               # Increased probability
#         border_mode=cv2.BORDER_CONSTANT,
#         value=0
#     ),
    
#     # Remove GridDistortion as it might be too aggressive for MNIST
    
#     # Modified noise parameters
#     A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
    
#     # Added more relevant transforms
#     A.RandomBrightnessContrast(
#         brightness_limit=0.2,
#         contrast_limit=0.2,
#         p=0.2
#     ),
    
#     # Correct parameters for ElasticTransform
#     A.ElasticTransform(
#         alpha=1.0,
#         sigma=10.0,
#         alpha_affine=None,  # Set to None as required by newer versions
#         interpolation=cv2.INTER_LINEAR,
#         border_mode=cv2.BORDER_CONSTANT,
#         value=0,
#         p=0.3
#     ),
    
#     A.Normalize(
#         mean=[0.1307],
#         std=[0.3081],
#     ),
#     ToTensorV2(),
# ])

# Custom Dataset class to work with Albumentations
class MNISTAlbumentations(datasets.MNIST):
    def __init__(self, root, train=True, download=True, transform=None):
        super().__init__(root, train=train, download=download, transform=None)
        self.transform = transform
        
    def __getitem__(self, idx):
        img, label = self.data[idx], self.targets[idx]
        
        # Convert to numpy array and add channel dimension
        img = np.array(img)
        img = np.expand_dims(img, axis=-1)  # Add channel dimension for Albumentations
        
        if self.transform is not None:
            transformed = self.transform(image=img)
            img = transformed["image"]
            
        return img, label


# Update the data loaders
train_loader = torch.utils.data.DataLoader(
    MNISTAlbumentations('../data', train=True, download=True, transform=train_transforms),
    batch_size=batch_size, shuffle=True, 
    **kwargs)

# Test transforms (only normalization, no augmentation)
test_transforms = A.Compose([
    A.Normalize(
        mean=[0.1307],
        std=[0.3081],
    ),
    ToTensorV2(),
])

test_loader = torch.utils.data.DataLoader(
    MNISTAlbumentations('../data', train=False, transform=test_transforms),
    batch_size=batch_size, shuffle=True,  
    **kwargs)

# Optional: Visualization function to check augmentations
def visualize_augmentations(dataset, idx=0, samples=5):
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(20, 4))
    for i in range(samples):
        data = dataset[idx][0]
        if isinstance(data, torch.Tensor):
            data = data.numpy()
        if data.shape[0] == 1:  # If channels first, move to last
            data = np.transpose(data, (1, 2, 0))
        plt.subplot(1, samples, i + 1)
        plt.imshow(data.squeeze(), cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.show()



In [None]:
# Uncomment to visualize augmentations
visualize_augmentations(train_loader.dataset)

# # print number of samples in train and test dataset
print(f"Number of samples in train dataset: {len(train_loader.dataset)}")
print(f"Number of samples in test dataset: {len(test_loader.dataset)}")


In [None]:
from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch, scheduler=None):
    model.train()
    pbar = tqdm(train_loader)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()
        pbar.set_description(desc= f'loss={loss.item()} batch_id={batch_idx}')


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

#### Option 1: SGD with Momentum

In [None]:
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)


for epoch in range(1, 14):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=0.523768961429596 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.69it/s]  

Test set: Average loss: 0.1979, Accuracy: 9418/10000 (94.18%)

loss=0.33777952194213867 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.81it/s]

Test set: Average loss: 0.1303, Accuracy: 9608/10000 (96.08%)

loss=0.39691492915153503 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s]

Test set: Average loss: 0.1093, Accuracy: 9699/10000 (96.99%)

loss=0.37994691729545593 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.81it/s]

Test set: Average loss: 0.0951, Accuracy: 9715/10000 (97.15%)

loss=0.2767679989337921 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.84it/s] 

Test set: Average loss: 0.0900, Accuracy: 9723/10000 (97.23%)

loss=0.42271316051483154 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s]

Test set: Average loss: 0.0809, Accuracy: 9764/10000 (97.64%)

loss=0.38566383719444275 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.83it/s]

Test set: Average loss: 0.0795, Accuracy: 9769/10000 (97.69%)

loss=0.2130003720521927 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s] 

Test set: Average loss: 0.0763, Accuracy: 9757/10000 (97.57%)

loss=0.3804439604282379 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.80it/s] 

Test set: Average loss: 0.0757, Accuracy: 9755/10000 (97.55%)

loss=0.31460073590278625 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.85it/s]

Test set: Average loss: 0.0647, Accuracy: 9785/10000 (97.85%)

loss=0.15340779721736908 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.72it/s]

Test set: Average loss: 0.0652, Accuracy: 9798/10000 (97.98%)

loss=0.3887993395328522 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.80it/s] 

Test set: Average loss: 0.0667, Accuracy: 9796/10000 (97.96%)

loss=0.20264939963817596 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s]

Test set: Average loss: 0.0633, Accuracy: 9810/10000 (98.10%)

"""


#### Option 2: Adam Optimizer

In [None]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=0.7975637316703796 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.66it/s]

Test set: Average loss: 0.3636, Accuracy: 9021/10000 (90.21%)

loss=0.48072579503059387 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.71it/s]

Test set: Average loss: 0.1922, Accuracy: 9454/10000 (94.54%)

loss=0.447536438703537 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.63it/s]  

Test set: Average loss: 0.1552, Accuracy: 9551/10000 (95.51%)

loss=0.5516005158424377 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.74it/s] 

Test set: Average loss: 0.1295, Accuracy: 9623/10000 (96.23%)

loss=0.4160817563533783 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s] 

Test set: Average loss: 0.1199, Accuracy: 9633/10000 (96.33%)

loss=0.41406330466270447 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.69it/s]

Test set: Average loss: 0.1133, Accuracy: 9647/10000 (96.47%)

loss=0.301770955324173 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.78it/s]  

Test set: Average loss: 0.1034, Accuracy: 9690/10000 (96.90%)

loss=0.20516985654830933 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.71it/s]

Test set: Average loss: 0.0977, Accuracy: 9723/10000 (97.23%)

loss=0.6024749875068665 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.72it/s] 

Test set: Average loss: 0.1029, Accuracy: 9685/10000 (96.85%)


"""

#### Option 3: AdamW Optimizer

In [None]:
model = Net().to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)

for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=0.8417987823486328 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.64it/s]

Test set: Average loss: 0.3707, Accuracy: 9013/10000 (90.13%)

loss=0.44986605644226074 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s]

Test set: Average loss: 0.1989, Accuracy: 9424/10000 (94.24%)

loss=0.4296775758266449 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s] 

Test set: Average loss: 0.1562, Accuracy: 9558/10000 (95.58%)

loss=0.6638230681419373 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.78it/s] 

Test set: Average loss: 0.1328, Accuracy: 9602/10000 (96.02%)

loss=0.4681873619556427 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s] 

Test set: Average loss: 0.1234, Accuracy: 9619/10000 (96.19%)

loss=0.44057831168174744 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.68it/s]

Test set: Average loss: 0.1093, Accuracy: 9680/10000 (96.80%)

loss=0.20165415108203888 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s]

Test set: Average loss: 0.1041, Accuracy: 9684/10000 (96.84%)

loss=0.3458675444126129 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s] 

Test set: Average loss: 0.0951, Accuracy: 9717/10000 (97.17%)

loss=0.3471454679965973 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.74it/s] 

Test set: Average loss: 0.0978, Accuracy: 9697/10000 (96.97%)

"""

model = Net().to(device)

optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=1e-4)

for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=0.4291648864746094 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.65it/s] 

Test set: Average loss: 0.1278, Accuracy: 9611/10000 (96.11%)

loss=0.34206056594848633 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s]

Test set: Average loss: 0.1087, Accuracy: 9655/10000 (96.55%)

loss=0.23918040096759796 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.67it/s]

Test set: Average loss: 0.0967, Accuracy: 9680/10000 (96.80%)

loss=0.271117240190506 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.71it/s]  

Test set: Average loss: 0.0756, Accuracy: 9761/10000 (97.61%)

loss=0.5112438797950745 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.78it/s] 

Test set: Average loss: 0.0997, Accuracy: 9659/10000 (96.59%)

loss=0.2769640386104584 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.80it/s] 

Test set: Average loss: 0.0741, Accuracy: 9770/10000 (97.70%)

loss=0.3575074374675751 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.72it/s] 

Test set: Average loss: 0.0723, Accuracy: 9774/10000 (97.74%)

loss=0.3100968301296234 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s] 

Test set: Average loss: 0.0778, Accuracy: 9753/10000 (97.53%)

loss=0.29725849628448486 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.78it/s]

Test set: Average loss: 0.0745, Accuracy: 9778/10000 (97.78%)


"""


#### Option 4: RMSprop Optimizer

In [None]:
model = Net().to(device)
optimizer = optim.RMSprop(model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=0.6785914301872253 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.65it/s]

Test set: Average loss: 0.3590, Accuracy: 8925/10000 (89.25%)

loss=0.45515215396881104 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.70it/s]

Test set: Average loss: 0.2337, Accuracy: 9300/10000 (93.00%)

loss=0.6533524394035339 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.71it/s] 

Test set: Average loss: 0.1720, Accuracy: 9494/10000 (94.94%)

loss=0.63240647315979 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s]   

Test set: Average loss: 0.1292, Accuracy: 9632/10000 (96.32%)

loss=0.5105965733528137 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s] 

Test set: Average loss: 0.1237, Accuracy: 9606/10000 (96.06%)

loss=0.28851065039634705 batch_id=117: 100%|██████████| 118/118 [00:21<00:00,  5.58it/s]

Test set: Average loss: 0.1153, Accuracy: 9636/10000 (96.36%)

loss=0.36825188994407654 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s]

Test set: Average loss: 0.1034, Accuracy: 9659/10000 (96.59%)

loss=0.27222302556037903 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.68it/s]

Test set: Average loss: 0.0979, Accuracy: 9690/10000 (96.90%)

loss=0.3961434066295624 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.72it/s] 

Test set: Average loss: 0.1065, Accuracy: 9688/10000 (96.88%)

"""

#### Option 5: Lamb Optimizer

In [None]:
model = Net().to(device)
from torch_optimizer import Lamb
optimizer = Lamb(model.parameters(), lr=0.01, weight_decay=1e-4)

for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=0.5342070460319519 batch_id=117: 100%|██████████| 118/118 [00:23<00:00,  5.09it/s] 

Test set: Average loss: 0.2091, Accuracy: 9390/10000 (93.90%)

loss=0.29361119866371155 batch_id=117: 100%|██████████| 118/118 [00:22<00:00,  5.21it/s]

Test set: Average loss: 0.1253, Accuracy: 9619/10000 (96.19%)

loss=0.27645498514175415 batch_id=117: 100%|██████████| 118/118 [00:22<00:00,  5.22it/s]

Test set: Average loss: 0.0946, Accuracy: 9711/10000 (97.11%)

loss=0.31653961539268494 batch_id=117: 100%|██████████| 118/118 [00:22<00:00,  5.18it/s]

Test set: Average loss: 0.0861, Accuracy: 9742/10000 (97.42%)

loss=0.33590951561927795 batch_id=117: 100%|██████████| 118/118 [00:22<00:00,  5.19it/s]

Test set: Average loss: 0.0920, Accuracy: 9706/10000 (97.06%)

loss=0.3409155309200287 batch_id=117: 100%|██████████| 118/118 [00:22<00:00,  5.18it/s] 

Test set: Average loss: 0.0747, Accuracy: 9762/10000 (97.62%)

loss=0.35509681701660156 batch_id=117: 100%|██████████| 118/118 [00:22<00:00,  5.20it/s]

Test set: Average loss: 0.0786, Accuracy: 9763/10000 (97.63%)

loss=0.1332807093858719 batch_id=117: 100%|██████████| 118/118 [00:22<00:00,  5.25it/s] 

Test set: Average loss: 0.0698, Accuracy: 9781/10000 (97.81%)

loss=0.2934856414794922 batch_id=117: 100%|██████████| 118/118 [13:58<00:00,  7.11s/it] 

Test set: Average loss: 0.0623, Accuracy: 9778/10000 (97.78%)


"""

#### Option 6: CyclicLR Optimizer

In [None]:
def lr_range_test(model, train_loader, optimizer, device, start_lr=1e-7, end_lr=1, num_iter=100):
    model.train()
    current_lr = start_lr
    lr_step = (end_lr - start_lr) / num_iter
    lrs = []
    losses = []
    
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx >= num_iter:
            break
            
        # Update LR
        for param_group in optimizer.param_groups:
            param_group['lr'] = current_lr
            
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        lrs.append(current_lr)
        losses.append(loss.item())
        current_lr += lr_step
        
    return lrs, losses

# Run the test
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=1e-7, momentum=0.9)
lrs, losses = lr_range_test(model, train_loader, optimizer, device)

# Plot results
import matplotlib.pyplot as plt
plt.plot(lrs, losses)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.show()

# ---
optimal_lr = 0.4
# After finding the optimal LR range from the test above:
base_lr = optimal_lr / 10  # typically the lr where loss starts decreasing
max_lr = optimal_lr       # typically the lr where loss starts increasing

iterations_per_epoch = len(train_loader)  # number of batches per epoch
step_size = iterations_per_epoch * 4  # for 4 epochs
print(step_size)

# Different modes available:
scheduler = torch.optim.lr_scheduler.CyclicLR(
    optimizer,
    base_lr=base_lr,
    max_lr=max_lr,
    step_size_up=step_size,
    mode="triangular2",    # Learning rates decrease by half after each cycle
    # mode="triangular",   # Fixed amplitude
    # mode="exp_range",    # Amplitude decreases exponentially
    cycle_momentum=True,   # If using SGD with momentum
    gamma=0.99997         # Decay factor for exp_range mode
)

# --
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch, scheduler)
    test(model, device, test_loader)

# --

"""
loss=0.4133894443511963 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s] 

Test set: Average loss: 0.1753, Accuracy: 9485/10000 (94.85%)

loss=0.3412211835384369 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s] 

Test set: Average loss: 0.1345, Accuracy: 9591/10000 (95.91%)

loss=0.3795698583126068 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s] 

Test set: Average loss: 0.1139, Accuracy: 9653/10000 (96.53%)

loss=0.41385218501091003 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.71it/s]

Test set: Average loss: 0.1062, Accuracy: 9698/10000 (96.98%)

loss=0.3342282474040985 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.65it/s] 

Test set: Average loss: 0.0974, Accuracy: 9713/10000 (97.13%)

loss=0.38326773047447205 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.66it/s]

Test set: Average loss: 0.0857, Accuracy: 9730/10000 (97.30%)

loss=0.24389390647411346 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s]

Test set: Average loss: 0.0832, Accuracy: 9755/10000 (97.55%)

loss=0.2546258866786957 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.82it/s] 

Test set: Average loss: 0.0837, Accuracy: 9761/10000 (97.61%)

loss=0.2959340512752533 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.67it/s] 

Test set: Average loss: 0.0798, Accuracy: 9757/10000 (97.57%)

"""

#### Option 7: Lookahead Optimizer


In [None]:
model = Net().to(device)
from torch_optimizer import Lookahead
base_optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = Lookahead(base_optimizer)

for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=1.1670705080032349 batch_id=117: 100%|██████████| 118/118 [00:21<00:00,  5.60it/s]

Test set: Average loss: 0.7041, Accuracy: 8209/10000 (82.09%)

loss=0.6810237765312195 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s]

Test set: Average loss: 0.3329, Accuracy: 9124/10000 (91.24%)

loss=0.6628736257553101 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.74it/s]

Test set: Average loss: 0.2315, Accuracy: 9401/10000 (94.01%)

loss=0.6918551921844482 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s] 

Test set: Average loss: 0.1902, Accuracy: 9463/10000 (94.63%)

loss=0.5152751803398132 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s] 

Test set: Average loss: 0.1681, Accuracy: 9508/10000 (95.08%)

loss=0.5009039640426636 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.70it/s] 

Test set: Average loss: 0.1530, Accuracy: 9569/10000 (95.69%)

loss=0.3961903154850006 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s] 

Test set: Average loss: 0.1427, Accuracy: 9590/10000 (95.90%)

loss=0.20646844804286957 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.79it/s]

Test set: Average loss: 0.1291, Accuracy: 9614/10000 (96.14%)

loss=0.4047769606113434 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s] 

Test set: Average loss: 0.1265, Accuracy: 9638/10000 (96.38%)


"""

model = Net().to(device)
from torch_optimizer import Lookahead
base_optimizer = optim.Adam(model.parameters(), lr=0.01)
optimizer = Lookahead(base_optimizer)

for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=0.5237801671028137 batch_id=117: 100%|██████████| 118/118 [00:21<00:00,  5.57it/s] 

Test set: Average loss: 0.1530, Accuracy: 9549/10000 (95.49%)

loss=0.2843589782714844 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.71it/s] 

Test set: Average loss: 0.1014, Accuracy: 9697/10000 (96.97%)

loss=0.26408156752586365 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.67it/s]

Test set: Average loss: 0.1001, Accuracy: 9689/10000 (96.89%)

loss=0.2643115818500519 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s] 

Test set: Average loss: 0.0803, Accuracy: 9749/10000 (97.49%)

loss=0.2205265313386917 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.72it/s] 

Test set: Average loss: 0.0878, Accuracy: 9706/10000 (97.06%)

loss=0.2551535665988922 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.68it/s] 

Test set: Average loss: 0.0705, Accuracy: 9772/10000 (97.72%)

loss=0.13928599655628204 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.69it/s]

Test set: Average loss: 0.0698, Accuracy: 9784/10000 (97.84%)

loss=0.13714684545993805 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.72it/s]

Test set: Average loss: 0.0768, Accuracy: 9751/10000 (97.51%)

loss=0.22165195643901825 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.74it/s]

Test set: Average loss: 0.0665, Accuracy: 9790/10000 (97.90%)


"""

#### Option 8: SWA Optimizer

In [None]:
model = Net().to(device)
# define AveragedModel
from torch.optim.swa_utils import AveragedModel, SWALR

swa_model = AveragedModel(model)

# Regular optimizer for the main model
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Schedule to start SWA after 5 epochs
swa_start = 4
swa_scheduler = SWALR(optimizer, swa_lr=0.001)

for epoch in range(1, 15):
    if epoch < swa_start:
        # Regular training for initial epochs
        train(model, device, train_loader, optimizer, epoch)
    else:
        # After swa_start, update both model and swa_model
        train(model, device, train_loader, optimizer, epoch)
        swa_model.update_parameters(model)
        swa_scheduler.step()
    
    # Regular model testing
    print(f"\nRegular Model - Epoch {epoch}:")
    test(model, device, test_loader)
    
    # Test SWA model after it starts collecting weights
    if epoch >= swa_start:
        # Update batch normalization statistics for swa_model
        torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
        print(f"\nSWA Model - Epoch {epoch}:")
        test(swa_model, device, test_loader)

"""
loss=0.5269340872764587 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.64it/s] 

Regular Model - Epoch 1:

Test set: Average loss: 0.1894, Accuracy: 9450/10000 (94.50%)

loss=0.2970282733440399 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.81it/s] 

Regular Model - Epoch 2:

Test set: Average loss: 0.1354, Accuracy: 9609/10000 (96.09%)

loss=0.3302697539329529 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s] 

Regular Model - Epoch 3:

Test set: Average loss: 0.1174, Accuracy: 9652/10000 (96.52%)

loss=0.38752785325050354 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s]

Regular Model - Epoch 4:

Test set: Average loss: 0.0956, Accuracy: 9705/10000 (97.05%)


SWA Model - Epoch 4:

Test set: Average loss: 0.0937, Accuracy: 9720/10000 (97.20%)

loss=0.38465166091918945 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.68it/s]

Regular Model - Epoch 5:

Test set: Average loss: 0.0907, Accuracy: 9731/10000 (97.31%)


SWA Model - Epoch 5:

Test set: Average loss: 0.0916, Accuracy: 9717/10000 (97.17%)

loss=0.09959688037633896 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s]

Regular Model - Epoch 6:

Test set: Average loss: 0.0877, Accuracy: 9744/10000 (97.44%)


SWA Model - Epoch 6:

Test set: Average loss: 0.0890, Accuracy: 9736/10000 (97.36%)

loss=0.35393691062927246 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.80it/s]

Regular Model - Epoch 7:

Test set: Average loss: 0.0828, Accuracy: 9734/10000 (97.34%)


SWA Model - Epoch 7:

Test set: Average loss: 0.0848, Accuracy: 9746/10000 (97.46%)

loss=0.3833090364933014 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.69it/s] 

Regular Model - Epoch 8:

Test set: Average loss: 0.0799, Accuracy: 9757/10000 (97.57%)


SWA Model - Epoch 8:

Test set: Average loss: 0.0831, Accuracy: 9754/10000 (97.54%)

loss=0.25518763065338135 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s]

Regular Model - Epoch 9:

Test set: Average loss: 0.0761, Accuracy: 9785/10000 (97.85%)


SWA Model - Epoch 9:

Test set: Average loss: 0.0837, Accuracy: 9740/10000 (97.40%)

loss=0.26413026452064514 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.70it/s]

Regular Model - Epoch 10:

Test set: Average loss: 0.0738, Accuracy: 9776/10000 (97.76%)


SWA Model - Epoch 10:

Test set: Average loss: 0.0791, Accuracy: 9780/10000 (97.80%)

loss=0.14917625486850739 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.70it/s]

Regular Model - Epoch 11:

Test set: Average loss: 0.0753, Accuracy: 9771/10000 (97.71%)


SWA Model - Epoch 11:

Test set: Average loss: 0.0793, Accuracy: 9756/10000 (97.56%)

loss=0.2749112844467163 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.79it/s] 

Regular Model - Epoch 12:

Test set: Average loss: 0.0740, Accuracy: 9784/10000 (97.84%)


SWA Model - Epoch 12:

Test set: Average loss: 0.0812, Accuracy: 9770/10000 (97.70%)

loss=0.2318316102027893 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s] 

Regular Model - Epoch 13:

Test set: Average loss: 0.0731, Accuracy: 9784/10000 (97.84%)


SWA Model - Epoch 13:

Test set: Average loss: 0.0771, Accuracy: 9770/10000 (97.70%)

loss=0.2490023970603943 batch_id=42:  36%|███▋      | 43/118 [00:07<00:13,  5.55it/s] 

"""

model = Net().to(device)
from torch.optim.swa_utils import AveragedModel, SWALR

swa_model = AveragedModel(model)

# Use OneCycleLR scheduler for better convergence
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.1,
    epochs=15,                          # Increased epochs
    steps_per_epoch=len(train_loader),
    pct_start=0.3,                      # Warm up for 30% of training
    div_factor=10,                      # Initial lr = max_lr/10
    final_div_factor=100                # Final lr = initial_lr/100
)

# Start SWA later in training
swa_start = 6
swa_scheduler = SWALR(optimizer, swa_lr=0.001)

for epoch in range(1, 15):  # Increased to 20 epochs
    if epoch < swa_start:
        train(model, device, train_loader, optimizer, epoch, scheduler)
    else:
        train(model, device, train_loader, optimizer, epoch)
        swa_model.update_parameters(model)
        swa_scheduler.step()
    
    test(model, device, test_loader)
    
    if epoch >= swa_start:
        torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
        test(swa_model, device, test_loader)
    

"""
loss=0.504370391368866 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.62it/s]  

Test set: Average loss: 0.1486, Accuracy: 9529/10000 (95.29%)

loss=0.20520512759685516 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.80it/s]

Test set: Average loss: 0.0993, Accuracy: 9698/10000 (96.98%)

loss=0.42784056067466736 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.72it/s]

Test set: Average loss: 0.1130, Accuracy: 9641/10000 (96.41%)

loss=0.3493490219116211 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s] 

Test set: Average loss: 0.0911, Accuracy: 9718/10000 (97.18%)

loss=0.3604229986667633 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.83it/s] 

Test set: Average loss: 0.0863, Accuracy: 9746/10000 (97.46%)

loss=0.29685112833976746 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.73it/s]

Test set: Average loss: 0.0728, Accuracy: 9764/10000 (97.64%)


Test set: Average loss: 0.0716, Accuracy: 9783/10000 (97.83%)

loss=0.26517605781555176 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s]

Test set: Average loss: 0.0740, Accuracy: 9776/10000 (97.76%)


Test set: Average loss: 0.0715, Accuracy: 9772/10000 (97.72%)

loss=0.2582552433013916 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.79it/s] 

Test set: Average loss: 0.0803, Accuracy: 9742/10000 (97.42%)


Test set: Average loss: 0.0660, Accuracy: 9787/10000 (97.87%)

loss=0.3765523433685303 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s] 

Test set: Average loss: 0.0682, Accuracy: 9775/10000 (97.75%)


Test set: Average loss: 0.0639, Accuracy: 9797/10000 (97.97%)

loss=0.2534942328929901 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s] 

Test set: Average loss: 0.0629, Accuracy: 9805/10000 (98.05%)


Test set: Average loss: 0.0648, Accuracy: 9791/10000 (97.91%)

loss=0.3718498945236206 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.64it/s] 

Test set: Average loss: 0.0641, Accuracy: 9810/10000 (98.10%)


Test set: Average loss: 0.0618, Accuracy: 9810/10000 (98.10%)

loss=0.3044315278530121 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s] 

Test set: Average loss: 0.0614, Accuracy: 9809/10000 (98.09%)


Test set: Average loss: 0.0602, Accuracy: 9793/10000 (97.93%)

loss=0.11558184772729874 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.82it/s]

Test set: Average loss: 0.0589, Accuracy: 9820/10000 (98.20%)


Test set: Average loss: 0.0611, Accuracy: 9804/10000 (98.04%)

loss=0.25364577770233154 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.81it/s]

Test set: Average loss: 0.0564, Accuracy: 9829/10000 (98.29%)


Test set: Average loss: 0.0599, Accuracy: 9816/10000 (98.16%)


"""

#### Option 9: Adam Optimizer with clip grad norm & stepLR

In [None]:
model = Net().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

for epoch in range(1, 10):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

"""
loss=0.8795673251152039 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.64it/s]

Test set: Average loss: 0.3629, Accuracy: 9010/10000 (90.10%)

loss=0.4695564806461334 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.82it/s] 

Test set: Average loss: 0.1969, Accuracy: 9444/10000 (94.44%)

loss=0.45793116092681885 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s]

Test set: Average loss: 0.1540, Accuracy: 9564/10000 (95.64%)

loss=0.5754876732826233 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.78it/s] 

Test set: Average loss: 0.1352, Accuracy: 9612/10000 (96.12%)

loss=0.37852492928504944 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s]

Test set: Average loss: 0.1236, Accuracy: 9632/10000 (96.32%)

loss=0.3127531111240387 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.68it/s] 

Test set: Average loss: 0.1112, Accuracy: 9669/10000 (96.69%)

loss=0.2938586175441742 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.75it/s] 

Test set: Average loss: 0.1095, Accuracy: 9681/10000 (96.81%)

loss=0.19834460318088531 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s]

Test set: Average loss: 0.0979, Accuracy: 9709/10000 (97.09%)

loss=0.3420853614807129 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s] 

Test set: Average loss: 0.1007, Accuracy: 9701/10000 (97.01%)

"""

model = Net().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.1)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

for epoch in range(1, 12):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()

"""
loss=0.4786355495452881 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.69it/s] 

Test set: Average loss: 0.1827, Accuracy: 9431/10000 (94.31%)

loss=0.41493701934814453 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.79it/s]

Test set: Average loss: 0.1620, Accuracy: 9514/10000 (95.14%)

loss=0.4339393675327301 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.72it/s] 

Test set: Average loss: 0.1107, Accuracy: 9623/10000 (96.23%)

loss=0.3085823655128479 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.79it/s] 

Test set: Average loss: 0.0857, Accuracy: 9735/10000 (97.35%)

loss=0.28189340233802795 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.78it/s]

Test set: Average loss: 0.0792, Accuracy: 9761/10000 (97.61%)

loss=0.3234998285770416 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.71it/s] 

Test set: Average loss: 0.0769, Accuracy: 9759/10000 (97.59%)

loss=0.39721545577049255 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.78it/s]

Test set: Average loss: 0.0769, Accuracy: 9767/10000 (97.67%)

loss=0.1926095336675644 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.76it/s] 

Test set: Average loss: 0.0772, Accuracy: 9760/10000 (97.60%)

loss=0.2487497478723526 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.79it/s] 

Test set: Average loss: 0.0783, Accuracy: 9753/10000 (97.53%)

loss=0.25021860003471375 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.77it/s]

Test set: Average loss: 0.0748, Accuracy: 9764/10000 (97.64%)

loss=0.2767579257488251 batch_id=117: 100%|██████████| 118/118 [00:20<00:00,  5.66it/s] 

Test set: Average loss: 0.0755, Accuracy: 9761/10000 (97.61%)


"""