In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision.datasets import MNIST
from torchvision.transforms import v2 as T
from tqdm import tqdm
import matplotlib

In [None]:
transforms = T.Compose(
    [
        T.ToImage(),
        T.ToDtype(dtype=torch.float32, scale=True),
        T.RandomRotation(degrees=(-15, 15), interpolation=T.InterpolationMode.BILINEAR),
        T.RandomCrop(size=(22, 22)),
        T.Resize(size=(28, 28), interpolation=T.InterpolationMode.BILINEAR),
        T.Normalize(mean=[0], std=[1]),
    ]
)

train_data = MNIST(root="data", train=True, download=True, transform=transforms)
test_data = MNIST(root="data", train=False, download=True, transform=transforms)

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


device = torch.device("mps")

model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

summary(model, input_size=(1, 1, 28, 28), device=device)

In [None]:
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()

        pbar.set_postfix(
            {
                "loss": f"{loss.item():.4f}",
                "acc": f"{100 * (outputs.argmax(dim=1) == labels).float().mean().item():.2f}%",
            }
        )

        optimizer.step()

    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(
        f"Epoch {epoch + 1}, Accuracy: {100 * correct / total:.2f}%, CI: +/- {100 * 1.96 * ((correct / total) * (1 - correct / total) / total) ** 0.5:.2f}%"
    )