In [None]:
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.collections import LineCollection
import seaborn as sns

In [None]:
import torch
import numpy as np
nt = 'cifar'
true_epsilon = 0.0347
# nt = 'svhn'
# true_epsilon = 0.0078
lt = 'test'
correct = torch.load('pixel2_eps/{}_small_correct_epsilons.pth'.format(nt)).cpu().data
incorrect = torch.load('pixel2_eps/{}_small_incorrect_epsilons.pth'.format(nt)).cpu().data

In [None]:

epsilons = torch.cat([correct, incorrect],0)
epsilons = epsilons[:len(epsilons)-(len(epsilons) % 1000)]
correctness = torch.cat([torch.ones_like(correct), torch.zeros_like(incorrect)],0)
sort_eps, idx = torch.sort(epsilons)
sort_correct = correctness[idx]

def plot_epsilon(axs, sort_eps, sort_correct, true_epsilon, n=1000, title=None):
    x = torch.arange(n).numpy()*sort_eps.size(0)//n
    y = sort_eps.view(n,sort_eps.size(0)//n).mean(1).numpy()

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    c = sort_correct.view(n,sort_eps.size(0)//n).mean(1).numpy()
    c = np.array([c[:-1], c[1:]]).mean(axis=0)
#     print(sort_eps.size(0))
    axs.plot(np.arange(sort_eps.size(0)), np.ones(sort_eps.size(0))*true_epsilon, alpha=0.0, color='green')
    xlim = axs.get_xlim()
    axs.plot(np.array([-1e10, 1e10]), np.ones(2)*true_epsilon, alpha=0.7)
    # Create a continuous norm to map from data points to colors
    # norm = plt.Normalize(dydx.min(), dydx.max())
    lc = LineCollection(segments, cmap='viridis', norm=plt.Normalize(vmin=0.35,vmax=1))
    # Set the values used for colormapping
    lc.set_array(c)
    lc.set_linewidth(5)
    line = axs.add_collection(lc)
    cb = fig.colorbar(line, ax=axs,)# orientation='horizontal')
    cb.set_label('Model accuracy')

    axs.set_yscale('log')
    axs.set_xlim(*xlim)
    axs.set_ylim(5e-4, 0.5)
    #axs.set_xlabel('Datapoint #')
    #axs.set_ylabel('$\epsilon$-distance')
    #axs.set_title(title)
    return line, axs

In [None]:
fig, ax = plt.subplots(1)
line, ax = plot_epsilon(ax, sort_eps, sort_correct, true_epsilon, title='$\epsilon$-distances for {} network on {} set'.format(nt, lt))

In [None]:
fig, axs = plt.subplots(1,2)
axs[0].plot(sorted(correct))
axs[0].plot([0,len(correct)], [true_epsilon, true_epsilon])
axs[0].set_title('eps dist for correct labels')
axs[1].plot(sorted(incorrect))
axs[1].plot([0,len(incorrect)], [true_epsilon, true_epsilon])
axs[1].set_title('eps dist for incorrect labels')
print("Robust error at eps: {:.4f}".format(1-(correct>=true_epsilon).sum()/(len(correct) + len(incorrect))))
print("Fraction of correct examples that aren't certified: {:.4f} ({:d})".format((correct<true_epsilon).sum()/len(correct), (correct<true_epsilon).sum()))
print("Fraction of incorrect examples that are more than eps away: {:.4f} ({:d})".format((incorrect>true_epsilon).sum()/len(incorrect), (incorrect>true_epsilon).sum()))
print("Examples left for cascade: {}/{}".format((correct<true_epsilon).sum() + (incorrect<true_epsilon).sum(), len(correct) + len(incorrect)))

In [None]:
plt.plot(sorted(incorrect))
plt.plot([0,len(incorrect)], [true_epsilon, true_epsilon])