# Synthetic Dataset Evaluator

In [74]:
num_of_digits = 5
distribution = 'uniform'
dataset_path = '../../data'
model_path = '/Users/razoren/data18_5/mnist_single_digit_model.zip'
samples_dir = '/Users/razoren/data18_5/uniform_5_digit_model/samples'
results_dir = '/Users/razoren/uniform_5_digit_model'

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [75]:
# imports and utils
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
from time import time
import asyncio
import csv

import PIL
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim

import torchvision
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [76]:
# Load Model
model = torch.load(model_path)

In [77]:
# Load Data
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=(0.5), std=(0.5))])

images = []
for image_path in os.listdir(samples_dir):
    if image_path != '.DS_Store':
        try:
            image = Image.open(os.path.join(samples_dir, image_path))
            image = transform(image)
            images.append(image)
        except PIL.UnidentifiedImageError:
            print("cannotidentify")

In [78]:
def imshow(img, title: str):
    plt.figure()
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.title(title)
    
def save_result_to_csv(results: list[str]):
    path = f"{results_dir}/{num_of_digits}_{distribution}_ocr_results.csv"
    with open(
        path,
        "w",
    ) as f:
        writer = csv.writer(f)
        writer.writerow(results)
    print(f"saved results to {path}")

    
def OCR(images):
    results = []
    for i, image in enumerate(images):
        num_of_digits = int(image.shape[2] / 28)

        # Split the n-digit image into n same equal parts
        res = ''
        for i in range(num_of_digits):
            img = image[0, :, :].view(28, 28 * num_of_digits)
            single_digit = img[:, (i)*28:(i+1)*28]
            single_digit_reshaped = single_digit.reshape(1, 28*28)

            # Turn off gradients to speed up this part
            with torch.no_grad():
                logps = model(single_digit_reshaped)

            # Output of the network are log-probabilities, need to take exponential for probabilities
            ps = torch.exp(logps)
            probab = list(ps.numpy()[0])
            res += str(probab.index(max(probab)))
        
        results.append(res)

        if i % 100 == 0:
            print(f"finished {i} images")
            
    save_result_to_csv(results)

In [79]:
OCR(images)

saved results to /Users/razoren/uniform_5_digit_model/5_uniform_ocr_results.csv
