In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.softmax(self.fc3(x))
        return x

In [None]:
def get_mnist_loader():
    transform = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='../data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    return train_loader, test_loader

In [None]:
def train_mlp_model(model, train_loader, lr=1e-3, epochs=5, device='cpu'):
    model = model().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loss = criterion(model.fc3(model.relu2(model.fc2(model.relu1(model.fc1(model.flatten(inputs)))))), labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if i % 200 == 199:
                print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.4f}, Accuracy: {100*correct/total:.2f}%')
                running_loss = 0.0

        print(f'Epoch {epoch+1} completed, Accuracy: {100*correct/total:.2f}%\n')

    print('Training completed!')
    return model

In [None]:
train_loader, test_loader = get_mnist_loader()
device = ('mps' if torch.mps.is_available() else 'cpu')

model = train_mlp_model(MLP, train_loader, lr=1e-3, epochs=5, device=device)
torch.save(model.state_dict(), '../models/mlp_model.pth')