In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torchvision import datasets, transforms
from torchsummary import summary
import matplotlib.pyplot as plt

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # Input Block - Enhanced feature extraction
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),    # 28x28x8
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 8, 3, padding=1),    # 28x28x8
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, 3, padding=1),   # 28x28x16
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(0.01)
        )

        # Transition Block 1
        self.trans1 = nn.Sequential(
            nn.MaxPool2d(2, 2),              # 14x14x16
        )

        # Convolution Block 2 - Focus on distinguishing similar digits
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1),  # 14x14x16
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, 3, padding=1),  # 14x14x32
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Dropout(0.01)
        )

        # Transition Block 2
        self.trans2 = nn.Sequential(
            nn.MaxPool2d(2, 2),              # 7x7x32
        )

        # Convolution Block 3 - Final feature refinement
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1),  # 7x7x32
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 16, 1),            # 7x7x16 (pointwise)
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Dropout(0.01)
        )

        self.gap = nn.Sequential(
            nn.AvgPool2d(kernel_size=7)      # 1x1x16
        )

        self.final = nn.Sequential(
            nn.Conv2d(16, 10, 1)             # 1x1x10
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.trans1(x)
        x = self.conv2(x)
        x = self.trans2(x)
        x = self.conv3(x)
        x = self.gap(x)
        x = self.final(x)
        x = x.view(-1, 10)
        return F.log_softmax(x, dim=1)


In [3]:
model = Net().to(device)
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 28, 28]              80
              ReLU-2            [-1, 8, 28, 28]               0
       BatchNorm2d-3            [-1, 8, 28, 28]              16
            Conv2d-4            [-1, 8, 28, 28]             584
              ReLU-5            [-1, 8, 28, 28]               0
       BatchNorm2d-6            [-1, 8, 28, 28]              16
            Conv2d-7           [-1, 16, 28, 28]           1,168
              ReLU-8           [-1, 16, 28, 28]               0
       BatchNorm2d-9           [-1, 16, 28, 28]              32
          Dropout-10           [-1, 16, 28, 28]               0
        MaxPool2d-11           [-1, 16, 14, 14]               0
           Conv2d-12           [-1, 16, 14, 14]           2,320
             ReLU-13           [-1, 16, 14, 14]               0
      BatchNorm2d-14           [-1, 16,

In [4]:
torch.manual_seed(1)
batch_size = 156

train_transforms = transforms.Compose([
    transforms.RandomRotation((-3.0, 3.0), fill=(1,)),  # Very conservative rotation
    transforms.RandomAffine(
        degrees=0,
        translate=(0.05, 0.05),  # Reduced translation
        scale=(0.98, 1.02),     # Minimal scaling
        shear=(-2, 2),          # Minimal shear
        fill=(1,)
    ),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

kwargs = {'num_workers': 2, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                  transform=train_transforms),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=test_transforms),
    batch_size=batch_size, shuffle=True, **kwargs)


In [5]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    train_loss = 0
    correct = 0
    processed = 0

    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)

        pbar.set_description(desc=f'Loss={loss.item():0.4f} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}%')

def test_with_misclassified(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    misclassified_images = []
    misclassified_pred = []
    misclassified_target = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)

            # Store misclassified images
            misclassified_mask = ~pred.eq(target.view_as(pred)).squeeze()
            if misclassified_mask.any():
                misclassified_imgs = data[misclassified_mask].cpu()
                pred_np = pred[misclassified_mask].cpu().numpy()
                target_np = target[misclassified_mask].cpu().numpy()

                # Handle both single and multiple misclassifications
                if len(misclassified_mask.size()) == 0:  # Single misclassification
                    misclassified_pred.append(int(pred_np))
                    misclassified_target.append(int(target_np))
                    misclassified_images.append(misclassified_imgs)
                else:  # Multiple misclassifications
                    for i in range(len(pred_np)):
                        misclassified_pred.append(int(pred_np[i]))
                        misclassified_target.append(int(target_np[i]))
                        misclassified_images.append(misclassified_imgs[i])

            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')

    # Plot first 25 misclassified images
    if len(misclassified_images) > 0:
        plt.figure(figsize=(10,10))
        for i in range(min(25, len(misclassified_images))):
            plt.subplot(5, 5, i+1)
            plt.imshow(misclassified_images[i].squeeze(), cmap='gray')
            plt.title(f'Pred: {misclassified_pred[i]}\nTrue: {misclassified_target[i]}')
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(f'misclassified_epoch.png')
        plt.close()

    return accuracy

In [6]:
optimizer = optim.Adam(
    model.parameters(),
    lr=0.001,
    betas=(0.9, 0.999),
    eps=1e-08,
    weight_decay=1e-4
)

# Modified scheduler for better convergence
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.004,              # Slightly higher max_lr
    epochs=20,
    steps_per_epoch=len(train_loader),
    pct_start=0.2,             # Standard warmup
    anneal_strategy='cos',
    div_factor=10,
    final_div_factor=100
)

best_acc = 0
for epoch in range(1, 21):
    print(f'Epoch {epoch}')
    train(model, device, train_loader, optimizer, epoch)
    acc = test_with_misclassified(model, device, test_loader)
    scheduler.step()

    if acc > best_acc:
        best_acc = acc
        model = model.cpu()  # Move model to CPU before saving
        torch.save(model, 'mnist_best.pth')
        print(f'Best accuracy: {best_acc:.2f}%')


Epoch 1


