In [21]:
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

In [22]:
train_dataset = MNIST(train=True, download=True, root='/data/', transform=transforms.ToTensor())
test_dataset = MNIST(download=True, root='/data/', transform=transforms.ToTensor())

In [23]:
train_loader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=4, shuffle=True)

In [32]:
class SimpleFNN(nn.Module):
  """
  A simple Feed forward Neural Network contians
  the 2 Linear Layers followed by ReLu
  1 output layer
  """
  def __init__(self, input_features, output_features):
    super(SimpleFNN, self).__init__()
    self.layer1 = nn.Linear(input_features, 256)
    self.layer2 = nn.ReLU()
    self.layer3 = nn.Linear(256, 512)
    self.layer4 = nn.ReLU()
    self.layer5 = nn.Linear(512, output_features)

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.layer5(x)
    return x

epochs = 10 # how many epochs to train
learning_rate = 0.001

# to run the model on the GPU
device = 'cuda'
model = SimpleFNN(784, 10).to(device)

loss = nn.CrossEntropyLoss() # loss function
optimizers = SGD(model.parameters(), learning_rate) # optimizer

for epoch in range(epochs):
  for idx, (data, labels) in enumerate(train_loader):
    data = data.to(device)
    data = data.reshape(data.size(0), -1)
    labels = labels.to(device)
    y_hat = model.forward(data)
    l = loss(y_hat, labels)

    optimizers.zero_grad()
    l.backward()
    optimizers.step()
  print(f"Epoch {epoch}: {epochs}")

Epoch 0: 5
Epoch 1: 5
Epoch 2: 5
Epoch 3: 5
Epoch 4: 5


In [33]:
num_samples = 0
num_correct = 0

model.eval()
with torch.no_grad():
  for data, labels in test_loader:
    data = data.to('cuda')
    labels = labels.to('cuda')
    data = data.reshape(data.size(0), -1)
    output = model(data)
    _, prediction = output.max(1)
    num_correct += (prediction == labels).sum()
    num_samples += prediction.size(0)

print(f"The model accuracy on test data: {num_correct/num_samples}")


The model accuracy on test data: 0.9308666586875916
