In [51]:
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

from collections import defaultdict
from PIL import Image, ImageOps
import numpy as np
import os

In [20]:
nasa_files = ["nasa1.png", "nasa2.jpeg", "nasa3.jpg", "nasa4.jpg"]
dataset_name = "10_fonts_4_sentences"

In [57]:
class TestDataset(Dataset):
    def __init__(self):
        super(TestDataset, self).__init__()
        self.w, self.h = 512, 71
       
        self.nasa_images = []
        for file in nasa_files:
            img = Image.open(f"test_scenario/preprocess_{file}").convert('RGB')
            img = ImageOps.invert(img).convert('L')
            w, h = img.size
            img = img.resize((int(w*self.h/h), self.h), Image.ANTIALIAS)
            img = img.crop((0, 0, self.w, self.h))
            self.nasa_images.append(np.array(img)[np.newaxis,:,:])
        self.n_nasa_images = len(self.nasa_images)
            
        self.images, self.fonts = [], []
        fonts = os.listdir(os.path.join("data", dataset_name))
        for font in fonts:
            font_path = os.path.join("data", dataset_name, font)
            filenames = os.listdir(font_path)
            for filename in filenames:
                self.fonts.append(font)
                img = Image.open(os.path.join(font_path, filename)).convert('RGB')
                img = ImageOps.invert(img).convert('L')
                img = img.crop((0, 0, self.w, self.h))
                self.images.append(np.array(img)[np.newaxis,:,:])
                
            

    def __getitem__(self, index):
        nasa_image_index = index % self.n_nasa_images
        image_index = index // self.n_nasa_images
        return (self.nasa_images[nasa_image_index], self.images[image_index], nasa_image_index, self.fonts[image_index])

    def __len__(self):
        return len(self.nasa_images)*len(self.images)

In [58]:
dataset = TestDataset()
loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [59]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

Device: cpu


In [16]:
model = torch.load("test_scenario/model", map_location=device)
model.eval()

Model(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): MaxPool2d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): MaxPool2d(kernel_size=(2, 3), stride=(2, 3), padding=0, dilation=1, ceil_mode=False)
    (4): Dropout(p=0.5, inplace=True)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (6): MaxPool2d(kernel_size=(2, 3), stride=(2, 3), padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (8): MaxPool2d(kernel_size=(2, 3), stride=(2, 3), padding=0, dilation=1, ceil_mode=False)
  )
  (siamese_feed_forward): Sequential(
    (0): Linear(in_features=6528, out_features=512, bias=True)
    (1): Dropout(p=0.5, inplace=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=512, out_features=64, bias=True)
    (4): Sigmoid()
  )
  (feed_forward): Sequential(
    (0): Linear(in_features=13056, out_features

In [60]:
nasa_font = "nasalization_old10"

In [61]:
predictions_test_scenario = defaultdict(list)

for img1, img2, nasa_image_index, font in loader:
    img1, img2 = img1.type(torch.FloatTensor).to(device), img2.type(torch.FloatTensor).to(device)
    predictions_test_scenario[f"{nasa_image_index.detach().item()}_{1 if font[0] == nasa_font else 0}"].append(model(img1, img2).detach().item())

In [62]:
def mean(l):
    return sum(l)/len(l)

In [63]:
for i, file in enumerate(nasa_files):
    print(f"\nFile {file}:")
    print('0- ', mean(predictions_test_scenario[f'{i}_0']))
    print('1- ', mean(predictions_test_scenario[f'{i}_1']))


File nasa1.png:
0-  0.7222222222222222
1-  1.0

File nasa2.jpeg:
0-  0.5833333333333334
1-  0.75

File nasa3.jpg:
0-  0.4722222222222222
1-  0.75

File nasa4.jpg:
0-  0.4722222222222222
1-  0.75
