In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, './..')
sys.path.insert(0, '../data')

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from matplotlib.patches import FancyArrowPatch
import matplotlib.gridspec as gridspec
from matplotlib.offsetbox import (OffsetImage, AnnotationBbox)
import matplotlib.patches as mpatches
import proplot as pplt

import pandas as pd
import numpy as np
import torch
from torchvision import datasets, transforms

from models import model, eval
import plots as pl
from utils import dev, load_data, classification

sys.path.insert(0, './../../')

#import response_contour_analysis.utils.model_handling as model_utils
#import response_contour_analysis.utils.dataset_generation as data_utils
#import response_contour_analysis.utils.histogram_analysis as hist_utils
#import response_contour_analysis.utils.principal_curvature as curve_utils
import response_contour_analysis.utils.plotting as plot_utils

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

In [None]:
plot_settings = {
        "text.usetex": True,
        "font.family": "serif",
        "font.size": 8,
        "axes.formatter.use_mathtext":True,
}
figwidth = '13.968cm'
figwidth_inch = 5.50107 
dpi = 600
pplt.rc.update(plot_settings)
mpl.rcParams.update(plot_settings)

In [None]:
def tab_name_to_hex(tab): 
    conv_table = {
        "tab:blue": "#1f77b4",
        "tab:orange": "#ff7f0e",
        "tab:green": "#2ca02c",
        "tab:red": "#d62728",
        "tab:purple": "#9467bd",
        "tab:brown": "#8c564b",
        "tab:pink": "#e377c2",
        "tab:gray": "#7f7f7f",
        "tab:grey": "#7f7f7f",
        "tab:olive": "#bcbd22",
        "tab:cyan": "#17becf",
    }
    return conv_table[tab.lower()]

plot_colors = [tab_name_to_hex('tab:blue'), tab_name_to_hex('tab:red')]

# Load data & models

In [None]:
seed = 0

# load data
data_natural = np.load(f'../data/natural_{seed}.npy', allow_pickle=True).item()
advs_nat = data_natural['advs']
pert_lengths_nat = data_natural['pert_lengths']
classes_nat = data_natural['adv_class']
dirs_nat = data_natural['dirs']
images_nat = data_natural['images']
labels_nat = data_natural['labels']

data_madry = np.load(f'../data/robust_{seed}.npy', allow_pickle=True).item()
advs_madry = data_madry['advs']
pert_lengths_madry = data_madry['pert_lengths']
classes_madry = data_madry['adv_class']
dirs_madry = data_madry['dirs']
images_madry = data_madry['images']
labels_madry = data_madry['labels']

In [None]:
# load models
model_natural = model.madry_diff()
model_madry = model.madry_diff()

model_natural.load_state_dict(torch.load(f'../models/natural_{seed}.pt', map_location=torch.device(dev())))
model_natural.to(DEVICE)
model_natural.eval()

model_madry.load_state_dict(torch.load(f'../models/robust_{seed}.pt', map_location=torch.device(dev())))
model_madry.to(dev())
model_madry.eval()

In [None]:
advs_frac_found = np.zeros((2, 10, 10)) # [model_type, gt_label, fraction_label_to]
advs_per_class = np.zeros((2, 10)) # [model_type, gt_label]
advs_num_classes = np.zeros((2, 10)) # [model_type, gt_label]
for i, (classes_, pert_lengths_, labels_) in enumerate(zip([classes_nat, classes_madry], [pert_lengths_nat, pert_lengths_madry], [labels_nat, labels_madry])):
    mask_idx = np.invert(np.isnan(pert_lengths_)).sum(1) > 1
    classes_[np.isnan(pert_lengths_)] = np.nan
    masked_classes = classes_[mask_idx]
    masked_labels = labels_[mask_idx]
    for l in range(10):
        var = []
        for gt_classes in masked_classes[masked_labels == l]:
            masked_gt_classes = gt_classes[~np.isnan(gt_classes)]
            var.append((len(np.unique(masked_gt_classes))-1)/(len(masked_gt_classes)-1))
        var = np.mean(np.array(var))
        advs_num_classes[i, l] = var
        u, c = np.unique(masked_classes[masked_labels == l], return_counts=True)
        c = c[~np.isnan(u)]
        u = u[~np.isnan(u)]
        advs_frac_found[i, l, u.astype(int)] = c / np.sum(c)
        p = pert_lengths_[labels == l]
        advs_per_class[i, l] = np.mean(np.invert(np.isnan(p)).sum(1))
