In [2]:
%matplotlib inline
import os, sys, gc
import shutil
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from scipy.stats import entropy as scent
from matplotlib import gridspec
import matplotlib.image as mpimg

import collections

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def ent_fn(X):
    """Return row-wise entropy of X (samples by probs)"""
    ents = np.empty(X.shape[0])
    for i, row in enumerate(X):
        ents[i] = scent(row)
    return ents

loadDir = 'base_predictions'

# raw bins
bins = np.load('{0}/human_bincounts.npy'.format(loadDir))

#raw probabilities
humans = bins / np.sum(bins, axis = 1)[:, np.newaxis]

# smoothed probabilities
humans_smoothed = (bins +1) / np.sum(bins, axis = 1)[:, np.newaxis]

ordered_filenames = np.load('{0}/decoded_test_filename_order.npy'.format(loadDir))

labels = ['P', 'A', 'B', 'C', 'De', 'Do', 'F', 'H', 'S', 'T']

im_dir = '/home/battleday/Academic/Berkeley/Superman/local/images/train_set_combined'

In [3]:
train_files = os.listdir('{0}/train/'.format(loadDir))
print(train_files)
train_files = sorted([m for m in train_files if m[-4:] == '.npz'])
print(train_files)

test_files = os.listdir('{0}/test/'.format(loadDir))
print(test_files)
test_files = sorted([p for p in test_files if p[-4:] == '.npz'])
print(test_files)

