In [None]:
%load_ext autoreload
%autoreload 2

import sys

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from matplotlib.patches import FancyArrowPatch
import matplotlib.gridspec as gridspec
from matplotlib.offsetbox import (OffsetImage, AnnotationBbox)
import matplotlib.patches as mpatches
import proplot as pplt

import pandas as pd
import numpy as np
import torch
from scipy.stats import entropy
from torchvision import datasets, transforms

sys.path.insert(0, './..')
sys.path.insert(0, '../data')
sys.path.insert(0, '../Curvature')

import plots as pl
from curve_utils import load_mnist, load_cifar, tab_name_to_hex

sys.path.insert(0, './../../')

import response_contour_analysis.utils.plotting as plot_utils

In [None]:
dataset_type = 1

In [None]:
plot_settings = {
        "text.usetex": True,
        "font.family": "serif",
        "font.size": 8,#12,
        "axes.formatter.use_mathtext":True,
}
figwidth = '13.968cm'
figwidth_inch = 5.50107 
dpi = 600
pplt.rc.update(plot_settings)
mpl.rcParams.update(plot_settings)

In [None]:
plot_colors = [tab_name_to_hex('tab:blue'), 
               tab_name_to_hex('tab:red'), 
               tab_name_to_hex('tab:green'), 
               tab_name_to_hex('tab:orange'), 
               tab_name_to_hex('tab:purple'), 
               tab_name_to_hex('tab:brown'), 
               tab_name_to_hex('tab:pink'), 
               tab_name_to_hex('tab:gray'), 
               tab_name_to_hex('tab:olive'), 
               tab_name_to_hex('tab:cyan')]  

two_plot_colors = plot_colors[:2]
model_types = ('Natural', 'Adversarial')

# Load data & models

In [None]:
if dataset_type == 0: # MNIST
    seed = 0
    model_natural, data_natural, model_robust, data_robust = load_mnist(code_directory='../../', seed=0)
    boundary_dists = [np.load(f'../data/distance_to_boundary_natural_{seed}.npz')['data'],
             np.load(f'../data/distance_to_boundary_robust_{seed}.npz')['data']]
    cmap = 'Gray'
    data_name = 'mnist'
    data_labels = [str(i) for i in range(10)]
else: # CIFAR
    model_natural, data_natural, model_robust, data_robust = load_cifar(code_directory='../../')
    ## TODO: Swap for CIFAR
    boundary_dists = [np.load(f'../data/distance_to_boundary_natural_0.npz')['data'],
             np.load(f'../data/distance_to_boundary_robust_0.npz')['data']]
    cmap = None
    data_name = 'cifar'
    data_labels = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

advs_nat = data_natural['advs']
pert_lengths_nat = data_natural['pert_lengths']
classes_nat = data_natural['adv_class']
dirs_nat = data_natural['dirs']
images_nat = data_natural['images']
labels_nat = data_natural['labels']

advs_robust = data_robust['advs']
pert_lengths_robust = data_robust['pert_lengths']
classes_robust = data_robust['adv_class']
dirs_robust = data_robust['dirs']
images_robust = data_robust['images']
labels_robust = data_robust['labels']

In [None]:
# plot with all adversarials included
fig , ax = pl.plot_pert_lengths([pert_lengths_nat, pert_lengths_robust],  n=15, labels=['Naturally','Robust'], ord=2)
fig.set_figheight((8/12) * figwidth_inch)
fig.set_figwidth(figwidth_inch)
fig.set_dpi(dpi)
plt.show()