#print('natural data sum per class:\n'+'\n'.join([str(i) + ' : ' + str(advs_frac_found[0, l, :].sum()) for l in range (10)]))
#print('\nrobust data sum per class:\n'+'\n'.join([str(i) + ' : ' + str(advs_frac_found[1, l, :].sum()) for l in range (10)]))

# Check seed consistency

In [None]:
seeds = [0, 1, 2, 3]
model_type = 'natural'
n_advs = []
for i, model_id in enumerate(seeds):
    # load data
    data_ = np.load(f'../data/{model_type}_{model_id}.npy', allow_pickle=True).item()
    advs_ = data_['advs']
    pert_lengths_ = data_['pert_lengths']
    classes_ = data_['adv_class']
    dirs_ = data_['dirs']
    images_ = data_['images']
    labels_ = data_['labels']
    n_advs.append(15-np.isnan(pert_lengths_).sum(1))
n_advs = np.stack(n_advs, axis=0)
fig, ax = pplt.subplots(figwidth=figwidth, dpi=dpi)
ax.hist(n_advs.T, alpha=0.5, bins=16, range=(-0.5,15.5))
pplt.show()

fig, ax = pplt.subplots(figwidth=figwidth, dpi=dpi)
ax.boxplot(n_advs.T)
pplt.show()

print(f'mean number of adversarial directions = {np.mean(np.std(n_advs,axis=0))}')

# Plot grid of adversarial examples

In [None]:
img_n = 0
fig, ax = pl.plot_advs(advs=advs_nat[img_n], orig=images_nat[img_n], shape=images_nat[img_n].shape, classes=classes_nat[img_n],
                        orig_class=labels_nat[img_n], n=5, vmin=0, vmax=1)
fig.set_figwidth(figwidth_inch)
fig.set_dpi(dpi)

In [None]:
#fig, ax = plt.subplots(5, 6, squeeze=False, figsize=(6, 5), dpi=dpi)
fig, ax = pplt.subplots(nrows=5, ncols=6, figwidth=figwidth, dpi=dpi)
for j in range(5):
    orig = np.reshape(images_nat[j*50], [28, 28])
    if j == 0:
        ax[j, 0].set_title('original')#, fontsize=18)
    ax[j, 0].imshow(orig, cmap='gray', vmin=0, vmax=1)
    ax[j, 0].set_xticks([])
    ax[j, 0].set_yticks([])
    ax[j, 0].set_xlabel(str(labels_nat[j*50]))#, fontdict={'fontsize': 18})

    for i, a in enumerate(advs_nat[j*50,:5]):
        if j==0:
            ax[j, i+1].set_title('Adv. ' + str(i + 1))#, fontsize=18)
        #ax[j, i+1].set_xlabel('\u279E ' + str(int(classes_nat[j*50,i])), fontdict={'fontsize': 18})
        ax[j, i+1].imshow(a.reshape([28, 28]), cmap='gray', vmin=0, vmax=1)
        ax[j, i+1].format(
            xlabel = r'$\rightarrow$ ' + str(int(classes_nat[j*50,i])),
            xticks = [],
            yticks = [],
        )
#plt.subplots_adjust(hspace=0.3, left=0, right=1, bottom=0.05, top=0.95)
pplt.show()

# Plot madry and natural adversarial examples in one figure

