<a href="https://colab.research.google.com/github/mertkaya1033/MNIST_digit_classify/blob/master/MNIST_digit_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MNIST Handwritten Digit Classification


## Imports and Helpers

In [None]:
import time
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.autograd import Variable

from typing import Tuple

# https://stackoverflow.com/questions/39832721/meaning-of-self-dict-self-in-a-class-definition
class Config(dict):
  """A class that allows attribute access and item access."""
  def __init__(self):
    self.__dict__ = self
  

torch.Size([60000, 28, 28])

## Convolutional Neural Network 

In [None]:
class MNIST_CNN(nn.Module):
  """
  A convolutional neural network
  """
  def __init__(self, image_size, num_classes, num_in_channels, train_args) -> None:
      super().__init__()

      kernel = train_args.kernel
      num_filters = train_args.num_filters

      padding = kernel // 2

      self.convBlock1 = nn.Sequential(
          nn.Conv2d(
              in_channels=num_in_channels, 
              out_channels=num_filters,
              kernel_size=kernel,
              padding=padding),
          nn.MaxPool2d(kernel_size=2),
          nn.BatchNorm2d(num_features=num_filters),
          nn.ReLU())
      
      self.convBlock2 = nn.Sequential(
          nn.Conv2d(
              in_channels=num_filters, 
              out_channels=num_filters*2,
              kernel_size=kernel,
              padding=padding),
          nn.MaxPool2d(kernel_size=2),
          nn.BatchNorm2d(num_features=num_filters*2),
          nn.ReLU())
      
      self.flatten = nn.Flatten()

      in_features = (num_filters*2)*(image_size//4)*(image_size//4)

      self.linearBlock = nn.Sequential(
          nn.Linear(in_features=in_features, out_features=num_classes),
      )
      

  def forward(self, x):
    convolved1 = self.convBlock1(x)
    convolved2 = self.convBlock2(convolved1)
    flattened = self.flatten(convolved2)
    prediction = self.linearBlock(flattened)
    return prediction

## MNIST Digit Classification Trainer

In [None]:
class MNIST_Classify():
  """Downloads the MNIST dataset, and trains models that predict the handwritten
  digit written on the provided 28x28 image."""

  _training_data = None
  _test_data = None

  @staticmethod
  def load_data(root: str = "data") -> Tuple[Dataset, Dataset]:
    """Downloads MNIST dataset into local file system, into the root directory. 
    Loads the dataset into memory.
    
    Parameters:
      root (str) -- the directory to download the dataset. (default: "data")
    
    Returns:
      tuple --  (training dataset, test dataset)
    """
    if MNIST_Classify._training_data is None:
      MNIST_Classify._training_data = datasets.MNIST(
          root=root,
          train=True,
          download=True,
          transform=ToTensor()
      )
    if MNIST_Classify._test_data is None:
      MNIST_Classify._test_data = datasets.MNIST(
          root=root,
          train=False,
          download=True,
          transform=ToTensor()
      )
    return MNIST_Classify._training_data, MNIST_Classify._test_data
    
  
  def _prepare_data(self, args: Config) -> None:
    """Prepares the dataset for training.

    Parameters:
        args (Config) --  hyperparameters of the model being trained
    """
    MNIST_Classify.load_data()

    self._train_dataloader = DataLoader(MNIST_Classify._training_data, batch_size=args.batch_size, shuffle=True)
    self._test_dataloader = DataLoader(MNIST_Classify._test_data, batch_size=args.batch_size, shuffle=True)

  def train_model(self, args: Config):
    """Trains an instance of a model provided in args with the given hyperparameters
    for the task of classifying handwritten digits using the MNIST dataset.

    Parameters:
        args (Config) --  stores the model and the hyperparameters in which the 
                          model instance should be trained

    Returns:
        The trained model instance
    """
    self._prepare_data(args)
    
    num_digits = 10
    num_channels = 1
    image_size = MNIST_Classify._training_data.data.size()[1] #28x28
    num_batches = len(self._train_dataloader)

    mnist_model = args.Model(image_size, num_digits, num_channels, args)
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params=mnist_model.parameters(), lr=args.learn_rate)
    
    mnist_model.train()

    for epoch in range(args.epochs):
      for i, (input, label) in enumerate(self._train_dataloader):
        optimizer.zero_grad()
        
        predictions = mnist_model(input)
        loss = loss_func(predictions, label)

        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
          print(f'Epoch [{epoch+1}/{args.epochs}], Step [{i+1}/{num_batches}], Loss: {loss.item():.4f}')
    return mnist_model


## h

In [None]:
# CONFIG
cnn_args = Config()
cnn_args_dict = {
    "experiment_name": "MNIST_digit_classification_CNN",
    "Model": MNIST_CNN,
    "seed": 0,
    "epochs": 10,
    "learn_rate": 0.001,
    "batch_size": 64,
    "kernel": 5,
    "num_filters": 16,
}
cnn_args.update(cnn_args_dict)

trainer = MNIST_Classify()
cnn_model = trainer.train_model(cnn_args)