In [4]:
# Imports
import torch
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [5]:
# Download dataset
train = datasets.MNIST("", train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
test = datasets.MNIST("", train=False, download=True, transform=transforms.Compose([transforms.ToTensor()]))

trainset = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=10, shuffle=True)

In [7]:
# Create PyTorch module
#
# Neural network architecture is defined in the __init__ method
# nn.Linear applies an affine  transformation to the input (w.T*x + b)
class Net(nn.Module):
    
    # Create network layers
    def __init__(self):
        super().__init__()
        
        # Create input layer, with 64 neurons and 784 inputs each
        self.fc1 = nn.Linear(28 * 28, 64)
        
        # Create layer #2, 64 inputs & 64 outputs
        self.fc2 = nn.Linear(64, 64)
        
        # Create layer #3, 64 inputs & 64 outputs
        self.fc3 = nn.Linear(64, 64)
        
        # Create output layer, 64 inputs & 10 outputs (since there are 10 possible digits)
        self.fc4 = nn.Linear(64, 10)
    
    # Computing the output of the network
    def forward(self, x):
        # Compute the sigmoid function of a given network
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        x = self.fc4(x)
        
        # Compute the softmax functon of a given input
        return F.softmax(x, dim=1)
        
net = Net()

In [None]:
# Setup the optimiser and train the network
optimiser = optim.SGD(net.parameters(), lr=0.001)

Epochs = 3

for epoch in range(Epochs):
    for data in trainset:
        X, y = data
        net.zero_grad()
        output = net.forward(X.view(-1, 28*28))
        loss = F.nll_loss(output, y)
        loss.backward()
        optimiser.step()

In [None]:
# Initialise counters
correct = 0
total = 0

# Test the model
with torch.no_grad():
    for data in testset:
        X, y = data
        output = net.forward(X.view(-1, 28*28))
        for idx, i in enumerate(output):
            if torch.argmax(i) == y[idx]:
                correct += 1
            total += 1

# Print result
print("accuracy: ", round(correct/total, 3))