In [None]:
advs_frac_found = np.zeros((2, 10, 10)) # [model_type, gt_label, fraction_label_to]
advs_total_counts = np.zeros((2, 10, 10)) # [model_type, gt_label, fraction_label_to]
advs_per_class = np.zeros((2, 10)) # [model_type, gt_label]
advs_num_classes = [] # [model_type, gt_label]
for model_idx, (classes_, pert_lengths_, labels_) in enumerate(zip([classes_nat, classes_robust], [pert_lengths_nat, pert_lengths_robust], [labels_nat, labels_robust])):
    mask_idx = np.invert(np.isnan(pert_lengths_)).sum(1) > 1
    classes_[np.isnan(pert_lengths_)] = np.nan
    masked_classes = classes_[mask_idx]
    masked_labels = labels_[mask_idx]
    sub_advs_num_classes = []
    for label_idx in range(10):
        unique_gt_classes = []
        for gt_classes in masked_classes[masked_labels == label_idx]:
            masked_gt_classes = gt_classes[~np.isnan(gt_classes)]
            if data_name == 'mnist': # variable number of advs found
                unique_gt_classes.append((len(np.unique(masked_gt_classes))-1)/(len(masked_gt_classes)-1))
            else: # always the same number of advs found
                unique_gt_classes.append((len(np.unique(masked_gt_classes))-1))
        sub_advs_num_classes.append(np.array(unique_gt_classes))
        unique_classes, unique_counts = np.unique(masked_classes[masked_labels == label_idx], return_counts=True)
        unique_counts = unique_counts[~np.isnan(unique_classes)]
        total_counts = np.sum(unique_counts)
        unique_classes = unique_classes[~np.isnan(unique_classes)]
        advs_total_counts[model_idx, label_idx, unique_classes.astype(int)] = total_counts
        advs_frac_found[model_idx, label_idx, unique_classes.astype(int)] = unique_counts / total_counts
        label_pert_lengths = pert_lengths_[labels_ == label_idx]
        advs_per_class[model_idx, label_idx] = np.mean(np.invert(np.isnan(label_pert_lengths)).sum(1))
    advs_num_classes.append(sub_advs_num_classes)

In [None]:
state = np.random.RandomState(51423)
data = state.rand(20, 8).cumsum(axis=0).cumsum(axis=1)[:, ::-1]
data = data + 20 * state.normal(size=(20, 8)) + 30

pd_data = pd.DataFrame(data, columns=np.arange(0, 16, 2))
pd_data.columns.name = 'column number'
pd_data.name = 'variable'

fig, axs = pplt.subplots()
axs.format(abc='A.', suptitle=f'Indicating error bounds')

# Medians and percentile ranges
ax = axs[0]
kw = dict(
    color='light red', edgecolor='k', legend=True,
    median=True, barpctile=90, boxpctile=True,
    # median=True, barpctile=(5, 95), boxpctile=(25, 75)  # equivalent
)
ax.bar(pd_data, **kw)
ax.format(title='Bar plot')

pplt.show()

In [None]:
array = [
    [1,  1,  2,  3,  4,  5,  6],
    [1,  1,  7,  8,  9, 10, 11]
]
hspace = [1.2,]
wspace = [0, 4.0] + [0.5,]*4

fig, axs = pplt.subplots(array, sharey=False, sharex=False,
                         dpi=dpi, figwidth=figwidth, wspace=wspace, hspace=hspace)

ax = axs[0]
bar_width = 0.35
plot_labels = [str(i) for i in range(10)]

####
## TODO: Error bars
####

x = np.arange(len(plot_labels))
ax.bar(x-bar_width/2, [label_classes.mean() for label_classes in advs_num_classes[0]], bar_width, color=plot_colors[0])
ax.bar(x+bar_width/2, [label_classes.mean() for label_classes in advs_num_classes[1]], bar_width, color=plot_colors[1])

ax.format(
    xlabel='Original class',
    ylabel=f'Unique adversarial classes',
    xticks=range(10),
    xtickminor=[],
    xgrid=False,
)
legend_handles = [mpatches.Patch(color=plot_colors[0], label='natural'),
                  mpatches.Patch(color=plot_colors[1], label='robust')]
legend = ax.legend(handles=legend_handles, loc='upper left', ncols=1, frame=False)
#for handle in legend.legendHandles:
#    handle.set_width(1.0)

for ax_loc in ['top', 'right']:
    ax.spines[ax_loc].set_color('none')

colors = pplt.Cycle([mpl.colors.ColorConverter.to_rgba(color) for color in plot_colors])

frac_found_entropy = []
y_max = np.max([(advs_frac_found[:, gt_label, :] / advs_frac_found[:, gt_label, :].sum(axis=1, keepdims=True)).max() for gt_label in range(10)])
ax_idx = 1
for gt_label in range(10): # label index
    ax = axs[ax_idx]
    frac_found_slice = advs_frac_found[:, gt_label, :]
    frac_found_slice = frac_found_slice / np.sum(frac_found_slice, axis=1, keepdims=True) # normalize across adv labels
    frac_found_entropy.append([entropy(frac_found_slice[i, :]) for i in range(2)])
    
    data = pd.DataFrame(frac_found_slice.transpose((1,0)), columns=pd.Index(model_types, name=''))
    handles = ax.bar(data, cycle=colors, edgecolor='none')
    ax.format(
        xtickminor=False,
        grid=False,
        ytickminor=False,
        ylim=[0, y_max],
        title=data_labels[gt_label],#f'{gt_label}',
        titleabove=False,
        titlepad=-1
    )
    for ax_loc in ['top', 'right']:
        ax.spines[ax_loc].set_color('none')
    if ax_idx == 6:
        ax.format(
            xticks=[i for i in np.arange(0, 10, 1)],
            xticklabels=['0', '', '2', '', '4', '', '7', '', '8', ''],
            xlabel='Adversarial class',
            ylabel='Class density',
            yticks=[0, np.fix(10*y_max)/10],
            yticklabels=[f'{np.fix(10*i)/10:.1f}' for i in [0, y_max]],
        )
    else:
        ax.format(
            xticks=[],
            yticks=[],
            yticklabels=[],
        )
        for ax_loc in ['bottom', 'left']:
            ax.spines[ax_loc].set_color('none')
    ax_idx += 1