['vgg_15_BN_64_train.npz', 'resnet_basic_110_train.npz', 'resnet_preact_bottleneck_164_train.npz', 'shake_shake_26_2x64d_SSI_cutout16_train.npz', 'densenet_BC_100_12_train.npz', 'wrn_28_10_train.npz', 'resnext_29_8x64d_train.npz', 'pyramidnet_basic_110_270_train.npz']
['densenet_BC_100_12_train.npz', 'pyramidnet_basic_110_270_train.npz', 'resnet_basic_110_train.npz', 'resnet_preact_bottleneck_164_train.npz', 'resnext_29_8x64d_train.npz', 'shake_shake_26_2x64d_SSI_cutout16_train.npz', 'vgg_15_BN_64_train.npz', 'wrn_28_10_train.npz']
['resnext_29_8x64d_test.npz', 'resnet_preact_bottleneck_164_test.npz', 'densenet_BC_100_12_test.npz', 'vgg_15_BN_64_test.npz', 'pyramidnet_basic_110_270_test.npz', 'wrn_28_10_test.npz', 'shake_shake_26_2x64d_SSI_cutout16_test.npz', 'resnet_basic_110_test.npz']
['densenet_BC_100_12_test.npz', 'pyramidnet_basic_110_270_test.npz', 'resnet_basic_110_test.npz', 'resnet_preact_bottleneck_164_test.npz', 'resnext_29_8x64d_test.npz', 'shake_shake_26_2x64d_SSI_cutout1

In [34]:
# test one output to prob
test = np.load('{0}/train/{1}'.format(loadDir, train_files[4]))

print(test.keys())
print(test['logits'].shape, test['probs'].shape)
guess = np.argmax(test['probs'], axis = 1)
print(guess.shape)
print(np.sum(test['labels'] == guess))

for i in [0, -1]:
    out = test['logits'][i, :]
    probs = test['probs'][i, :]
    soft = softmax(out)
    print(np.abs(probs-soft)<0.001)

for l in test['labels'][-10:]:
    print(labels[l])

['labels', 'logits', 'probs']
(50000, 10) (50000, 10)
(50000,)
50000
[ True  True  True  True  True  True  True  True  True  True]
[ True  True  True  True  True  True  True  True  True  True]
De
B
P
A
P
B
F
T
A
A


In [6]:
train_dict = collections.OrderedDict()
for m in train_files:
    raw = np.load('{0}/train/{1}'.format(loadDir, m))
    model = m.split('.')[0]
    train_dict[model] = {}
    for prop in raw.keys(): 
        train_dict[model][prop] = raw[prop]
    train_dict[model]['entropy'] = ent_fn(train_dict[model]['probs'])
print(train_dict.keys())

odict_keys(['densenet_BC_100_12_train', 'pyramidnet_basic_110_270_train', 'resnet_basic_110_train', 'resnet_preact_bottleneck_164_train', 'resnext_29_8x64d_train', 'shake_shake_26_2x64d_SSI_cutout16_train', 'vgg_15_BN_64_train', 'wrn_28_10_train'])


In [7]:
test_dict = collections.OrderedDict()
for m in test_files:
    raw = np.load('{0}/test/{1}'.format(loadDir, m))
    model = m.split('.')[0]
    test_dict[model] = {}
    for prop in raw.keys(): 
        test_dict[model][prop] = raw[prop]
    test_dict[model]['entropy'] = ent_fn(test_dict[model]['probs'])
print(test_dict.keys())

odict_keys(['densenet_BC_100_12_test', 'pyramidnet_basic_110_270_test', 'resnet_basic_110_test', 'resnet_preact_bottleneck_164_test', 'resnext_29_8x64d_test', 'shake_shake_26_2x64d_SSI_cutout16_test', 'vgg_15_BN_64_test', 'wrn_28_10_test'])


In [18]:
averageEnt = []
for key, value in train_dict.items():
    averageEnt.append(value['entropy'])
averageEnt = np.mean(averageEnt, axis=0)
print(averageEnt.shape)

num_ims = 5
top = np.argsort(averageEnt)[::-1]
print(top.shape)

(50000,)
(5,)


In [None]:
def fig_fn(title, save_path, ims, 
           num_ims, added_rows, bolded_axes,
           base_rows = 7):
    """still a bunch of other things defined above
    that aren't in here"""
    extended_ims = ims.copy()[-(num_ims * (added_rows -1)):]
    ims = ims.copy()[-num_ims:]
    
    
    fig = plt.figure(figsize=((num_ims) * 2, added_rows + base_rows)) 
    gs = gridspec.GridSpec(added_rows + base_rows, num_ims)

    for i in np.arange(num_ims * (added_rows - 1)):
        r, c = np.unravel_index(i, (added_rows - 1, num_ims))
        #print(r, c)
        ax = plt.subplot(gs[r, c])
        im = extended_ims[i]
        im_name = ordered_filenames[im]
        ax.set_xticks([])
        ax.set_yticks([])
        img = mpimg.imread(im_dir + '/' + im_name)
        #ax_im_b.set_title('High certainty image', fontsize = 20)
        ax.imshow(img)


    fig.suptitle(title, fontsize = 20, fontweight='bold')
    for i in np.arange(num_ims):
        im = ims[i]
        ax_im_b = plt.subplot(gs[added_rows:3 + added_rows, i]) # size of im
        im_name_b = ordered_filenames[im]
        ax_im_b.xaxis.set_ticklabels([])
        ax_im_b.set_xticks([])
        ax_im_b.set_yticks([])
        #ax_im_b.set_title(labels[0], fontsize = 12)
        ax_im_b.yaxis.set_ticklabels([])
        img_b = mpimg.imread(im_dir + '/' + im_name_b)
        #ax_im_b.set_title('High certainty image', fontsize = 20)
        ax_im_b.imshow(img_b)


        im_guesses = humans[im]
        #print(im_guesses)
        ax_hist_b = plt.subplot(gs[added_rows + 3, i])
        ax_hist_b.bar(np.arange(10), im_guesses) #, align = 'left')
        ax_hist_b.set_xlim([-1, 10])
        ax_hist_b.set_ylim([0, 1]) 

        if ax_hist_b.is_first_col():
            if 0 in bolded_axes:
                ax_hist_b.set_ylabel('Human Pr', fontsize = 12, fontweight = 'bold')
            else:
                ax_hist_b.set_ylabel('Human Pr', fontsize = 12)
                # set to count
        #ax_hist_b.set_xlabel('Category', fontsize = 24) # set to count
        else:
            ax_hist_b.yaxis.set_ticks([])

        ax_hist_b.xaxis.set_ticks([])


        im_guesses = NN[im]
        #print(im_guesses)
        ax_hist_n = plt.subplot(gs[added_rows + 4, i])
        ax_hist_n.bar(np.arange(10), im_guesses) #, align = 'left')
        ax_hist_n.set_xlim([-1, 10])
        ax_hist_n.set_ylim([0, 1]) 

        if ax_hist_n.is_first_col():
            if 1 in bolded_axes:
                ax_hist_n.set_ylabel('NN Pr', fontsize = 12, fontweight = 'bold') 
            else: 
                ax_hist_n.set_ylabel('NN Pr', fontsize = 12) 
                    # set to count
        else:
            ax_hist_n.yaxis.set_ticks([])
        ax_hist_n.xaxis.set_ticks([])
        #ax_hist_n.xaxis.set_ticks(np.arange(10))

        im_guesses = PT[im]

        #print(im_guesses)
        ax_hist_p1 = plt.subplot(gs[added_rows + 5, i])
        ax_hist_p1.bar(np.arange(10), im_guesses) #, align = 'left')
        ax_hist_p1.set_xlim([-1, 10])
        ax_hist_p1.set_ylim([0, 1]) 

        if ax_hist_p1.is_first_col():
            if 2 in bolded_axes:
                ax_hist_p1.set_ylabel('PT Pr', fontsize = 12, fontweight = 'bold') 
            else:
                ax_hist_p1.set_ylabel('PT Pr', fontsize = 12)
                # set to count
        else:
            ax_hist_p1.yaxis.set_ticks([])

        #ax_hist_n.xaxis.set_ticks(np.arange(10))
        ax_hist_p1.xaxis.set_ticks([])

        im_guesses = EX[im]
        #print(im_guesses)
        ax_hist_p2 = plt.subplot(gs[added_rows + 6, i])
        ax_hist_p2.bar(np.arange(10), im_guesses) #, align = 'left')
        ax_hist_p2.set_xlim([-1, 10])
        ax_hist_p2.set_ylim([0, 1]) 

        if ax_hist_p2.is_first_col():
            if 3 in bolded_axes:
                ax_hist_p2.set_ylabel('EX Pr', fontsize = 12, fontweight = 'bold') 
            else:
                ax_hist_p2.set_ylabel('EX Pr', fontsize = 12) # set to count
        else:
            ax_hist_p2.yaxis.set_ticks([])

        ax_hist_p2.xaxis.set_ticks(np.arange(num_ims))
        ax_hist_p2.xaxis.set_ticklabels(labels)
        plt.xticks(rotation = -45, ha='center')

    plt.savefig(save_path + '/' + title + '.png')

    