In [None]:
fig = plt.figure()
for i in range(5):
    plt.subplot(3,5,1+i)
    plt.title('Orig. class ' + str(labels[i*50]))
    plt.imshow(np.reshape(images_nat[i*50], [28,28]), cmap='gray', vmin=0, vmax=1)
    plt.xticks([])
    plt.yticks([])
    if i == 0:
        plt.ylabel("Original Image")
    
    plt.subplot(3,5,6+i)
    plt.title('Adv. class ' + str(classes_nat[i*50,0]))
    plt.imshow(np.reshape(advs_nat[i*50,0], [28,28]), cmap='gray', vmin=0, vmax=1)
    plt.xticks([])
    plt.yticks([])
    if i == 0:
        plt.ylabel("Natural CNN")

    plt.subplot(3,5,11+i)
    plt.title('Adv. class ' + str(classes_madry[i*50,0]))
    plt.imshow(np.reshape(advs_madry[i*50,0], [28,28]), cmap='gray', vmin=0, vmax=1)
    plt.xticks([])
    plt.yticks([])
    if i == 0:
        plt.ylabel("Madry CNN")
plt.suptitle('Adversarials of non-robust and robust models')
fig.set_figwidth(figwidth_inch)
fig.set_figheight((7/12) * figwidth_inch)
fig.set_dpi(dpi)
plt.show()

# Perturbation Length comparison

In [None]:
# plot with all adversarials included
fig , ax = pl.plot_pert_lengths([pert_lengths_nat, pert_lengths_madry],  n=15, labels=['naturally trained','adversarially trained'], ord=2)
fig.set_figheight((8/12) * figwidth_inch)
fig.set_figwidth(figwidth_inch)
fig.set_dpi(dpi)
plt.show()


In [None]:
# only samples with at least n adversarials included
n=8

p_robust = pert_lengths_madry[np.invert(np.isnan(pert_lengths_madry)).sum(-1)>=n]
p_natural = pert_lengths_nat[np.invert(np.isnan(pert_lengths_nat)).sum(-1)>=n]
        
fig , ax = pl.plot_pert_lengths([p_natural, p_robust], n=n, labels=['naturally trained','adversarially trained'], ord=2)
fig.set_figheight((5/7) * figwidth_inch)
fig.set_figwidth(figwidth_inch)
fig.set_dpi(dpi)
plt.show()

# Plot variation of target classes

In [None]:
classes_ = classes_nat#_madry
pert_lengths_ = pert_lengths_nat#_madry

mask_idx = np.invert(np.isnan(pert_lengths_)).sum(1)>1
classes_[np.isnan(pert_lengths_)]=np.nan
classes_ = classes_[mask_idx]
labels_ = labels[mask_idx]
fig, ax = pl.plot_var_hist(classes_, labels_, title='Natural CNN', with_colors = False)
fig.set_figheight((5/7) * figwidth_inch)
fig.set_figwidth(figwidth_inch)
fig.set_dpi(dpi)
plt.show()

In [None]:
plt.figure(dpi=dpi)
plt.bar(range(10), advs_per_class[0,:])
plt.bar(range(10), advs_per_class[1,:])
plt.xticks(np.arange(0,10))
plt.xlabel('original class label')
plt.ylabel('mean number of directions found')
plt.title('Natural Model')
plt.ylim(0,25)
fig.set_figwidth(figwidth_inch)
fig.set_figheight((7/5) * figwidth_inch)
plt.show()
# print(np.mean(advs_per_class),np.mean(np.invert(np.isnan(pert_lengths_nat)).sum(1)))

# Distance to decision boundary

In [None]:
n_dim = 8
n_samples = 100
seed = 0
dists = [np.load(f'../data/distance_to_boundary_natural_{seed}.npz')['data'],
         np.load(f'../data/distance_to_boundary_robust_{seed}.npz')['data']]

In [None]:
fig, ax = pplt.subplots(figwidth=figwidth, dpi=dpi)
mean_dists = np.nanmean(dists[0],axis=-1)
mask = ~np.isnan(mean_dists)
filtered_data = [d[m] for d, m in zip(mean_dists.T, mask.T)]
ax.boxplot(filtered_data)
ax.plot(np.arange(1,9),np.nanmean(pert_lengths_nat[np.invert(np.isnan(pert_lengths_nat)).sum(-1)>8,:8], axis=0), 'b.')
ax.format(
    xlabel='dimension of adversarial space',
    ylabel='mean distance to decision boundary')
