In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from torchsummary import summary
from tqdm import tqdm

In [2]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()
device: str
if use_cuda:
    device = "cuda"
elif use_mps:
    device = "mps"
else:
    device = "cpu"

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 8, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.20),
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(8, 16, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.20),
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(16, 16, 3),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 10, 3),
            nn.ReLU(),
            nn.BatchNorm2d(10),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Dropout(0.20),
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = x.view(x.size(0), -1)
        return F.log_softmax(x, dim=1)

    def summarize(self):
        summary(self, input_size=(1, 28, 28))

In [4]:
Net().summarize()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 28, 28]              80
              ReLU-2            [-1, 8, 28, 28]               0
       BatchNorm2d-3            [-1, 8, 28, 28]              16
            Conv2d-4            [-1, 8, 28, 28]             584
              ReLU-5            [-1, 8, 28, 28]               0
       BatchNorm2d-6            [-1, 8, 28, 28]              16
         MaxPool2d-7            [-1, 8, 14, 14]               0
           Dropout-8            [-1, 8, 14, 14]               0
            Conv2d-9           [-1, 16, 14, 14]           1,168
             ReLU-10           [-1, 16, 14, 14]               0
      BatchNorm2d-11           [-1, 16, 14, 14]              32
           Conv2d-12           [-1, 16, 14, 14]           2,320
             ReLU-13           [-1, 16, 14, 14]               0
      BatchNorm2d-14           [-1, 16,

In [5]:
torch.manual_seed(1)
batch_size = 64
kwargs = {"num_workers": 1, "pin_memory": True} if device == "cuda" else {}

# Define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

# Download and create datasets
train_dataset = datasets.MNIST(
    root='../data', 
    train=True, 
    download=True, 
    transform=transform
)

test_dataset = datasets.MNIST(
    root='../data', 
    train=False, 
    download=True, 
    transform=transform
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    **kwargs,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=1000, 
    shuffle=False,
    **kwargs,
)

In [6]:
def train_step(epoch, model, device, train_loader, optimizer):
    model.train()
    pbar = tqdm(train_loader)
    correct = 0
    processed = 0

    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)

        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        # Calculate accuracy
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)
        
        pbar.set_description(desc=f"epoch={epoch:02d} loss={loss.item():.4f} batch_id={batch_idx:04d} accuracy={100*correct/processed:.2f}%")


def test_step(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction="mean").item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader)  # Average loss across batches

    print(
        "Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )

In [7]:
epochs = 20
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, epochs + 1):
    train_step(epoch, model, device, train_loader, optimizer)
    test_step(model, device, test_loader)

epoch=01 loss=0.3911 batch_id=0937 accuracy=83.33%: 100%|██████████| 938/938 [00:08<00:00, 110.16it/s]


Test set: Average loss: 0.0567, Accuracy: 9829/10000 (98.29%)



epoch=02 loss=0.3151 batch_id=0937 accuracy=86.32%: 100%|██████████| 938/938 [00:08<00:00, 114.43it/s]


Test set: Average loss: 0.0370, Accuracy: 9884/10000 (98.84%)



epoch=03 loss=0.4425 batch_id=0937 accuracy=86.77%: 100%|██████████| 938/938 [00:08<00:00, 110.87it/s]


Test set: Average loss: 0.0383, Accuracy: 9884/10000 (98.84%)



epoch=04 loss=0.7033 batch_id=0937 accuracy=86.91%: 100%|██████████| 938/938 [00:08<00:00, 116.45it/s]


Test set: Average loss: 0.0414, Accuracy: 9877/10000 (98.77%)



epoch=05 loss=0.3970 batch_id=0937 accuracy=87.32%: 100%|██████████| 938/938 [00:08<00:00, 111.97it/s]


Test set: Average loss: 0.0297, Accuracy: 9910/10000 (99.10%)



epoch=06 loss=0.3769 batch_id=0937 accuracy=87.36%: 100%|██████████| 938/938 [00:08<00:00, 111.03it/s]


Test set: Average loss: 0.0308, Accuracy: 9902/10000 (99.02%)



epoch=07 loss=0.3905 batch_id=0937 accuracy=87.67%: 100%|██████████| 938/938 [00:08<00:00, 111.79it/s]


Test set: Average loss: 0.0300, Accuracy: 9912/10000 (99.12%)



epoch=08 loss=0.1580 batch_id=0937 accuracy=87.67%: 100%|██████████| 938/938 [00:08<00:00, 112.25it/s]


Test set: Average loss: 0.0293, Accuracy: 9900/10000 (99.00%)



epoch=09 loss=0.1650 batch_id=0937 accuracy=87.46%: 100%|██████████| 938/938 [00:08<00:00, 112.56it/s]


Test set: Average loss: 0.0231, Accuracy: 9926/10000 (99.26%)



epoch=10 loss=0.2020 batch_id=0937 accuracy=87.65%: 100%|██████████| 938/938 [00:08<00:00, 109.69it/s]


Test set: Average loss: 0.0275, Accuracy: 9919/10000 (99.19%)



epoch=11 loss=0.1691 batch_id=0937 accuracy=87.74%: 100%|██████████| 938/938 [00:08<00:00, 114.99it/s]


Test set: Average loss: 0.0214, Accuracy: 9939/10000 (99.39%)



epoch=12 loss=0.1616 batch_id=0937 accuracy=87.74%: 100%|██████████| 938/938 [00:08<00:00, 115.67it/s]


Test set: Average loss: 0.0227, Accuracy: 9926/10000 (99.26%)



epoch=13 loss=0.2603 batch_id=0937 accuracy=88.06%: 100%|██████████| 938/938 [00:08<00:00, 115.08it/s]


Test set: Average loss: 0.0256, Accuracy: 9911/10000 (99.11%)



epoch=14 loss=0.2949 batch_id=0937 accuracy=87.80%: 100%|██████████| 938/938 [00:08<00:00, 114.78it/s]


Test set: Average loss: 0.0285, Accuracy: 9903/10000 (99.03%)



epoch=15 loss=0.2100 batch_id=0937 accuracy=87.82%: 100%|██████████| 938/938 [00:08<00:00, 115.65it/s]


Test set: Average loss: 0.0181, Accuracy: 9944/10000 (99.44%)



epoch=16 loss=0.0853 batch_id=0937 accuracy=87.87%: 100%|██████████| 938/938 [00:08<00:00, 115.68it/s]


Test set: Average loss: 0.0239, Accuracy: 9926/10000 (99.26%)



epoch=17 loss=0.1235 batch_id=0937 accuracy=87.90%: 100%|██████████| 938/938 [00:08<00:00, 115.43it/s]


Test set: Average loss: 0.0222, Accuracy: 9934/10000 (99.34%)



epoch=18 loss=0.2770 batch_id=0937 accuracy=88.02%: 100%|██████████| 938/938 [00:08<00:00, 114.97it/s]


Test set: Average loss: 0.0227, Accuracy: 9920/10000 (99.20%)



epoch=19 loss=0.1358 batch_id=0937 accuracy=87.95%: 100%|██████████| 938/938 [00:08<00:00, 115.15it/s]


Test set: Average loss: 0.0217, Accuracy: 9926/10000 (99.26%)



epoch=20 loss=0.5194 batch_id=0937 accuracy=87.83%: 100%|██████████| 938/938 [00:08<00:00, 114.67it/s]


Test set: Average loss: 0.0346, Accuracy: 9903/10000 (99.03%)