fig.savefig(f'../data/{data_name}_adv_class_comp.pdf', dpi=dpi, transparent=True, bbox_inches="tight", pad_inches=0.01)
pplt.show()

In [None]:
frac_found_entropy = np.stack(frac_found_entropy)#.transpose((1,0))
pd_entropy = pd.DataFrame(frac_found_entropy, columns=pd.Index(model_types, name=''))
pd_entropy

# Decision Space Visualization

In [None]:
img_n = 50
n_grid = 100
offset = 0.1
len_grid_scale = 1.5

orig = images_robust[img_n]
adv1 = advs_robust[img_n, 0]
adv2 = advs_robust[img_n, 1]

pert1 = adv1 - orig.reshape(1, -1)
pert2 = adv2 - orig.reshape(1, -1)
len1 = np.linalg.norm(pert1)
len2 = np.linalg.norm(pert2)
len_grid = len_grid_scale * np.maximum(len1, len2)

adv_locs = [# [[adv1_y, adv1_x], [adv2_y, adv2_x]]
    [
        offset * n_grid / (offset + len_grid),
        (offset + len1) * n_grid / (offset + len_grid)
    ],[
        (offset + len2) * n_grid / (offset + len_grid),
        offset * n_grid / (offset + len_grid)
    ]
]

x = np.linspace(-offset, len_grid, n_grid)
y = np.linspace(-offset, len_grid, n_grid)

n_ticks = int(n_grid**2)
data_ticks = np.linspace(offset*n_grid/(offset+len_grid),
                         (offset+np.floor(len_grid))*n_grid/(offset+len_grid),
                         n_ticks)
plot_ticks = np.linspace(0,
                         np.floor(len_grid),
                         n_ticks)

origin_loc = [offset * n_grid / (offset + len_grid),]*2
origin_plot_loc = [
    data_ticks[np.argmin(np.abs(data_ticks - origin_loc[0]))],
    data_ticks[np.argmin(np.abs(data_ticks - origin_loc[0]))],
]

adv1_plot_loc = [
    data_ticks[np.argmin(np.abs(data_ticks - adv_locs[0][0]))],
    data_ticks[np.argmin(np.abs(data_ticks - adv_locs[0][1]))],
]

adv2_plot_loc = [
    data_ticks[np.argmin(np.abs(data_ticks - adv_locs[1][0]))],
    data_ticks[np.argmin(np.abs(data_ticks - adv_locs[1][1]))],
]

yx_range = [[0, data_ticks[-1]], [0, data_ticks[-1]]]

fig, ax = pplt.subplots(figwidth=figwidth_inch/2, dpi=dpi)

dec_advs, labels = pl.plot_dec_space(
    orig, adv1, adv2, model_robust,
    offset=offset, n_grid=n_grid, len_grid_scale=len_grid_scale,
    show_legend=False, show_advs=True, overlay_inbounds=False,
    colors=plot_colors, ax=ax)

ax.legend(handles=labels, loc='upper right', ncols=1, title='predicted\nclass')

new_shape = (n_grid, n_grid) + dec_advs.shape[1:]
dec_advs = dec_advs.reshape(new_shape)

arrowprops = dict(
    linestyle='--',
    lw=2,
    arrowstyle='->, head_width=0.5, head_length=0.5',
    mutation_scale=0.08,
    color='k'
)

# origin
y_pos = origin_plot_loc[0]
x_pos = origin_plot_loc[1]
overlay_offset = [-45, -37]
_ = plot_utils.overlay_image(ax, dec_advs, y_pos, x_pos, yx_range,
                             overlay_offset, arrowprops, cmap=cmap)

# adversarial 1 (x axis)
y_pos = adv1_plot_loc[0]
x_pos = adv1_plot_loc[1]
overlay_offset = [0, -37]
_ = plot_utils.overlay_image(ax, dec_advs, y_pos, x_pos, yx_range,
                             overlay_offset, arrowprops, cmap=cmap)