pplt.show()

fig, ax = pplt.subplots(dpi=dpi)
ax.plot(np.arange(2,9),np.mean(np.isnan(dists[0]).sum(-1),axis=0)[1:]/n_samples,'k.')
ax.format(
    xlabel='dimension of adversarial space',
    ylabel='rate of out of bounds samples')
pplt.show()

In [None]:
colors = ['blue', 'orange']

fig, ax = plt.subplots(dpi=dpi)
for i, color in enumerate(colors):
    boxprops = dict(color=color, linewidth=1.5, alpha=0.7)
    whiskerprops = dict(color=color, alpha=0.7)
    capprops = dict(color=color, alpha=0.7)
    medianprops = dict(linestyle=None, linewidth=0)
    meanpointprops = dict(marker='o', markeredgecolor='black',
                          markerfacecolor=color)

    mean_dists = np.nanmean(dists[i], axis=-1)
    mask = ~np.isnan(mean_dists)
    filtered_data = [d[m] for d, m in zip(mean_dists.T, mask.T)]
    ax.boxplot(filtered_data, whis=[10,90], showfliers=False, showmeans=False, boxprops=boxprops, 
              whiskerprops=whiskerprops, capprops=capprops, meanprops=meanpointprops,
              medianprops=medianprops)
    if i == 0:
        lengths = pert_lengths_nat
    else:
        lengths = pert_lengths_madry
    x = np.arange(1, 9)
    y = lengths[np.invert(np.isnan(lengths)).sum(-1)>8, :8]
    ax.scatter(x, np.nanmean(y, axis=0), color=color, marker='.')
ax.set_xlabel('dimension of adversarial space')
ax.set_ylabel('mean distance to decision boundary')

fig.set_figwidth(figwidth_inch)
fig.set_figheight((7/5) * figwidth_inch)

plt.show()

In [None]:
array = [
    [1,  1,  1,  1,  2,  3,  4,  5,  6,  7],
    [1,  1,  1,  1,  8,  9, 10, 11, 12, 13],
    [1,  1,  1,  1, 14, 15, 16, 17, 18, 19],
    [1,  1,  1,  1, 20, 21,  0, 22, 22,  0],
]
hspace = [0.8,]*3
wspace = [0, 0, 0] + [0.1, -1,]*3



fig, axs = pplt.subplots(array, sharey=False, sharex=False,
                         dpi=dpi, figwidth=figwidth, wspace=wspace, hspace=hspace)

ax = axs[0]
bar_width = 0.35
plot_labels = [str(i) for i in range(10)]
x = np.arange(len(plot_labels))

ax.bar(x-bar_width/2, advs_num_classes[0, :], bar_width, color=plot_colors[0])
ax.bar(x+bar_width/2, advs_num_classes[1, :], bar_width, color=plot_colors[1])
ax.format(
    xlabel='Original class label',
    ylabel='Mean number of adversarial classes',
    xticks=range(10),
    xtickminor=[],
    ylim=[0, 1],
    xgrid=False,
)
legend_handles = [mpatches.Patch(color=plot_colors[0], label='Natural'),
                  mpatches.Patch(color=plot_colors[1], label='Adversarial')]
ax.legend(handles=legend_handles, loc='upper right', ncols=1, frame=False)

def tab_name_to_hex(tab): 
    conv_table = {
        "tab:blue": "#1f77b4",
        "tab:orange": "#ff7f0e",
        "tab:green": "#2ca02c",
        "tab:red": "#d62728",
        "tab:purple": "#9467bd",
        "tab:brown": "#8c564b",
        "tab:pink": "#e377c2",
        "tab:gray": "#7f7f7f",
        "tab:grey": "#7f7f7f",
        "tab:olive": "#bcbd22",
        "tab:cyan": "#17becf",
    }
    return conv_table[tab.lower()]

colors = pplt.Cycle([mpl.colors.ColorConverter.to_rgba(tab_name_to_hex(color))
          for color in [
              'tab:blue',
              'tab:orange',
              'tab:green',
              'tab:purple',
              'tab:red',
              'tab:brown',
              'tab:grey',
              'tab:pink',
              'tab:cyan',
              'tab:olive']
         ])

