In [1]:
import os
import glob
import numpy as np
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data = datasets.MNIST(
    "../data", train=True, transform=transforms.ToTensor(), download=True
)
test_data = datasets.MNIST(
    "../data", train=False, transform=transforms.ToTensor(), download=True
)

for file_path in glob.glob("../data/**/*.gz", recursive=True):
    os.remove(file_path)

In [3]:
len(test_data)

10000

In [4]:
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)

In [None]:
class MNISTNet(torch.nn.Module):
    def __init__(self, input_dim=28*28, hidden1=128, hidden2=64, output_dim=10):
        super().__init__()
        self.input = torch.nn.Linear(input_dim, hidden1)
        self.hidden = torch.nn.Linear(hidden1, hidden2)
        self.output = torch.nn.Linear(hidden2, output_dim)
    def forward(self, x):
        x = F.relu(self.input(x))
        x = F.relu(self.hidden(x))
        x = self.output(x)
        return F.log_softmax(x, dim=1)
    
def create_mnist_net(learning_rate=.001, input_dim = 28*28, hidden, **kwargs): 
    net = MNISTNet(input_dim=input_dim, hidden1=128, hidden2=64, output_dim=output_dim)
    loss_fn = torch.nn.NLLLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    return net, loss_fn, optimizer

In [None]:
def test_create_mnist_net():
    net, loss_fn, optimizer = create_mnist_net(learning_rate=.001)
    assert isinstance(net, MNISTNet), "net should be an instance of MNISTNet"
    assert isinstance(loss_fn, torch.nn.NLLLoss), "loss_fn should be NLLLoss"
    assert isinstance(optimizer, torch.optim.Adam), "optimizer should be Adam"
    assert len(list(net.parameters())) > 0, "net should have parameters"


create_mnist_net() test passed.


In [None]:
def train_mnist_net(net, train_loader, loss_fn, optimizer, epochs=100):
    losses = []
    train_acc = []
    test_acc = []
    
    net.train()
    for epoch in range(epochs):
        batch_losses = []
        batch_acc = []
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(data.size(0), -1)  # Flatten the images

            # Forward pass
            y_hat = net(data)
            loss = loss_fn(y_hat, target)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Accumulate loss and accuracy
            batch_losses.append(loss.item())
            matches = (y_hat.argmax(dim=1) == target)
            
            
            if batch_idx % 100 == 0:
                print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}")