In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms, utils
import os
import numpy as np
import sys
import matplotlib.pyplot as plt
from PIL import Image
import random

sys.path.insert(0,'/home/kylecshan/Kaushik/')
from model import initialize_model


In [None]:
input_size = 224
base_dir = '/home/kylecshan/data/images224/train_ms2000_v3/'
test_dir = base_dir + 'val/'
train_dir = base_dir + 'train/'
batch_size = 256

data_transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

test_dataset = datasets.ImageFolder(test_dir, data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle = False, num_workers=4)

In [None]:
model_path = '/home/kylecshan/saved/05_16_2019/model815.pth'

model_name = "densenet169"
freeze_layers = 100
num_classes = 2000

In [None]:
model = initialize_model(model_name, num_classes, freeze_layers, use_pretrained=True)
model = nn.DataParallel(model)
load_dict = torch.load(model_path)
model.load_state_dict(load_dict)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
print('')

In [None]:

top = (1,5)
correct = {x : np.zeros(num_classes) for x in top}
seen = {x : np.zeros(num_classes) for x in top}
top_pred = {}

wrong = []

for i, (inputs, labels) in enumerate(test_loader):
    inputs = inputs.to(device)
    labels = labels.numpy()
    # forward
    print('batch %d / %d' % (i, len(test_loader.dataset)//batch_size))
    with torch.set_grad_enabled(False):
        outputs = model(inputs)
        outputs.grad_enabled = False

        for k in top:
            _, pred = torch.topk(outputs, k=k, dim=1)
            top_pred[k] = pred.cpu().numpy()

            batch_correct = np.sum(top_pred[k] == labels[:,None], axis=1)
            for i in range(len(labels)):
                actual = labels[i]
                correct[k][actual] += batch_correct[i]
                seen[k][actual] += 1

                if k == 5:
                    t1 = torch.argmax(outputs[i])
                    if t1.item() != labels[i]:
                        next_wrong = (inputs[i].cpu(), np.max(top_pred[k][i,:]), labels[i], top_pred[k][i,:])
                        wrong.append(next_wrong)
                    
        del outputs

test_acc = {}
for k in top:
    test_acc[k] = np.zeros(num_classes)
    for i in range(num_classes):
        test_acc[k][i] = 0 if seen[k][i] == 0 else correct[k][i] / seen[k][i]
    print('Top %d accuracy: %f' % (k, np.sum(correct[k]) / np.sum(seen[k])))


In [None]:
vis_transform = transforms.Compose([
                transforms.Resize(input_size),
                transforms.CenterCrop(input_size),
                transforms.ToTensor(),
            ])

def example_grid(classid, base_dir, num_examples):
    img_dir = base_dir + str(classid) + '/'
    img_arr = torch.empty((num_examples, 3, 224, 224))
    imgs = random.sample(os.listdir(img_dir), num_examples)
    for i in range(num_examples):
        file = imgs[i]
        path = img_dir + file
        img = Image.open(path)
        img = vis_transform(img)
        img_arr[i] = img
        if i == num_examples-1:
            break
    
    return utils.make_grid(img_arr, nrow=3)

In [None]:
def visualize_error(idx_wrong):
# idx_wrong is between 0 and the number of misclassified entries minus 1
    i = idx_wrong
    image_array = wrong[i][0].cpu().numpy()
    mean=np.array([0.485, 0.456, 0.406])
    std=np.array([0.229, 0.224, 0.225])
    image_array = image_array*std[:,None,None] + mean[:,None,None]
    image_array = image_array.clip(0, 1)
    image_array = image_array.transpose((1,2,0))

    plt.rcParams['figure.figsize'] = [16, 8]

    f, axarr = plt.subplots(1,3)
    axarr[0].imshow(image_array)
    axarr[0].set_title('Incorrectly Classified Image')
    axarr[0].axis('off')
    
    grid = example_grid(test_dataset.classes[wrong[i][1]], train_dir, 9)
    axarr[1].imshow(grid.numpy().transpose((1, 2, 0)))
    axarr[1].set_title('Predicted Class')
    axarr[1].axis('off')

    grid = example_grid(test_dataset.classes[wrong[i][2]], train_dir, 9)
    axarr[2].imshow(grid.numpy().transpose((1, 2, 0)))
    axarr[2].set_title('Actual Class')
    axarr[2].axis('off')
    return wrong[i][1:]

In [None]:
visualize_error(3)

In [None]:
visualize_error(4)

In [None]:
visualize_error(2)

In [None]:
len(wrong)