model_names = ('Natural', 'Adversarial')
ax_idx = 1
for label_index in range(10): # label index
    for model_index in range(2):
        ax = axs[ax_idx]
        patches, texts = ax.pie(advs_frac_found[model_index, label_index, :], cycle=colors, startangle=90, normalize=True)
        #ax.set_title(f'{model_names[model_index]}\nfrom {plot_labels[label_index]}')
        if ax_idx % 2 != 0:
            ax.format(
                title = f'{plot_labels[label_index]}',
                titleloc = 'right',
                titleabove = False,
                titlepad = -1
            )
        ax_idx += 1

legend = axs[-1].legend(patches, labels=plot_labels,
               title='Adversarial class',
               columnspacing=-1., markerfirst=False,
               loc='fill', frame=False, ncols=5, pad=0)
for handle in legend.legendHandles:
    handle.set_width(4.0)

fig.savefig('../data/adv_class_comp_pie.pdf', dpi=dpi, transparent=True, bbox_inches="tight", pad_inches=0.01)
pplt.show()

In [None]:
array = [
    [1,  1,  2,  3,  4,  5,  6],
    [1,  1,  7,  8,  9, 10, 11]
]
hspace = [1.2,]
wspace = [0, 4.0] + [0.5,]*4

fig, axs = pplt.subplots(array, sharey=False, sharex=False,
                         dpi=dpi, figwidth=figwidth, wspace=wspace, hspace=hspace)

ax = axs[0]
bar_width = 0.35
plot_labels = [str(i) for i in range(10)]
x = np.arange(len(plot_labels))
ax.bar(x-bar_width/2, advs_num_classes[0, :], bar_width, color=plot_colors[0])
ax.bar(x+bar_width/2, advs_num_classes[1, :], bar_width, color=plot_colors[1])
ax.format(
    xlabel='Original class',
    ylabel='Number of alternate classes',
    xticks=range(10),
    xtickminor=[],
    ylim=[0, 1],
    xgrid=False,
)
legend_handles = [mpatches.Patch(color=plot_colors[0], label='Natural'),
                  mpatches.Patch(color=plot_colors[1], label='Adversarial')]
ax.legend(handles=legend_handles, loc='upper right', ncols=1, frame=False)

model_types = ('Natural', 'Adversarial')
colors = pplt.Cycle([mpl.colors.ColorConverter.to_rgba(color) for color in plot_colors])

# TODO:
#y_max = np.max([np.max(advs_frac_found[i, 6, :] / advs_frac_found[i, 6, :].sum()) for i in range(10)])
ax_idx = 1
for label_index in range(10): # label index
    ax = axs[ax_idx]
    frac_found_slice = advs_frac_found[:, label_index, :].transpose((1,0))
    frac_found_slice = frac_found_slice / np.sum(frac_found_slice, axis=0, keepdims=True)
    data = pd.DataFrame(frac_found_slice, columns=pd.Index(model_types, name=''))
    handles = ax.bar(data, cycle=colors, edgecolor='none')
    ax.format(
        xtickminor=False,
        grid=False,
        ytickminor=False,
        title=f'{label_index}',
        titleabove=False,
        titlepad=-1
    )
    for ax_loc in ['top', 'right']:
        ax.spines[ax_loc].set_color('none')
    if ax_idx == 6:
        ax.format(
            xticks=[i for i in np.arange(0, 10, 1)],
            xticklabels=['0', '', '2', '', '4', '', '7', '', '8', ''],
            xlabel='Adversarial class',
            ylabel='Class density',
            ylim=[0, 0.7],
            yticks=[0, 0.6],
            yticklabels=[str(i) for i in [0, 0.6]],
        )
    else:
        ax.format(
            xticks=[],
            ylim=[0, 0.7],
            yticks=[],
            yticklabels=[],
        )
        for ax_loc in ['bottom', 'left']:
            ax.spines[ax_loc].set_color('none')
    ax_idx += 1

