In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# Device configuration
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"{device=}")

# mnist dataset
train_dataset = datasets.MNIST(
    root="../../dataset/", train=True, transform=transforms.ToTensor(), download=True
)

test_dataset = datasets.MNIST(
    root="../../dataset/", train=False, transform=transforms.ToTensor(), download=True
)


print(f"{train_dataset.data.shape=}")
print(f"{test_dataset.data.shape=}")


# hyperparameters
num_classes = 10
num_epochs = 3
batch_size = 100
learning_rate = 0.001

# data loader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 12, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(12, 16, 3), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fn = nn.Linear(16 * 5 * 5, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        # print(f"{out.shape=}")  # torch.Size([8, 12, 12, 12])
        out = self.layer2(out)
        # print(f"{out.shape=}")  # torch.Size([8, 16, 5, 5])
        out = out.view(out.size(0), -1)
        out = self.fn(out)
        # print(f"{out.shape=}")  # torch.Size([8, 10])
        return out

device=device(type='mps')
train_dataset.data.shape=torch.Size([60000, 28, 28])
test_dataset.data.shape=torch.Size([10000, 28, 28])


In [2]:
model = CNNModel().to(device)
example_input = torch.randn(8, 1, 28, 28).to(device)
with torch.no_grad():
    print(model(example_input).shape)

# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

torch.Size([8, 10])


In [3]:
# train the model
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        model.train()
        images = images.to(device)
        labels = labels.to(device)

        # forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}"
            )

Epoch [1/3], Step [100/600], Loss: 0.3237
Epoch [1/3], Step [200/600], Loss: 0.4018
Epoch [1/3], Step [300/600], Loss: 0.2853
Epoch [1/3], Step [400/600], Loss: 0.0773
Epoch [1/3], Step [500/600], Loss: 0.1636
Epoch [1/3], Step [600/600], Loss: 0.2160
Epoch [2/3], Step [100/600], Loss: 0.1227
Epoch [2/3], Step [200/600], Loss: 0.1308
Epoch [2/3], Step [300/600], Loss: 0.1625
Epoch [2/3], Step [400/600], Loss: 0.1162
Epoch [2/3], Step [500/600], Loss: 0.0593
Epoch [2/3], Step [600/600], Loss: 0.2359
Epoch [3/3], Step [100/600], Loss: 0.1225
Epoch [3/3], Step [200/600], Loss: 0.1817
Epoch [3/3], Step [300/600], Loss: 0.0765
Epoch [3/3], Step [400/600], Loss: 0.1577
Epoch [3/3], Step [500/600], Loss: 0.0309
Epoch [3/3], Step [600/600], Loss: 0.0749


In [4]:
# test the model
model.eval()
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    n_class_correct = [0 for i in range(10)]
    n_class_samples = [0 for i in range(10)]
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)

        _, predicted = torch.max(outputs, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

        for i in range(batch_size):
            pred = predicted[i]
            label = labels[i]
            if pred == label:
                n_class_correct[label] += 1
            n_class_samples[label] += 1

    print(f"Accuracy: {100 * n_correct / n_samples}")
    for i in range(10):
        acc = 100.0 * n_class_correct[i] / n_class_samples[i]
        print(f"Accuracy of {i}: {acc} %")

Accuracy: 97.72
Accuracy of 0: 99.48979591836735 %
Accuracy of 1: 98.8546255506608 %
Accuracy of 2: 97.86821705426357 %
Accuracy of 3: 98.7128712871287 %
Accuracy of 4: 97.55600814663951 %
Accuracy of 5: 97.86995515695067 %
Accuracy of 6: 97.39039665970772 %
Accuracy of 7: 96.59533073929961 %
Accuracy of 8: 96.09856262833675 %
Accuracy of 9: 96.63032705649158 %