Loss=0.5289 Batch_id=384 Accuracy=78.76%: 100%|██████████| 385/385 [00:28<00:00, 13.49it/s]
  misclassified_pred.append(int(pred_np[i]))


Test set: Average loss: 0.4424, Accuracy: 9646/10000 (96.46%)
Best accuracy: 96.46%
Epoch 2


Loss=0.1680 Batch_id=384 Accuracy=96.72%: 100%|██████████| 385/385 [00:25<00:00, 15.23it/s]


Test set: Average loss: 0.1454, Accuracy: 9812/10000 (98.12%)
Best accuracy: 98.12%
Epoch 3


Loss=0.0822 Batch_id=384 Accuracy=97.96%: 100%|██████████| 385/385 [00:26<00:00, 14.56it/s]


Test set: Average loss: 0.0809, Accuracy: 9872/10000 (98.72%)
Best accuracy: 98.72%
Epoch 4


Loss=0.0728 Batch_id=384 Accuracy=98.47%: 100%|██████████| 385/385 [00:26<00:00, 14.40it/s]


Test set: Average loss: 0.0600, Accuracy: 9886/10000 (98.86%)
Best accuracy: 98.86%
Epoch 5


Loss=0.0951 Batch_id=384 Accuracy=98.68%: 100%|██████████| 385/385 [00:25<00:00, 14.96it/s]


Test set: Average loss: 0.0478, Accuracy: 9892/10000 (98.92%)
Best accuracy: 98.92%
Epoch 6


Loss=0.0786 Batch_id=384 Accuracy=98.83%: 100%|██████████| 385/385 [00:25<00:00, 14.91it/s]


Test set: Average loss: 0.0406, Accuracy: 9910/10000 (99.10%)
Best accuracy: 99.10%
Epoch 7


Loss=0.0628 Batch_id=384 Accuracy=98.94%: 100%|██████████| 385/385 [00:25<00:00, 14.92it/s]


Test set: Average loss: 0.0354, Accuracy: 9911/10000 (99.11%)
Best accuracy: 99.11%
Epoch 8


Loss=0.0326 Batch_id=384 Accuracy=99.06%: 100%|██████████| 385/385 [00:26<00:00, 14.79it/s]


Test set: Average loss: 0.0301, Accuracy: 9918/10000 (99.18%)
Best accuracy: 99.18%
Epoch 9


Loss=0.0499 Batch_id=384 Accuracy=99.14%: 100%|██████████| 385/385 [00:26<00:00, 14.69it/s]


Test set: Average loss: 0.0289, Accuracy: 9919/10000 (99.19%)
Best accuracy: 99.19%
Epoch 10


Loss=0.0342 Batch_id=384 Accuracy=99.14%: 100%|██████████| 385/385 [00:26<00:00, 14.78it/s]


Test set: Average loss: 0.0262, Accuracy: 9924/10000 (99.24%)
Best accuracy: 99.24%
Epoch 11


Loss=0.0139 Batch_id=384 Accuracy=99.19%: 100%|██████████| 385/385 [00:26<00:00, 14.70it/s]


Test set: Average loss: 0.0278, Accuracy: 9919/10000 (99.19%)
Epoch 12


Loss=0.0289 Batch_id=384 Accuracy=99.29%: 100%|██████████| 385/385 [00:25<00:00, 14.92it/s]


Test set: Average loss: 0.0221, Accuracy: 9934/10000 (99.34%)
Best accuracy: 99.34%
Epoch 13


Loss=0.0037 Batch_id=384 Accuracy=99.28%: 100%|██████████| 385/385 [00:26<00:00, 14.52it/s]


Test set: Average loss: 0.0258, Accuracy: 9924/10000 (99.24%)
Epoch 14


Loss=0.0066 Batch_id=384 Accuracy=99.33%: 100%|██████████| 385/385 [00:26<00:00, 14.59it/s]


Test set: Average loss: 0.0232, Accuracy: 9937/10000 (99.37%)
Best accuracy: 99.37%
Epoch 15


Loss=0.0267 Batch_id=384 Accuracy=99.36%: 100%|██████████| 385/385 [00:26<00:00, 14.61it/s]


Test set: Average loss: 0.0256, Accuracy: 9922/10000 (99.22%)
Epoch 16


Loss=0.0214 Batch_id=384 Accuracy=99.41%: 100%|██████████| 385/385 [00:26<00:00, 14.78it/s]


Test set: Average loss: 0.0214, Accuracy: 9934/10000 (99.34%)
Epoch 17


Loss=0.0119 Batch_id=384 Accuracy=99.40%: 100%|██████████| 385/385 [00:25<00:00, 14.84it/s]


Test set: Average loss: 0.0223, Accuracy: 9931/10000 (99.31%)
Epoch 18


Loss=0.0253 Batch_id=384 Accuracy=99.47%: 100%|██████████| 385/385 [00:25<00:00, 14.86it/s]


Test set: Average loss: 0.0265, Accuracy: 9910/10000 (99.10%)
Epoch 19


Loss=0.0377 Batch_id=384 Accuracy=99.42%: 100%|██████████| 385/385 [00:26<00:00, 14.68it/s]


Test set: Average loss: 0.0221, Accuracy: 9935/10000 (99.35%)
Epoch 20


Loss=0.0070 Batch_id=384 Accuracy=99.48%: 100%|██████████| 385/385 [00:26<00:00, 14.58it/s]


Test set: Average loss: 0.0188, Accuracy: 9946/10000 (99.46%)
Best accuracy: 99.46%
