In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Subset
from torch.utils.data import DataLoader, TensorDataset

import torchvision
from torchvision import transforms

In [2]:
# Define the path where the MNIST data will be downloaded
image_path = './local_training_data'

# Define a transform to convert the images to PyTorch tensors
transform = transforms.Compose([
    transforms.ToTensor()
])

# Download the MNIST training dataset
mnist_dataset = torchvision.datasets.MNIST(
    root=image_path, 
    train=True, 
    transform=transform, 
    download=True
)

# Split the dataset into training and validation sets
mnist_valid_dataset = Subset(mnist_dataset, torch.arange(10000))
mnist_train_dataset = Subset(mnist_dataset, torch.arange(10000, len(mnist_dataset)))

# Download the MNIST test dataset
mnist_test_dataset = torchvision.datasets.MNIST(
    root=image_path, 
    train=False, 
    transform=transform, 
    download=False  # Set to False assuming it's already downloaded
)

In [3]:
# Assuming mnist_train_dataset and mnist_valid_dataset are already defined
train_ds = mnist_train_dataset
valid_ds = mnist_valid_dataset

In [4]:
# Initialize the sequential model
model = nn.Sequential()

# Add the first convolutional layer
model.add_module('conv1', nn.Conv2d(
    in_channels=1, 
    out_channels=32, 
    kernel_size=5, 
    padding=2
))

# Add a ReLU activation layer
model.add_module('relu1', nn.ReLU())

# Add the first pooling layer
model.add_module('pool1', nn.MaxPool2d(kernel_size=2))

# Add the second convolutional layer
model.add_module('conv2', nn.Conv2d(
    in_channels=32, 
    out_channels=64, 
    kernel_size=5, 
    padding=2
))

# Add another ReLU activation layer
model.add_module('relu2', nn.ReLU())

# Add the second pooling layer
model.add_module('pool2', nn.MaxPool2d(kernel_size=2))

# Add a flatten layer to the model
model.add_module('flatten', nn.Flatten())

# Add fully connected layers with a dropout layer in between
model.add_module('fc1', nn.Linear(3136, 1024))
model.add_module('relu3', nn.ReLU())
model.add_module('dropout', nn.Dropout(p=0.5))
model.add_module('fc2', nn.Linear(1024, 10))

In [5]:
def fit_model(model: torch.nn.Module,
                train_ds: TensorDataset,
                loss_fn: nn.Module,
                optimizer: optim.Optimizer,
                valid_ds: TensorDataset,
                accuracy_fn = None,
                num_epochs = 100,
                batch_size = 32,
                seed = 1,
                transform_pred = None,
                device: torch.device = "cpu"):
    # variables
    loss_hist_train = [0] * num_epochs
    accuracy_hist_train = [0] * num_epochs
    loss_hist_valid = [0] * num_epochs
    accuracy_hist_valid = [0] * num_epochs    
    torch.manual_seed(seed)
    train_dl = DataLoader(dataset=train_ds, batch_size=batch_size, shuffle=True)
    valid_dl = None
    if valid_ds:
        valid_dl = DataLoader(dataset=valid_ds, batch_size=batch_size, shuffle=True)
    n_train = len(train_dl.dataset)    
    # set model to device
    model.to(device)
    for epoch in range(num_epochs):
        # set model to training mode
        model.train()
        # mini-batch training
        for x_batch, y_batch in train_dl:
            # Send data to device
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            # Forward pass
            y_pred = model(x_batch)
            if transform_pred:
                y_pred = transform_pred(y_pred)

            # Compute loss
            loss = loss_fn(y_pred, y_batch)

            # Do backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if accuracy_fn:
                accuracy_batch = accuracy_fn(y_pred, y_batch)
                accuracy_hist_train[epoch] += accuracy_batch
            loss_hist_train[epoch] += loss.item()            

        # Compute store loss and accuracy as percent
        loss_hist_train[epoch] /= n_train
        accuracy_hist_train[epoch] /= n_train
        
        if valid_dl:
            n_valid = len(valid_dl.dataset)
            # set model to training mode
            model.eval()
            with torch.no_grad():
                for x_batch, y_batch in valid_dl:                    
                    # Send data to device
                    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

                    # Forward pass
                    y_pred = model(x_batch)
                    if transform_pred:
                        y_pred = transform_pred(y_pred)

                    # Compute loss
                    loss = loss_fn(y_pred, y_batch)

                    if accuracy_fn:
                        accuracy_batch = accuracy_fn(y_pred, y_batch)
                        accuracy_hist_valid[epoch] += accuracy_batch
                    loss_hist_valid[epoch] += loss.item()

                # Compute store loss and accuracy valid as percent
                accuracy_hist_valid[epoch] /= n_valid
                loss_hist_valid[epoch] /= n_valid
                    
        # Print out
        if epoch % 10 == 0:
            if valid_dl:
                print(f"Train loss: {loss_hist_train[epoch]:.5f} | Train accuracy: {accuracy_hist_train[epoch]:.2f}% | Val loss: {loss_hist_valid[epoch]:.5f} | Val accuracy: {accuracy_hist_valid[epoch]:.2f}%")
            else:
                print(f"Train loss: {loss_hist_train[epoch]:.5f} | Train accuracy: {accuracy_hist_train[epoch]:.2f}%")

    # Return result    
    return (loss_hist_train, accuracy_hist_train, loss_hist_valid, accuracy_hist_valid)

In [6]:
def get_accuracy_sum_multiclass(y_pred, y):
    is_correct = (torch.argmax(y_pred, dim=1) == y).float()
    accuracy = is_correct.sum()
    return accuracy

In [7]:
learning_rate = 0.001
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
loss_hist_train, accuracy_hist_train, loss_hist_valid, accuracy_hist_valid = fit_model(model, 
                                                                                       train_ds,
                                                                                       loss_fn, 
                                                                                       optimizer,
                                                                                       valid_ds=valid_ds,
                                                                                       num_epochs = 20,
                                                                                       batch_size=64, 
                                                                                       accuracy_fn=get_accuracy_sum_multiclass)

Train loss: 0.00244 | Train accuracy: 0.95% | Val loss: 0.00112 | Val accuracy: 0.98%
Train loss: 0.00017 | Train accuracy: 1.00% | Val loss: 0.00060 | Val accuracy: 0.99%


KeyboardInterrupt: 