# MNIST OCR Model
##### Taken from [here](https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627)
This model trains on an n-digit mnist dataset, it takes much longer to train but does get good accuracy after being trained on a big enough data set.

## Initial setup

In [None]:
num_of_digits = 3
dataset_path = "../../data"

In [None]:
# imports and utils
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from time import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim

import torchvision
from torchvision import datasets, transforms


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def view_classify(img, ps):
    ''' Function for viewing an image and it's predicted classes.'''
    fig, ax1 = plt.subplots(figsize=(6,9), ncols=1)
    ax1.imshow(img.resize_(1, 28, 28 * num_of_digits).numpy().squeeze())
    ax1.axis('off')
    plt.tight_layout()

## Load Multi Digit MNIST Data Set

In [None]:
train_data = torch.load(f'{dataset_path}/{num_of_digits}_digit_model/mnist_{num_of_digits}_digit_train_data')
test_data = torch.load(f'{dataset_path}/{num_of_digits}_digit_model/mnist_{num_of_digits}_digit_test_data')

In [None]:
trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

## Training - start here if you want to train a new model
##### Only do this if really necessary because it takes a really long time to train this model. Otherwise, just load the pre trained model below.

In [None]:
# Visualize the data
dataiter = iter(trainloader)
images, labels = dataiter._next_data()
print(f'images type: {type(images)}')
print(f'images shape: {images.shape}')
print(f'labels shape: {labels.shape}')

print(f'label: {labels[0]}')
plt.imshow(images[0].numpy().squeeze(), cmap='gray_r');

figure = plt.figure()
num_of_images = 60
for index in range(1, num_of_images + 1):
    plt.subplot(6, 10, index)
    plt.axis('off')
    plt.imshow(images[index].numpy().squeeze(), cmap='gray_r')

In [None]:
# Layer details for the neural network
input_size = 784 * num_of_digits # = 28 * 28
hidden_sizes = [512 * num_of_digits, 256 * num_of_digits, 128 * num_of_digits, 64 * num_of_digits]
output_size = 10**num_of_digits
print(output_size)

# Build a feed-forward network
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[1], hidden_sizes[2]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[2], hidden_sizes[3]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[3], output_size),
                      nn.LogSoftmax(dim=1))
print(model)

In [None]:
# Set loss criterion
criterion = nn.NLLLoss()
images, labels = next(iter(trainloader))
images = images.view(images.shape[0], -1)

logps = model(images)
loss = criterion(logps, labels)

In [None]:
# Train the model
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
time0 = time()
epochs = 100
for e in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        # Flatten MNIST images into a 784 long vector
        images = images.view(images.shape[0], -1)
    
        # Training pass
        optimizer.zero_grad()
        
        output = model(images)
        loss = criterion(output, labels)
        
        #This is where the model learns by backpropagating
        loss.backward()
        
        #And optimizes its weights here
        optimizer.step()
        
        running_loss += loss.item()
    else:
        print("Epoch {} - Training loss: {}".format(e, running_loss/len(trainloader)))
print("\nTraining Time (in minutes) =",(time()-time0)/60)


In [None]:
# Save model
torch.save(model, f'{dataset_path}/{num_of_digits}_digit_model/mnist_{num_of_digits}_digit_ocr_model')

## Load OCR model - start here if you want to load an existing model


In [None]:
# Load model
model = torch.load(f'{dataset_path}/{num_of_digits}_digit_model/mnist_{num_of_digits}_digit_ocr_model')

## Evaluate the model

In [None]:
# Visualize the model's ability to classify digits
images, labels = next(iter(valloader))

img = images[0].view(1, 784 * num_of_digits)

# Turn off gradients to speed up this part
with torch.no_grad():
    logps = model(img)

# Output of the network are log-probabilities, need to take exponential for probabilities
ps = torch.exp(logps)
probab = list(ps.numpy()[0])
print("Predicted Digit =", probab.index(max(probab)))
view_classify(img.view(1, 28, 28 * num_of_digits), ps)

In [None]:
# Evaluate the model's accuracy
correct_count, all_count = 0, 0
for images,labels in valloader:
  for i in range(len(labels)):
    img = images[i].view(1, 784 * num_of_digits)
    # Turn off gradients to speed up this part
    with torch.no_grad():
        logps = model(img)

    # Output of the network are log-probabilities, need to take exponential for probabilities
    ps = torch.exp(logps)
    probab = list(ps.numpy()[0])
    pred_label = probab.index(max(probab))
    true_label = labels.numpy()[i]
    if(true_label == pred_label):
      correct_count += 1
    all_count += 1

print(f"Number Of Images Tested = {all_count}")
print(f"Model Accuracy = {round(correct_count*100/all_count, 2)}%")