# MNIST Classifier
Training a basic feedforward neural network to classify handwritten
digits from the MNIST database

In [None]:
import torch
import torchvision
import numpy as np
import os
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from tqdm import tqdm
plt.style.use('ggplot')

In [None]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [None]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [None]:
# Show a few of the images
fig, ax = plt.subplots(1,3)
for i in range(3):
    ax[i].imshow(example_data[i].squeeze(), cmap='gray')
    ax[i].set(title=f"{example_targets[i]}")
    ax[i].set_xticks([])
    ax[i].set_yticks([])


In [75]:
# Create the neural network
class NeuralNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [83]:
input_size = 784 # 28x28 pixel images
hidden_size = 196
num_classes = 10 # 10 classes
model = NeuralNet(input_size, hidden_size, num_classes)