# Synthetic Dataset Evaluator

In [None]:
num_of_digits = 3
dataset_path = '../../data'
samples_path = f'{dataset_path}/{num_of_digits}_digit_model'

In [None]:
# imports and utils
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
from time import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim

import torchvision
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder

In [None]:
model = torch.load(f'{dataset_path}/mnist_ocr_model')

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = ImageFolder(
    root=samples_path,
    transform=transform
)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=200,
    num_workers=1,
    shuffle=True
)

In [None]:
def imshow(img, title: str):
    plt.figure()
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.title(title)
    
def OCR(num_digits: int, valloader):
    images, _ = next(iter(valloader))
    
    for image in images:        
        # Split the n-digit image into n same equal parts
        res = ''
        for i in range(num_digits):
            img = image[0, :, :].view(28, 28 * num_digits)
            single_digit = img[:, (i)*28:(i+1)*28]
            single_digit_reshaped = single_digit.reshape(1, 28*28)

            # Turn off gradients to speed up this part
            with torch.no_grad():
                logps = model(single_digit_reshaped)

            # Output of the network are log-probabilities, need to take exponential for probabilities
            ps = torch.exp(logps)
            probab = list(ps.numpy()[0])
            res += str(probab.index(max(probab)))
        imshow(image, title=f"Predicted Number: {res}")

In [None]:
OCR(num_digits=num_of_digits, valloader=train_loader)