# adversarial 2 (y axis)
y_pos = adv2_plot_loc[0]
x_pos = adv2_plot_loc[1]
overlay_offset = [-45, 0]
_ = plot_utils.overlay_image(ax, dec_advs, y_pos, x_pos, yx_range,
                             overlay_offset, arrowprops, cmap=cmap)

#num_pixels = int(np.sqrt(orig.size))
arrowprops = dict(visible=False)
# pert 1 (x axis)
disp_pert1 = (pert1.reshape(orig.shape).transpose(1,2,0) - pert1.min()) / (pert1.max() - pert1.min())
imagebox = OffsetImage(disp_pert1,
                       zoom=0.75, cmap=cmap)
imagebox.image.axes = ax
yx = [adv - orig for adv, orig in zip(adv1_plot_loc, origin_plot_loc)]
offset = [-30, -27]
ab = AnnotationBbox(imagebox,
    xy=yx[::-1],
    xybox=offset,
    xycoords=ax.transData,
    boxcoords='offset points',
    pad=0.0,
    arrowprops=arrowprops
)
ax.add_artist(ab)

# pert 2 (y axis)
disp_pert2 = (pert2.reshape(orig.shape).transpose(1,2,0) - pert2.min()) / (pert2.max() - pert2.min())
imagebox = OffsetImage(disp_pert2,
                       zoom=0.75, cmap=cmap)
imagebox.image.axes = ax
yx = [adv - orig for adv, orig in zip(adv2_plot_loc, origin_plot_loc)]
offset = [-35, -48]
ab = AnnotationBbox(imagebox,
    xy=yx[::-1],
    xybox=offset,
    xycoords=ax.transData,
    boxcoords='offset points',
    pad=0.0,
    arrowprops=arrowprops
)
ax.add_artist(ab)

# left math
pos = yx[::-1]
offset = [-33, -59]
ax.text(x=pos[0]+offset[0], y=pos[1]+offset[1], s='+',
        color='black', ha='left', rotation='vertical',
        size=20, transform=ax.transData)

pos = yx[::-1]
offset = [-32, -16]
ax.text(x=pos[0]+offset[0], y=pos[1]+offset[1], s='=',
        color='black', ha='left', rotation='vertical',
        size=20, transform=ax.transData)

#bottom math
pos = [origin_plot_loc[1] - origin_plot_loc[1]/2, 0]
offset = [-11, -22]
ax.text(x=pos[0]+offset[0], y=pos[1]+offset[1], s='+',
        color='black', ha='left', rotation='horizontal',
        size=20, transform=ax.transData)

pos = [origin_plot_loc[1] + adv1_plot_loc[1]/2, 0]
offset = [-4, -22]
ax.text(x=pos[0]+offset[0], y=pos[1]+offset[1], s='=',
        color='black', ha='left', rotation='horizontal',
        size=20, transform=ax.transData)

# perturbation arrows
rand1 = FancyArrowPatch(
    [origin_plot_loc[1], origin_plot_loc[0]], # {x0, y0}
    [adv1_plot_loc[1] + 1.0, origin_plot_loc[0] + 5],  # {x1, y1}
    mutation_scale=10, lw=1, arrowstyle='->',
    color='white', linestyle='dashed')
ax.add_artist(rand1)

rand2 = FancyArrowPatch(
    [origin_plot_loc[1], origin_plot_loc[0]], # {x0, y0}
    [adv1_plot_loc[1] + 2, origin_plot_loc[0] + 15],  # {x1, y1}
    mutation_scale=10, lw=1, arrowstyle='->',
    color='white', linestyle='dashed')
ax.add_artist(rand2)

ax.text(x=origin_plot_loc[1] + adv1_plot_loc[1] / 6,
        y=origin_plot_loc[0] + adv2_plot_loc[0] / 6 + 7,
        s=r'$\ldots$', fontsize=30, rotation=-45, color='white')

rand3 = FancyArrowPatch(
    [origin_plot_loc[1], origin_plot_loc[0]], # {x0, y0}
    [origin_plot_loc[1] + 11, adv2_plot_loc[0]],  # {x1, y1}
    mutation_scale=10, lw=1, arrowstyle='->',
    color='white', linestyle='dashed')
ax.add_artist(rand3)

fig.savefig(f'../data/{data_name}_example_decomp.pdf', dpi=dpi, transparent=True, bbox_inches="tight", pad_inches=0.01)
pplt.show()