In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import os
import pandas as pd
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import urllib
#------------------
from dataset import RandomTestDataset
from resnet_family import resnet20_cifar

## Evaluate the confidence of the model on unseen categories

In [None]:
# image = image.resize((32, 32))
transform = transforms.Compose([transforms.Resize((32, 32)), 
                                       transforms.ToTensor(), 
                                       transforms.Normalize(mean=[0.491, 0.482, 0.446], std=[0.247, 0.243, 0.261])])
seen_classes = RandomTestDataset('./images/seen_cifar10', transform=transform)
unseen_classes = RandomTestDataset('./images/unseen_cifar10', transform=transform)
seen_dataloader = DataLoader(seen_classes, batch_size=1, shuffle=False)
unseen_dataloader = DataLoader(unseen_classes, batch_size=1, shuffle=False)

In [None]:
labels_cifar10 = ['airplanes', 'cars', 'birds', 'cats', 'deer', 'dogs', 'frogs', 'horses', 'ships', 'trucks']

model = resnet20_cifar()
model.eval()

seen = []
unseen = []

for i_iter, batch in enumerate(seen_dataloader):
    image, file_name = batch
    output = model(image)
    score, label = F.softmax(output, dim=1).max(1)
    full_file_name = os.path.join('./images/seen_cifar10', file_name[0])
    seen.append({'fname': full_file_name, 'prediction': labels_cifar10[label], 'score': np.round(score.item(), 3)})
    
for i_iter, batch in enumerate(unseen_dataloader):
    image, file_name = batch
    output = model(image)
    score, label = F.softmax(output, dim=1).max(1)
    full_file_name = os.path.join('./images/unseen_cifar10', file_name[0])
    unseen.append({'fname': full_file_name, 'prediction': labels_cifar10[label], 'score': np.round(score.item(), 3)})

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(16, 8))
for i in range(len(seen)):
    image = Image.open(seen[i]['fname'])
    ax[0, i].imshow(image)
    ax[0, i].set_title(f"{seen[i]['prediction']}, {seen[i]['score']}")
    image = Image.open(unseen[i]['fname'])
    ax[1, i].imshow(image)
    ax[1, i].set_title(f"{unseen[i]['prediction']}, {unseen[i]['score']}")