# Train a classifier on MNIST

## Setup


Import libraries

In [1]:
import importlib
import regene_models
importlib.reload(regene_models)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
import new_CLUE
importlib.reload(new_CLUE)

<module 'new_CLUE' from '/Users/conor/Documents/College terms/College/Thesis/Thesis_Code_Minimised/MyImplementation/new_CLUE.py'>

Set the device

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


Load the Datasets

In [3]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

Set the latent dimension

In [4]:
latent_dim = 256

Create a models directory if it doesn't exist

In [5]:
# Create models directory if it doesn't exist
os.makedirs('model_saves', exist_ok=True)

## Train the classifier

Create the model

In [9]:
from mnist_classifier import MNISTClassifier

classifier = MNISTClassifier(device=device)

And train it

In [11]:
# Train the classifier
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (images, labels) in enumerate(trainloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = classifier(images)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(trainloader)}], '
                  f'Loss: {loss.item():.4f}, '
                  f'Accuracy: {100 * correct/total:.2f}%')
    
    # Print epoch statistics
    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100 * correct / total
    print(f'Epoch [{epoch+1}/{num_epochs}] completed. '
          f'Average Loss: {epoch_loss:.4f}, '
          f'Accuracy: {epoch_acc:.2f}%')

# Save the trained model
torch.save(classifier.state_dict(), 'model_saves/mnist_classifier.pth')
print("Training completed and model saved!")


Epoch [1/5], Step [100/938], Loss: 0.2128, Accuracy: 93.44%
Epoch [1/5], Step [200/938], Loss: 0.2149, Accuracy: 94.89%
Epoch [1/5], Step [300/938], Loss: 0.0173, Accuracy: 95.60%
Epoch [1/5], Step [400/938], Loss: 0.0576, Accuracy: 96.07%
Epoch [1/5], Step [500/938], Loss: 0.0732, Accuracy: 96.46%
Epoch [1/5], Step [600/938], Loss: 0.0171, Accuracy: 96.71%
Epoch [1/5], Step [700/938], Loss: 0.0374, Accuracy: 96.92%
Epoch [1/5], Step [800/938], Loss: 0.0246, Accuracy: 97.05%
Epoch [1/5], Step [900/938], Loss: 0.0269, Accuracy: 97.19%
Epoch [1/5] completed. Average Loss: 0.0921, Accuracy: 97.25%
Epoch [2/5], Step [100/938], Loss: 0.1045, Accuracy: 98.50%
Epoch [2/5], Step [200/938], Loss: 0.0144, Accuracy: 98.60%
Epoch [2/5], Step [300/938], Loss: 0.0085, Accuracy: 98.56%
Epoch [2/5], Step [400/938], Loss: 0.1092, Accuracy: 98.61%
Epoch [2/5], Step [500/938], Loss: 0.0116, Accuracy: 98.66%
Epoch [2/5], Step [600/938], Loss: 0.0213, Accuracy: 98.68%
Epoch [2/5], Step [700/938], Loss: 0.0