# MLP Model Training

This notebook contains the source code necessary to build an MLP classification model using the MNIST dataset. No other resources are required besides the dependencies in the virtual environment and the codes in the `src/` directory.

In [None]:
import os, sys
import torch
import torch.nn as nn
import torch.optim as optim

sys.path.insert(0, os.path.abspath('../src'))

from models import MLP
from utils import get_mnist_loader

In [None]:
def train_mlp_model(model, train_loader, lr=1e-3, epochs=5, device='cpu'):
    model = model().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss = criterion(model.fc3(model.relu2(model.fc2(model.relu1(model.fc1(model.flatten(inputs)))))), labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if i % 200 == 199:
                print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.4f}, Accuracy: {100*correct/total:.2f}%')
                running_loss = 0.0

        print(f'Epoch {epoch+1} completed, Accuracy: {100*correct/total:.2f}%\n')

    print('Training completed!')
    return model

## Exporting the Model

The following code block contains the code used to train and export the model to the `models/` directory.

In [None]:
train_loader, test_loader = get_mnist_loader()
device = ('mps' if torch.mps.is_available() else 'cpu')

model = train_mlp_model(MLP, train_loader, lr=1e-3, epochs=5, device=device)
os.makedirs('../models', exist_ok=True)
torch.save(model.state_dict(), '../models/mlp_model.pth')