In [None]:
#!/usr/bin/env python2.7
from __future__ import division

# ---
# Interactively explore failure cases of ResNet-50 on ImageNet 2012 Val.
# Li Ding, Oct. 20, 2017
# ---

from keras.applications.resnet50 import decode_predictions
import numpy as np
import xmltodict
import glob
import os


# set path
data_path = ''
img_path = data_path + '/ILSVRC2012_img_val'
label_path = data_path + '/val'


# find prediction file
assert os.path.isfile(data_path + "/val_prob.csv"), "Prediction file does not exist."
preds = np.genfromtxt(data_path + "/val_prob.csv", delimiter=",")


# get ground truth labels
names = glob.glob(label_path+'/*')
names.sort()
gt = []
for i in names:
    with open(i) as f:
        a = xmltodict.parse(f.read())['annotation']['object']
        l = []
        if isinstance(a,list):
            for j in a:
                l.append(j['name'])
        else:
            l.append(a['name'])
        gt.append(list(set(l)))


# number of images, 50,000 for val
n = 50000

# get top-5 prediction
p = decode_predictions(preds, top=5)
# get top-1 prediction
# p = decode_predictions(preds, top=1)


# evaluate
err = []
for i,j in zip(gt,p):
    err.append(sum([h not in [k[0] for k in j] for h in i])/len(i))

err = np.array(err)
ind = np.arange(n)
wrong = ind[err==1]

print 'Top-5 Error:', sum(err)/len(err)
# print 'Top-1 Error:', sum(err)/len(err)


# get true labels, instead of IDs

from nltk.corpus import wordnet
syns = list(wordnet.all_synsets())
offsets_list = [(s.offset(), s) for s in syns]
offsets_dict = dict(offsets_list)


In [None]:
# explore failure cases
from IPython.display import display,Image
from ipywidgets import interact, widgets, interactive, Layout

# show a failure case
def show_wrong(index):
    print 'Ground Truth:'
    print '------',str(offsets_dict[int(gt[wrong[index]][0][1:])]).split("'")[1].split('.')[0]
    print 'Prediction: '
    for i in p[wrong[index]]:
        print '------', i[1], '%.2f'%i[2]
    display(Image('/media/lding/Data/ImgNet/ILSVRC2012_img_val/ILSVRC2012_val_{}.JPEG'
                  .format('%08d'%(wrong[index]+1))))

# interactive widgets
play = widgets.Play(interval=2000, value=0, min=0, max=len(wrong)-1, step=1, 
                    description="Press play", disabled=False)
slider = widgets.IntSlider(min=0, max=len(wrong)-1, step=1, value=0, layout=Layout(width='80%'))
w = interactive(show_wrong, index=slider)
widgets.jslink((play, 'value'), (slider, 'value'))
widgets.VBox([play, w])