# MNIST OCR Model
##### Taken from [here](https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627)
##### This multi-digit mnist ocr uses our original MNIST classifier model on multi digit images that we split into single digit images and classify seperately. In order to use this notebook, train the model in [this](https://colab.research.google.com/drive/15_s6DFZJZFFgFvhzl0Xkm7-uGurq288N?usp=sharing#) notebook, save it, and load it in here.

## Initial setup

In [2]:
num_of_digits = 3
dataset_path = "../../data"

In [3]:
# imports and utils
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim

import torchvision
from torchvision import datasets, transforms


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def view_classify(img, ps):
    ''' Function for viewing an image and it's predicted classes.'''
    ps = ps.data.numpy().squeeze()

    fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
    ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
    ax1.axis('off')
    ax2.barh(np.arange(10), ps)
    ax2.set_aspect(0.1)
    ax2.set_yticks(np.arange(10))
    ax2.set_yticklabels(np.arange(10))
    ax2.set_title('Class Probability')
    ax2.set_xlim(0, 1.1)
    plt.tight_layout()

## Load Multi-Digit Dataset

In [4]:
train_data = torch.load(f'{dataset_path}/{num_of_digits}_digit_model/mnist_{num_of_digits}_digit_train_data')
test_data = torch.load(f'{dataset_path}/{num_of_digits}_digit_model/mnist_{num_of_digits}_digit_test_data')

In [None]:
# Visualize the Data Set
fig2, axes = plt.subplots(3,3)
fig2.tight_layout()
for i in range(9):
  sub = axes[int(i/3), i%3]
  sub.imshow(train_data[i][0][0], cmap='gray', interpolation='none')
  sub.set_title("Ground Truth: {}".format(train_data[i][1])) 
  sub.set_xticks([])
  sub.set_yticks([])

In [None]:
trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

## Load OCR model - start here if you want to load an existing model


In [None]:
# Load model
model = torch.load('gdrive/My Drive/mnist/mnist_ocr_model')

## Evaluate the model

In [None]:
def OCR(num_digits: int, valloader):
  images, labels = next(iter(valloader))
  
  figure = plt.figure()
  plt.axis('off')
  plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')

  # Split the n-digit image into n same equal parts
  for i in range(num_digits):
    img = images[0].view(28, 28 * num_digits)
    single_digit = img[:, (i)*28:(i+1)*28]
    single_digit_reshaped = single_digit.reshape(1, 28*28)

    # Turn off gradients to speed up this part
    with torch.no_grad():
        logps = model(single_digit_reshaped)

    # Output of the network are log-probabilities, need to take exponential for probabilities
    ps = torch.exp(logps)
    probab = list(ps.numpy()[0])
    print("Predicted Digit =", probab.index(max(probab)))
    view_classify(single_digit_reshaped.view(1, 28, 28), ps)

In [None]:
OCR(num_digits=4, valloader=valloader)