In [44]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable

#Loading dataset
train_dataset =dsets.MNIST(root='./data',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)
test_dataset = dsets.MNIST(root='./data',
                          train=False,
                          transform=transforms.ToTensor())

In [45]:
# Making Dataset Iterable
batch_size = 100
n_iters = 3000
num_epochs = n_iters / (len(train_dataset) / batch_size)
num_epochs = int(num_epochs)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)

In [63]:
# Create Model Class
# We can use sigmoid activation, tanh or ReLU
# hidden_dim are the neurons

class FeedforwardNeuralNetModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(FeedforwardNeuralNetModel, self).__init__()
        # Linear function 1: 784 --> 100
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        # Non-linearity 1
        self.relu1 = nn.ReLU()
        
        # Linear function 2: 100 --> 100
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        # Non-linearity 2
        self.relu2 = nn.ReLU()
        
        # Linear function 3: 100 --> 100
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        # Non-linearity 3
        self.relu3 = nn.ReLU()
        
        # Linear function 4 (readout): 100 --> 10
        self.fc4 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        # Linear function 1
        out = self.fc1(x)
        # Non-linearity
        out = self.relu1(out)
        
        #Linear function 2
        out = self.fc2(out)
        # Non linearity 2
        out = self.relu2(out)
        
        # Linear function 3
        out = self.fc3(out)
        # Non-linearity 3
        out = self.relu3(out)
        
        # Linear function 4 (readout)
        out = self.fc2(out)
        return out

In [64]:
# Now we Instantiate the Model Class
input_dim = 28*28
hidden_dim = 100
output_dim = 10

model = FeedforwardNeuralNetModel(input_dim, hidden_dim, output_dim)

In [65]:
# Instantiate Loss Class
criterion = nn.CrossEntropyLoss()

In [66]:
# Instantiate Optimizer Class
learning_rate = 0.1

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [67]:
# Train the model

iter = 0
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Load image as Variable
        images = Variable(images.view(-1, 28*28))
        labels = Variable(labels)

        # Clear gradients w.r.t parameters
        optimizer.zero_grad()

        # Forward pass to get output/logits
        outputs = model(images)

        # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)

        # Getting gradients w.r.t parameters
        loss.backward()

        # Updating parameters
        optimizer.step()
    
    iter += 1
    
    if iter % 500 == 0:
        # Calculate Accuracy
        correct = 0
        total = 0
        # Iterate through test dataset
        for images, labels in test_loader:
            # Load images to a Torch Variable
            images = Variable(images.view(-1, 28*28))
            
            # Forward pass only to get logits/output
            outputs = model(images)
            
            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1)
            
            # Total number of labels
            total += labels.size(0)
            
            # Total correct predictions
            correct += (predicted == labels).sum()
            
        accuracy = 100 * correct / total
        
        # Print Loss
        print('Iteration: {}. Loss: {}. Accuracy: {}'.format(iter, loss.data[0], accuracy))