fig.savefig('../data/adv_class_comp_bar.pdf', dpi=dpi, transparent=True, bbox_inches="tight", pad_inches=0.01)
pplt.show()

# Decision Space Visualization

In [None]:
for img_n in [0,50]:#,100,150,200,250,300,350,400,450]:
    fig, ax = pplt.subplots(dpi=dpi//10)
    orig = images_madry[img_n]
    adv1 = advs_madry[img_n,0]
    adv2 = advs_madry[img_n,1]
    model_ = model_madry
    _ = pl.plot_dec_space(orig, adv1, adv2, model_, show_legend=True, show_advs=True, overlay_inbounds=True, ax=ax)
    plt.show()

In [None]:
img_n = 50
n_grid = 100
offset = 0.5
len_grid_scale = 1.5

orig = images_madry[img_n]
adv1 = advs_madry[img_n, 0]
adv2 = advs_madry[img_n, 1]

pert1 = adv1 - orig.reshape(1, -1)
pert2 = adv2 - orig.reshape(1, -1)
len1 = np.linalg.norm(pert1)
len2 = np.linalg.norm(pert2)
len_grid = len_grid_scale * np.maximum(len1, len2)

adv_locs = [# [[adv1_y, adv1_x], [adv2_y, adv2_x]]
    [
        offset * n_grid / (offset + len_grid),
        (offset + len1) * n_grid / (offset + len_grid)
    ],[
        (offset + len2) * n_grid / (offset + len_grid),
        offset * n_grid / (offset + len_grid)
    ]
]

x = np.linspace(-offset, len_grid, n_grid)
y = np.linspace(-offset, len_grid, n_grid)

n_ticks = int(n_grid**2)
data_ticks = np.linspace(offset*n_grid/(offset+len_grid), (offset+np.floor(len_grid))*n_grid/(offset+len_grid), n_ticks)
plot_ticks = np.linspace(0, np.floor(len_grid), n_ticks)

origin_loc = [offset * n_grid / (offset + len_grid),]*2
origin_plot_loc = [
    data_ticks[np.argmin(np.abs(data_ticks - origin_loc[0]))],
    data_ticks[np.argmin(np.abs(data_ticks - origin_loc[0]))],
]

adv1_plot_loc = [
    data_ticks[np.argmin(np.abs(data_ticks - adv_locs[0][0]))],
    data_ticks[np.argmin(np.abs(data_ticks - adv_locs[0][1]))],
]

adv2_plot_loc = [
    data_ticks[np.argmin(np.abs(data_ticks - adv_locs[1][0]))],
    data_ticks[np.argmin(np.abs(data_ticks - adv_locs[1][1]))],
]

yx_range = [[0, data_ticks[-1]], [0, data_ticks[-1]]]

fig, ax = pplt.subplots(figwidth=figwidth_inch/2, dpi=dpi)

dec_advs, labels = pl.plot_dec_space(
    orig, adv1, adv2, model_madry,
    offset=offset, n_grid=n_grid, len_grid_scale=len_grid_scale,
    show_legend=False, show_advs=True, overlay_inbounds=False, ax=ax)

ax.legend(handles=labels, loc='upper right', ncols=1, title='predicted\nclass')

new_shape = (n_grid, n_grid) + dec_advs.shape[1:]
dec_advs = dec_advs.reshape(new_shape)

arrowprops = dict(
    linestyle='--',
    lw=2,
    arrowstyle='->, head_width=0.5, head_length=0.5',
    mutation_scale=0.08,
    color='k'
)

# origin
y_pos = origin_plot_loc[0]
x_pos = origin_plot_loc[1]
overlay_offset = [-52, -52]
_ = plot_utils.overlay_image(ax, dec_advs, y_pos, x_pos, yx_range, overlay_offset, arrowprops)

# adversarial 1 (x axis)
y_pos = adv1_plot_loc[0]
x_pos = adv1_plot_loc[1]
overlay_offset = [0, -52]
_ = plot_utils.overlay_image(ax, dec_advs, y_pos, x_pos, yx_range, overlay_offset, arrowprops)

# adversarial 2 (y axis)
y_pos = adv2_plot_loc[0]
x_pos = adv2_plot_loc[1]
overlay_offset = [-52, 0]
_ = plot_utils.overlay_image(ax, dec_advs, y_pos, x_pos, yx_range, overlay_offset, arrowprops)

num_pixels = int(np.sqrt(orig.size))
arrowprops = dict(visible=False)

# pert 1 (x axis)
imagebox = OffsetImage(pert1.reshape(num_pixels, num_pixels), zoom=0.75, cmap='Gray')
imagebox.image.axes = ax
yx = [adv - orig for adv, orig in zip(adv1_plot_loc, origin_plot_loc)]
offset = [-60, -34]
ab = AnnotationBbox(imagebox,
    xy=yx[::-1],
    xybox=offset,
    xycoords=ax.transData,
    boxcoords='offset points',
    pad=0.0,
    arrowprops=arrowprops
)
ax.add_artist(ab)

# pert 2 (y axis)
imagebox = OffsetImage(pert2.reshape(num_pixels, num_pixels), zoom=0.75, cmap='Gray')
imagebox.image.axes = ax
yx = [adv - orig for adv, orig in zip(adv2_plot_loc, origin_plot_loc)]
offset = [-34, -60]
ab = AnnotationBbox(imagebox,
    xy=yx[::-1],
    xybox=offset,
    xycoords=ax.transData,
    boxcoords='offset points',
    pad=0.0,
    arrowprops=arrowprops
)
ax.add_artist(ab)

# left math
pos = yx[::-1]
offset = [-23, -61]
ax.text(x=pos[0]+offset[0], y=pos[1]+offset[1], s='+', ha='left', rotation='vertical', size=20, transform=ax.transData)

pos = yx[::-1]
offset = [-24, -13]
ax.text(x=pos[0]+offset[0], y=pos[1]+offset[1], s='=', ha='left', rotation='vertical', size=20, transform=ax.transData)

#bottom math
pos = [origin_plot_loc[1] - origin_plot_loc[1]/2, 0]
offset = [-3, -21]
ax.text(x=pos[0]+offset[0], y=pos[1]+offset[1], s='+', ha='left', rotation='horizontal', size=20, transform=ax.transData)

pos = [origin_plot_loc[1] + adv1_plot_loc[1]/2, 0]
offset = [1, -21]
ax.text(x=pos[0]+offset[0], y=pos[1]+offset[1], s='=', ha='left', rotation='horizontal', size=20, transform=ax.transData)

# perturbation arrows
rand1 = FancyArrowPatch(
    [origin_plot_loc[1], origin_plot_loc[0]], # {x0, y0}
    [adv1_plot_loc[1] + 1.0, origin_plot_loc[0] + 5],  # {x1, y1}
    mutation_scale=10, lw=1, arrowstyle='->', color='white', linestyle='dashed')
ax.add_artist(rand1)

rand2 = FancyArrowPatch(
    [origin_plot_loc[1], origin_plot_loc[0]], # {x0, y0}
    [adv1_plot_loc[1] + 2, origin_plot_loc[0] + 15],  # {x1, y1}
    mutation_scale=10, lw=1, arrowstyle='->', color='white', linestyle='dashed')
ax.add_artist(rand2)

ax.text(x=origin_plot_loc[1] + adv1_plot_loc[1] / 6,
        y=origin_plot_loc[0] + adv2_plot_loc[0] / 6 + 8,
        s=r'$\ldots$',
        fontsize=30, color='white')

rand3 = FancyArrowPatch(
    [origin_plot_loc[1], origin_plot_loc[0]], # {x0, y0}
    [origin_plot_loc[1] + 11, adv2_plot_loc[0] - 1],  # {x1, y1}
    mutation_scale=10, lw=1, arrowstyle='->', color='white', linestyle='dashed')
ax.add_artist(rand3)

#fig.savefig('../data/example_decomp.pdf', dpi=dpi, transparent=True, bbox_inches="tight", pad_inches=0.01)
pplt.show()