# Logistic Regression on the MNIST dataset using Pytorch.

In [2]:
#Importing basic required libraries

import torch
import torchvision
import numpy as np

from torchvision.datasets import MNIST

We will use 60,000 images for training and validation, and 10,000 images for testing.

In [3]:
dataset = MNIST(root='data/', download=True)
len(dataset)

60000

Since, Pytorch only works with tensors, we need to first convert the images into pytorch tensors, torchvision.transforms.ToTensor() does just that.

In [4]:
import torchvision.transforms as transforms

In [5]:
dataset = MNIST(root='data/', 
                train=True,
                transform=transforms.ToTensor())

In [6]:
test_dataset = MNIST(root='data/', train=False, transform=transforms.ToTensor())
len(test_dataset)

10000

Out of the 60,000 images, we will use 54,000 for training and the remaining 6,000 for validation. 

In [7]:
from torch.utils.data import random_split

train_ds, val_ds = random_split(dataset, [54000, 6000])
len(train_ds), len(val_ds)

(54000, 6000)

In [8]:
from torch.utils.data import DataLoader

batch_size = 128

train_loader = DataLoader(train_ds, batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size)
test_loader = DataLoader(test_dataset, 10000)

Each image in the dataset is of 28\*28 size, and we need to classify each digit into one of the ten possible classes. Ergo, our weights matrix will have the dimensions given below.

In [9]:
import torch.nn as nn

input_size = 28*28
num_classes = 10

# Logistic regression model
model = nn.Linear(input_size, num_classes)

In [10]:
import torch.nn.functional as F

loss_fn = F.cross_entropy

In [11]:
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    optimizer = opt_func(model.parameters(), lr)
    history = [] # for recording epoch-wise results
    
    for epoch in range(epochs):
        
        # Training Phase 
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        # Validation phase
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)

    return history

In [12]:
def eval_test(test_loader):
    for batch in test_loader:
        print("Test accuracy: ", model.test_performance(batch).item())

In [13]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

In [14]:
def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

In [15]:
class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(input_size, num_classes)
        
    def forward(self, xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return out
    
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss, 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
        
    def test_performance(self, batch):
        images, labels = batch
        preds = self(images)
        return accuracy(preds, labels)
    
model = MnistModel()

In [16]:
fit(20,0.01,model, train_loader,val_loader)

Epoch [0], val_loss: 0.8627, val_acc: 0.8277
Epoch [1], val_loss: 0.6518, val_acc: 0.8530
Epoch [2], val_loss: 0.5664, val_acc: 0.8641
Epoch [3], val_loss: 0.5180, val_acc: 0.8736
Epoch [4], val_loss: 0.4866, val_acc: 0.8790
Epoch [5], val_loss: 0.4639, val_acc: 0.8838
Epoch [6], val_loss: 0.4464, val_acc: 0.8867
Epoch [7], val_loss: 0.4328, val_acc: 0.8878
Epoch [8], val_loss: 0.4214, val_acc: 0.8910
Epoch [9], val_loss: 0.4121, val_acc: 0.8925
Epoch [10], val_loss: 0.4040, val_acc: 0.8937
Epoch [11], val_loss: 0.3973, val_acc: 0.8961
Epoch [12], val_loss: 0.3912, val_acc: 0.8964
Epoch [13], val_loss: 0.3858, val_acc: 0.8977
Epoch [14], val_loss: 0.3810, val_acc: 0.8986
Epoch [15], val_loss: 0.3768, val_acc: 0.9001
Epoch [16], val_loss: 0.3729, val_acc: 0.9014
Epoch [17], val_loss: 0.3694, val_acc: 0.9021
Epoch [18], val_loss: 0.3664, val_acc: 0.9014
Epoch [19], val_loss: 0.3632, val_acc: 0.9026


[{'val_loss': 0.8626750111579895, 'val_acc': 0.8276737928390503},
 {'val_loss': 0.6518032550811768, 'val_acc': 0.8529872298240662},
 {'val_loss': 0.566364049911499, 'val_acc': 0.8641242384910583},
 {'val_loss': 0.5179910659790039, 'val_acc': 0.873622715473175},
 {'val_loss': 0.48659783601760864, 'val_acc': 0.8789656758308411},
 {'val_loss': 0.463875412940979, 'val_acc': 0.8838098645210266},
 {'val_loss': 0.44637295603752136, 'val_acc': 0.8866593241691589},
 {'val_loss': 0.43275874853134155, 'val_acc': 0.8878229260444641},
 {'val_loss': 0.42144763469696045, 'val_acc': 0.8910049200057983},
 {'val_loss': 0.41208750009536743, 'val_acc': 0.8925247192382812},
 {'val_loss': 0.4040427505970001, 'val_acc': 0.8937119841575623},
 {'val_loss': 0.3973042070865631, 'val_acc': 0.8960866928100586},
 {'val_loss': 0.39118969440460205, 'val_acc': 0.896419107913971},
 {'val_loss': 0.38582512736320496, 'val_acc': 0.8977251052856445},
 {'val_loss': 0.3810231685638428, 'val_acc': 0.898580014705658},
 {'val_l

In [17]:
eval_test(test_loader)

Test accuracy:  0.907800018787384


We have achieved an accuracy of 90% on the MNIST dataset which is good enough for this simple model.