In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, './..')
sys.path.insert(0, '../data')

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import proplot as pplt
import pandas as pd
import dill

from tqdm import tqdm
import numpy as np
import torch
from torchvision import datasets, transforms

from models import eval
from models import model as model_loader
import plots as pl
from utils import dev, load_data, classification, make_orth_basis
from robustness1.datasets import CIFAR

sys.path.insert(0, './../../')

import response_contour_analysis.utils.model_handling as model_utils
import response_contour_analysis.utils.dataset_generation as data_utils
import response_contour_analysis.utils.histogram_analysis as hist_utils
import response_contour_analysis.utils.principal_curvature as curve_utils
import response_contour_analysis.utils.plotting as plot_utils

# check device
print(dev())

# Experiment parameters

In [None]:
hess_params = dict()
hess_params['hessian_num_pts'] = 1.0e4
hess_params['hessian_lr'] = 1e-4
hess_params['hessian_random_walk'] = False
hess_params['return_points'] = False
hess_params['lr_decay'] = False#True
num_iters = 2 # for paired image boundary search
num_steps_per_iter = 100 # for paired image boundary search
buffer_portion = 0.25
num_eps = 1000
batch_size = 1000
num_images = 50#labels_nat.size
hess_radius_mult = 0.7 # times the min adv perturbation length
num_advs = 8
autodiff = True
load = False
num_hessian_tests = 0#10

plot_settings = {
        "text.usetex": True,
        "font.family": "serif",
        "font.size": 8,
        "axes.formatter.use_mathtext":True,
}
pplt.rc.update(plot_settings)
mpl.rcParams.update(plot_settings)

figwidth = '13.968cm'
figwidth_inch = 5.50107 
dpi = 600
model_types = ['Naturally trained', 'Adversarially trained']

In [None]:
def tab_name_to_hex(tab): 
    conv_table = {
        "tab:blue": "#1f77b4",
        "tab:orange": "#ff7f0e",
        "tab:green": "#2ca02c",
        "tab:red": "#d62728",
        "tab:purple": "#9467bd",
        "tab:brown": "#8c564b",
        "tab:pink": "#e377c2",
        "tab:gray": "#7f7f7f",
        "tab:grey": "#7f7f7f",
        "tab:olive": "#bcbd22",
        "tab:cyan": "#17becf",
    }
    return conv_table[tab.lower()]

plot_colors = [tab_name_to_hex('tab:blue'), tab_name_to_hex('tab:red')]

# Load data & models

In [None]:
seed = 0

# load data
data_natural = np.load(f'../data/cifar_natural_diff.npy', allow_pickle=True).item()
advs_nat = data_natural['advs']
pert_lengths_nat = data_natural['pert_lengths']
classes_nat = data_natural['adv_class']
dirs_nat = data_natural['dirs']
images_nat = data_natural['images']
labels_nat = data_natural['labels']

data_madry = np.load(f'../data/cifar_robust_diff.npy', allow_pickle=True).item()
advs_madry = data_madry['advs']
pert_lengths_madry = data_madry['pert_lengths']
classes_madry = data_madry['adv_class']
dirs_madry = data_madry['dirs']
images_madry = data_madry['images']
labels_madry = data_madry['labels']

cifar_labels = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
sub_pert_lengths_nat = pert_lengths_nat[:, :num_advs]
sub_pert_lengths_madry = pert_lengths_madry[:, :num_advs]
mean_pert_lengths = np.mean([sub_pert_lengths_nat[np.isfinite(sub_pert_lengths_nat)].mean(),
                             sub_pert_lengths_madry[np.isfinite(sub_pert_lengths_madry)].mean()])
min_pert_lengths = np.min([sub_pert_lengths_nat[np.isfinite(sub_pert_lengths_nat)].min(),
                             sub_pert_lengths_madry[np.isfinite(sub_pert_lengths_madry)].min()])
hess_params['hessian_dist'] = min_pert_lengths * hess_radius_mult # radius around the target image

In [None]:
# load models
ds = CIFAR('../data/cifar-10-batches-py')
classifier_model = ds.get_model('resnet50', False)
model_natural = model_loader.cifar_pretrained(classifier_model, ds)

resume_path = '../models/nat_diff.pt'
checkpoint = torch.load(resume_path, pickle_module=dill, map_location=torch.device(dev()))

state_dict_path = 'model'
if not ('model' in checkpoint):
    state_dict_path = 'state_dict'
sd = checkpoint[state_dict_path]
sd = {k[len('module.'):]: v for k, v in sd.items()}
model_natural.load_state_dict(sd)
model_natural.to(dev())
model_natural.double()
model_natural.eval()

classifier_model = ds.get_model('resnet50', False)
model_madry = model_loader.cifar_pretrained(classifier_model, ds)

resume_path = '../models/rob_diff.pt'
checkpoint = torch.load(resume_path, pickle_module=dill, map_location=torch.device(dev()))

state_dict_path = 'model'
if not ('model' in checkpoint):
    state_dict_path = 'state_dict'
sd = checkpoint[state_dict_path]
sd = {k[len('module.'):]:v for k,v in sd.items()}
model_madry.load_state_dict(sd)
model_madry.to(dev())
model_madry.double()
model_madry.eval()

In [None]:
def torchify(img):
    output = torch.from_numpy(img).type(torch.DoubleTensor).to(dev()) # always autodiff
    output.requires_grad = True
    return output


def get_paired_boundary_image(model, origin, alt_image, num_steps_per_iter, num_iters):
    input_shape = (1,)+origin.shape
    def find_pert(image_line):
        correct_lbl = torch.argmax(model(torchify(image_line[0, ...].reshape(input_shape))))
        pert_lbl = correct_lbl.clone()
        step_idx = 1 # already know the first one
        while pert_lbl == correct_lbl:
            pert_image = image_line[step_idx, ...]
            pert_lbl = torch.argmax(model(torchify(pert_image.reshape(input_shape))))
            step_idx += 1
        return step_idx-1, pert_image
    image_line = np.linspace(origin.reshape(-1), alt_image.reshape(-1), num_steps_per_iter)
    for search_iter in range(num_iters):
        step_idx, pert_image = find_pert(image_line)
        image_line = np.linspace(image_line[step_idx - 1, ...], image_line[step_idx, ...], num_steps_per_iter)
    delta_image = origin.reshape(-1) - pert_image
    pert_length = np.linalg.norm(delta_image)
    direction = delta_image / pert_length
    return pert_image.reshape(origin.shape), direction, pert_length


def generate_paired_dict(data_dict, model, num_images, num_advs):
    image_shape = data_dict['images'].shape[1:]
    num_pixels = int(np.prod(image_shape))
    images = np.zeros((num_images,) + image_shape)
    labels = np.zeros((num_images), dtype=np.int)
    dirs = np.zeros((num_images, num_advs, 1, num_pixels))
    advs = np.zeros((num_images, num_advs, 1, num_pixels))
    pert_lengths = np.zeros((num_images, num_advs))
    adv_class = np.zeros((num_images, num_advs))
    #model_predictions = torch.argmax(model(torchify(data_dict['images'])), dim=1).detach().cpu().numpy()
    batch_size = 10
    image_splits = torch.split(torchify(data_dict['images']), batch_size)
    model_predictions = []
    for batch_idx, batch in enumerate(image_splits):
        model_predictions.append(torch.argmax(model(batch), dim=1).detach().cpu().numpy())
    model_predictions = np.stack(model_predictions, axis=0).reshape((len(model_predictions)*batch_size,) + model_predictions[0].shape[1:])
    valid_indices = []
    for image_idx in range(data_dict['images'].shape[0]):
        if model_predictions[image_idx] == data_dict['labels'][image_idx]:
            valid_indices.append(image_idx)
    origin_indices = np.random.choice(valid_indices, size=num_images, replace=False)
    for image_idx, origin_idx in enumerate(origin_indices):
        images[image_idx, ...] = data_dict['images'][origin_idx, ...]
        labels[image_idx] = data_dict['labels'][origin_idx]
        shuffled_valid_indices = np.random.choice(valid_indices, size=len(valid_indices), replace=False)
        alt_indices = [idx for idx, alt_class in zip(shuffled_valid_indices, data_dict['labels'][shuffled_valid_indices]) if alt_class != labels[image_idx]]
        for dir_idx, alt_idx in enumerate(alt_indices[:num_advs]):
            alt_image = data_dict['images'][alt_idx, ...]
            boundary_image, boundary_dir, pert_length = get_paired_boundary_image(
                model, images[image_idx, ...], alt_image, num_steps_per_iter, num_iters)
            dirs[image_idx, dir_idx, ...] = boundary_dir.reshape(1, -1)
            advs[image_idx, dir_idx, ...] = boundary_image.reshape(1, -1)
            adv_class[image_idx,  dir_idx] = torch.argmax(model(torchify(boundary_image[None, ...]))).item()
            pert_lengths[image_idx, dir_idx] =  pert_length
    output_dict = {}
    output_dict['images'] = images
    output_dict['labels'] = labels
    output_dict['dirs'] = dirs
    output_dict['advs'] = advs
    output_dict['adv_class'] = adv_class
    output_dict['pert_lengths'] = pert_lengths
    return output_dict


def paired_activation(model, image, neuron1, neuron2):
    if not image.requires_grad:
        image.requires_grad = True
    model.zero_grad()
    activation1 = model_utils.unit_activation(model, image, neuron1, compute_grad=True)
    activation2 = model_utils.unit_activation(model, image, neuron2, compute_grad=True)
    activation_difference = activation1 - activation2
    return activation_difference


def paired_activation_and_gradient(model, image, neuron1, neuron2):
    activation_difference = paired_activation(model, image, neuron1, neuron2)
    grad = torch.autograd.grad(activation_difference, image)[0]
    return activation_difference, grad


def get_curvature(condition_zip, num_images, num_advs, num_eps, batch_size, buffer_portion, autodiff=False):
    """
    A note on the gradient of the difference in activations:
        The gradient points in the direction of the origin from the boundary image.
        Therefore, for large enough eps, origin - eps * grad/|grad| will reach the boundary; and boundary + eps * grad/|grad| will reach the origin 
    """
    models, model_data = zip(*condition_zip)
    num_models = len(models)
    image_shape = model_data[0]['images'][0, ...][None, ...].shape
    image_size = np.prod(image_shape)
    num_dims = image_size - 1 #removes normal direction
    shape_operators = np.zeros((num_models, num_images, num_advs, num_dims, num_dims))
    principal_curvatures = np.zeros((num_models, num_images, num_advs, num_dims))
    principal_directions = np.zeros((num_models, num_images, num_advs, image_size, num_dims))
    origin_indices = np.zeros((num_models, num_images), dtype=np.int)
    for model_idx, (model_, data_)  in enumerate(zip(models, model_data)):
        model_predictions = torch.argmax(model_(torchify(data_['images'])), dim=1).detach().cpu().numpy()
        valid_indices = [] # Need to ensure that all images are correctly labeled & have valid adversarial examples
        for image_idx in range(data_['images'].shape[0]):
            if model_predictions[image_idx] == data_['labels'][image_idx]: # correctly labeled
                if np.all(np.isfinite(data_['pert_lengths'][image_idx, :num_advs])): # enough adversaries found
                    valid_indices.append(image_idx)
        origin_indices[model_idx, :] = np.random.choice(valid_indices, size=num_images, replace=False)
        pbar = tqdm(total=num_advs*num_images, leave=False)
        for image_idx, origin_idx in enumerate(list(origin_indices[model_idx, :])):
            clean_lbl = int(data_['labels'][origin_idx])
            for adv_idx in range(num_advs):
                boundary_image = get_paired_boundary_image(
                    model=model_,
                    origin=data_['images'][origin_idx, ...],
                    alt_image=data_['advs'][origin_idx, adv_idx, ...],
                    num_steps_per_iter=num_steps_per_iter,
                    num_iters=num_iters
                )[0]
                adv_lbl = int(data_['adv_class'][origin_idx, adv_idx])
                if autodiff:
                    def func(x):
                        acts_diff = paired_activation(model_, x, clean_lbl, adv_lbl)
                        return acts_diff
                    hessian = torch.autograd.functional.hessian(func, torchify(boundary_image[None,...]))
                    hessian = hessian.reshape((int(boundary_image.size), int(boundary_image.size)))
                else:
                    def func(x):
                        acts_diff, grad = paired_activation_and_gradient(model_, x, clean_lbl, adv_lbl)
                        return acts_diff, grad
                    hessian = curve_utils.sr1_hessian(
                        func, torchify(boundary_image[None, ...]),
                        distance=hess_params['hessian_dist'],
                        n_points=hess_params['hessian_num_pts'],
                        random_walk=hess_params['hessian_random_walk'],
                        learning_rate=hess_params['hessian_lr'],
                        return_points=False,
                        progress=False)
                activation, gradient = paired_activation_and_gradient(model_, torchify(boundary_image[None, ...]), clean_lbl, adv_lbl)
                gradient = gradient.reshape(-1)
                curvature = curve_utils.local_response_curvature_isoresponse_surface(gradient, hessian)
                shape_operators[model_idx, image_idx, adv_idx, ...] = curvature[0].detach().cpu().numpy()
                principal_curvatures[model_idx, image_idx, adv_idx, :] = curvature[1].detach().cpu().numpy()
                principal_directions[model_idx, image_idx, adv_idx, ...] = curvature[2].detach().cpu().numpy()
                pbar.update(1)
        pbar.close()
    return shape_operators, principal_curvatures, principal_directions, origin_indices


def get_hessian_error(model, origin, clean_lbl, adv_lbl, abscissa, ordinate, hess_params):
    def act_func(x):
        acts_diff = paired_activation(model, x, clean_lbl, adv_lbl)
        return acts_diff
    def act_grad_func(x):
        acts_diff, grad = paired_activation_and_gradient(model, x, clean_lbl, adv_lbl)
        return acts_diff, grad
    origin.requires_grad = True
    sr1_hessian = curve_utils.sr1_hessian(
        act_grad_func, origin,
        distance=hess_params['hessian_dist'],
        n_points=hess_params['hessian_num_pts'],
        random_walk=hess_params['hessian_random_walk'],
        learning_rate=hess_params['hessian_lr'],
        return_points=False,
        progress=True)
    autodiff_hessian = torch.autograd.functional.hessian(act_func, origin)
    autodiff_hessian = autodiff_hessian.reshape((int(origin.numel()), int(origin.numel())))
    n_x_samples = 10
    n_y_samples = 100
    x = np.linspace(-hess_params['hessian_dist']/2, hess_params['hessian_dist']/2, n_x_samples)
    y = np.linspace(-hess_params['hessian_dist']*1.25, hess_params['hessian_dist']*1.25, n_y_samples)
    X, Y = np.meshgrid(x, y)
    samples = (abscissa * X.reshape((-1, 1)) + ordinate * Y.reshape((-1, 1))).reshape((-1,) + origin.shape[1:])
    samples = origin + torchify(samples)
    exact_response = act_func(samples)
    sr1_approx_response = curve_utils.hessian_approximate_response(act_grad_func, samples, sr1_hessian)
    autodiff_approx_response = curve_utils.hessian_approximate_response(act_grad_func, samples, autodiff_hessian)
    sr1_total_error = (exact_response - sr1_approx_response)
    autodiff_total_error = (exact_response - autodiff_approx_response)
    sr1_rms_error = np.sqrt(np.mean(np.square(sr1_total_error.detach().cpu().numpy())))
    autodiff_rms_error = np.sqrt(np.mean(np.square(autodiff_total_error.detach().cpu().numpy())))
    return sr1_rms_error, autodiff_rms_error

In [None]:
if num_hessian_tests > 0:
    sr1_errors = []
    autodiff_errors = []
    valid_indices = [i for i in range(data_natural['images'].shape[0]) if np.isfinite(data_natural['pert_lengths'][i, 2])]
    image_indices = np.random.choice(valid_indices, size=num_hessian_tests, replace=False)
    for image_idx in image_indices:
        adv_idx = 0
        boundary_image, boundary_dir, pert_length = get_paired_boundary_image(
            model=model_natural,
            origin=data_natural['images'][image_idx, ...],
            alt_image=data_natural['advs'][image_idx, adv_idx, ...],
            num_steps_per_iter=num_steps_per_iter,
            num_iters=num_iters
        )

        sr1_rms_error, autodiff_rms_error = get_hessian_error(
            model=model_natural,
            origin=torchify(boundary_image[None,...]),
            clean_lbl=int(data_natural['labels'][image_idx]),
            adv_lbl=int(data_natural['adv_class'][image_idx, adv_idx]),
            abscissa=boundary_dir,
            ordinate=data_natural['advs'][image_idx, adv_idx+1],
            hess_params=hess_params
        )
        sr1_errors.append(sr1_rms_error)
        autodiff_errors.append(autodiff_rms_error)

    print(f'Average RMS error over {num_hessian_tests} images with an l2 radius of {hess_params["hessian_dist"]:.3f} for\nSR1 Hessian:\t\t{np.mean(sr1_errors):.3f}\nAutodiff hessian:\t{np.mean(autodiff_errors):.3f}.')

In [None]:
if autodiff:
    filename = '../data/cifar_curvatures_and_directions_autodiff.npz'
else:
    filename = '../data/cifar_curvatures_and_directions_sr1.npz'

if load:
    data = np.load(filename, allow_pickle=True)['data'].item()
    data_paired_natural = data['data_paired_natural']
    data_paired_madry = data['data_paired_madry']
    paired_shape_operators = data['paired_shape_operators']
    paired_principal_curvatures = data['paired_principal_curvatures']
    paired_principal_directions = data['paired_principal_directions']
    paired_mean_curvatures = data['paired_mean_curvatures']
    paired_origin_indices = data['paired_origin_indices']
    adv_shape_operators = data['adv_shape_operators']
    adv_principal_curvatures = data['adv_principal_curvatures']
    adv_principal_directions = data['adv_principal_directions']
    adv_mean_curvatures = data['adv_mean_curvatures']
    adv_origin_indices = data['adv_origin_indices']
    del data
else:
    #data_paired_natural = generate_paired_dict(data_natural, model_natural, num_images, num_advs)
    #data_paired_madry = generate_paired_dict(data_madry, model_madry, num_images, num_advs)
    paired_condition_zip = zip([model_natural, model_madry], [data_paired_natural, data_paired_madry])
    paired_shape_operators, paired_principal_curvatures, paired_principal_directions, paired_origin_indices = get_curvature(
        paired_condition_zip, num_images, num_advs, num_eps, batch_size, buffer_portion, autodiff)
    paired_mean_curvatures = np.mean(paired_principal_curvatures, axis=-1)
    adv_condition_zip = zip([model_natural, model_madry], [data_natural, data_madry])
    adv_shape_operators, adv_principal_curvatures, adv_principal_directions, adv_origin_indices = get_curvature(
        adv_condition_zip, num_images, num_advs, num_eps, batch_size, buffer_portion, autodiff)
    adv_mean_curvatures = np.mean(adv_principal_curvatures, axis=-1)
    save_dict = {
        'data_paired_natural':data_paired_natural,
        'data_paired_madry':data_paired_madry,
        'paired_shape_operators': paired_shape_operators,
        'paired_principal_curvatures': paired_principal_curvatures,
        'paired_principal_directions': paired_principal_directions,
        'paired_mean_curvatures': paired_mean_curvatures,
        'paired_origin_indices': paired_origin_indices,
        'adv_shape_operators': adv_shape_operators,
        'adv_principal_curvatures': adv_principal_curvatures,
        'adv_principal_directions': adv_principal_directions,
        'adv_mean_curvatures': adv_mean_curvatures,
        'adv_origin_indices': adv_origin_indices
    }
    np.savez(filename, data=save_dict)

In [None]:
def test_model_and_data(data_nat, data_mad):
    diff_str = '_diff' if autodiff else '_nodiff'
    for model, data, name in zip([model_natural, model_madry], [data_nat, data_mad], ['natural'+diff_str, 'madry'+diff_str]):
        image_shape = data['images'][0,...].shape
        clean_count = 0
        adv_count = 0
        clean_imgs = torchify(data['images'])
        clean_model_predictions = torch.argmax(model(clean_imgs), dim=1).detach().cpu().numpy()
        for img_idx in range(data['images'].shape[0]):
            label = int(data['labels'][img_idx])
            prediction = int(clean_model_predictions[img_idx])
            #assert  prediction == label , f'{name}: img_idx={img_idx}; prediction={prediction}; label={label}'
            #if clean_model_predictions[img_idx] != data['labels'][img_idx]: print(f'{name}: img_idx={img_idx}; prediction={clean_model_predictions[img_idx]}; label={data["labels"][img_idx]}')
            if clean_model_predictions[img_idx] != data['labels'][img_idx]: clean_count += 1
            adv_imgs = torchify(data['advs'][img_idx, ...].reshape((-1,)+image_shape))
            adv_model_predictions = torch.argmax(model(adv_imgs), dim=1).detach().cpu().numpy()
            for adv_idx in range(data['adv_class'].shape[1]):
                if np.isfinite(data['pert_lengths'][img_idx, adv_idx]):
                    label = int(data['adv_class'][img_idx, adv_idx])
                    prediction = int(adv_model_predictions[adv_idx])
                    #assert  prediction == label, f'{name}: img_idx={img_idx}, adv_idx={adv_idx}, prediction={prediction}, adv_label={label}'
                    #if adv_model_predictions[adv_idx] != data['adv_class'][img_idx, adv_idx]: print(f'{name}: img_idx={img_idx}, adv_idx={adv_idx}, prediction={adv_model_predictions[adv_idx]}, label={data["adv_class"][img_idx, adv_idx]}')
                    if adv_model_predictions[adv_idx] != data['adv_class'][img_idx, adv_idx]: adv_count += 1
        print(f'{name}: number of clean images with bad predictions = {clean_count}; number of adv images with bad predictions = {adv_count}')

test_model_and_data(data_paired_natural, data_paired_madry)
test_model_and_data(data_natural, data_madry)

In [None]:
bar_width = 0.5
fig, axs = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(12,7))
fig.subplots_adjust(top=0.8)

for data_idx, mean_curvatures in enumerate([paired_mean_curvatures, adv_mean_curvatures]):
    for model_idx in range(2):
        boxprops = dict(color=plot_colors[model_idx], linewidth=1.5, alpha=0.7)
        whiskerprops = dict(color=plot_colors[model_idx], alpha=0.7)
        capprops = dict(color=plot_colors[model_idx], alpha=0.7)
        medianprops = dict(linestyle='--', linewidth=0.5, color=plot_colors[model_idx])
        meanpointprops = dict(marker='o', markeredgecolor='black',
                              markerfacecolor=plot_colors[model_idx])
        meanprops = dict(linestyle='-', linewidth=0.5, color=plot_colors[model_idx])
        data = mean_curvatures[model_idx, :, :].reshape(-1)
        axs[data_idx].boxplot(data, sym='', positions=[model_idx], whis=(10, 90), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,
            whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops, meanprops=meanprops)
    axs[data_idx].set_xticks([0, 1], minor=False)
    axs[data_idx].set_xticks([], minor=True)
    axs[data_idx].set_xticklabels(model_types)
    if data_idx == 0:
        axs[data_idx].set_ylabel('Mean curvature')
        axs[data_idx].set_title('Paired image boundary')
    else:
        axs[data_idx].set_title('Adversarial image boundary')

for model_idx in range(adv_mean_curvatures.shape[0]):
    boxprops = dict(color=plot_colors[model_idx], linewidth=1.5, alpha=0.7)
    whiskerprops = dict(color=plot_colors[model_idx], alpha=0.7)
    capprops = dict(color=plot_colors[model_idx], alpha=0.7)
    medianprops = dict(linestyle='--', linewidth=0.5, color=plot_colors[model_idx])
    meanpointprops = dict(marker='o', markeredgecolor='black',
                          markerfacecolor=plot_colors[model_idx])
    meanprops = dict(linestyle='-', linewidth=0.5, color=plot_colors[model_idx])
    for adv_idx in range(adv_mean_curvatures.shape[-1]):
        data = adv_mean_curvatures[model_idx, :, adv_idx].reshape(-1)
        axs[2].boxplot(data, sym='', positions=[adv_idx], whis=(10, 90), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,
            whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops, meanprops=meanprops)
axs[2].set_title('Adversarial image boundary')
axs[2].set_xlabel('Dimension number')
axs[2].set_xticks([i for i in range(adv_mean_curvatures.shape[-1])], minor=False)
axs[2].set_xticks([], minor=True)
axs[2].set_xticklabels([str(i+1) for i in range(adv_mean_curvatures.shape[-1])])

def make_space_above(axes, topmargin=1):
    """ increase figure size to make topmargin (in inches) space for 
        titles, without changing the axes sizes
        obtained from: https://stackoverflow.com/a/55768955/
    """
    fig = axes.flatten()[0].figure
    s = fig.subplotpars
    w, h = fig.get_size_inches()

    figh = h - (1 - s.top) * h + topmargin
    fig.subplots_adjust(bottom=s.bottom*h/figh, top=1-topmargin/figh)
    fig.set_figheight(figh)

make_space_above(axs, topmargin=0.5)  

fig.suptitle(f'Curvature at the decision boundary\nfor {num_images} images and the first {num_advs} adversarial directions', y=1.0)
plt.show()
#fig.savefig('../data/cifar_mean_curvature_boxplots.pdf', transparent=True, bbox_inches='tight', pad_inches=0.01)

In [None]:
bar_width = 0.5
fig, axs = pplt.subplots(nrows=1, ncols=2, sharey=True, figwidth=figwidth)
titles = ['Test image boundary', 'Adversarial image boundary']
for data_idx, mean_curvatures in enumerate([paired_mean_curvatures, adv_mean_curvatures]):
    data = pd.DataFrame(mean_curvatures.reshape(-1, np.prod(mean_curvatures.shape[1:])).transpose(1, 0),
                        columns=pd.Index(model_types, name=''))
    axs[data_idx].boxplot(data, fill=True, mean=True,
                          cycle=pplt.Cycle(plot_colors),
                          linewidth=0.5,
                          meanlinestyle='-', medianlinestyle='--',
                          marker='o', markersize=1.0
                         )
    axs[data_idx].format(
        xticklabels=model_types,
        ylabel='Mean curvature',
        title=titles[data_idx],
        xgrid=False
    )
    axs[data_idx].axhline(0.0, color='black', linestyle='dashed', linewidth=0.5)


#axs.format(
#    suptitle=f'Curvature at the decision boundary',#\naveraged across {num_images} images and the first {num_advs} adversarial directions'
#)

pplt.show()
fig.savefig('../data/cifar_mean_curvature_boxplots.pdf', transparent=True, bbox_inches='tight', pad_inches=0.01, dpi=dpi)

In [None]:
num_models, num_images, num_advs, num_dims = adv_principal_curvatures.shape

fig, ax = pplt.subplots(nrows=1, ncols=1, figwidth=figwidth_inch/2, dpi=dpi, sharey=False, sharex=False)
for image_idx in range(num_images):
    for adv_idx in range(num_advs):
        ax.scatter(adv_principal_curvatures[0, image_idx, adv_idx, :],
                   s=0.01, c=plot_colors[0])
        ax.scatter(adv_principal_curvatures[1, image_idx, adv_idx, :],
                   s=0.01, c=plot_colors[1])
        
ix = ax.inset(
    bounds=[200, 0.50, 400, 1.5],
    transform='data', zoom=True,
    zoom_kw={'edgecolor': 'k', 'lw': 1, 'ls': '--'}
)
ix.format(
    xlim=(0, num_dims), ylim=(-0.02, 0.02), metacolor='red7',
    grid=False,
    linewidth=1.5, ticklabelweight='bold'
)
ix.plot([0, num_dims], [0, 0], lw=0.1, c='k')
ix.scatter(adv_principal_curvatures[0, ...].mean(axis=(0, 1)),
           s=0.005, alpha=1.0, c=plot_colors[0])
ix.scatter(adv_principal_curvatures[1, ...].mean(axis=(0, 1)),
           s=0.005, alpha=1.0, c=plot_colors[1])

ax.format(
    title=f'Curvature profile, averaged across {num_images} images',
    xlim=(-5, num_dims+5),
    ylabel='Curvature',
    xlabel='Principal curvature direction',
    grid=False
)
for ax_loc in ['top', 'right']:
    ax.spines[ax_loc].set_color('none')
pplt.show()

fig.savefig('../data/cifar_curvature_profile.pdf', transparent=True, bbox_inches='tight', pad_inches=0.01, dpi=dpi)

In [None]:
# TODO:
# y & x axis should be the same scale
# no need to have a dot on the y axis, should use an arrow instead
num_plot_images = 3
offset = 0

for model_idx, (model_, data_) in enumerate(zip([model_natural, ], [data_natural, ])):
    num_models, num_images, num_advs, num_directions, num_pixels = adv_principal_directions.shape
    
    fig, axs = pplt.subplots(nrows=3, ncols=num_plot_images, figwidth=figwidth, dpi=dpi, hspace=2, wspace=2)
    axs.format(
        ylabel = 'Principal curvature direction'
    )
    for image_idx in range(offset, offset+num_plot_images):
        adv_idx = np.random.randint(low=0, high=num_advs)
        most_flat_index = np.argmin(np.abs(adv_principal_curvatures[model_idx, image_idx, adv_idx, :]))
        curvature_indices = [0, most_flat_index, -1]
        dataset_image_idx = adv_origin_indices[model_idx, image_idx]
        origin = data_['images'][dataset_image_idx, ...]

        boundary_image = get_paired_boundary_image(
            model=model_,
            origin=origin,
            alt_image=data_['advs'][dataset_image_idx, adv_idx, ...],
            num_steps_per_iter=num_steps_per_iter,
            num_iters=num_iters
        )[0]
        boundary_dist = np.linalg.norm(boundary_image.reshape(-1) - origin.reshape(-1))

        for ax_idx, curvature_idx in enumerate(curvature_indices):
            ax = axs[ax_idx, image_idx-offset]
            adv1 = boundary_image.reshape(-1)
            principal_direction = adv_principal_directions[model_idx, image_idx, adv_idx, :, curvature_idx]
            # Source of error?
            adv2 = origin.reshape(-1) + boundary_dist * principal_direction
            dec_advs, labels = pl.plot_dec_space(origin[None, ...], adv1, adv2, model_, offset=1.0,
                              n_grid=100, len_grid_scale=1.8, show_legend=False, show_advs=True,
                              overlay_inbounds=True, ax=ax)
            ax.legend(handles=labels, loc='upper left', ncols=1, title='predicted\nclass')
            ax.format(
                title=f'Curvature = {adv_principal_curvatures[model_idx, image_idx, adv_idx, curvature_idx]:.5f}',
                xlabel=f'Adversarial direction number {adv_idx}',
            )
    plt.show()
    fig.savefig(f'../data/cifar_curvature_visualizations.pdf', transparent=True, bbox_inches='tight', pad_inches=0.01, dpi=dpi)

In [None]:
filename = '../data/cifar_subspace_curvatures_autodiff.npz'

if load:
    data = np.load(filename, allow_pickle=True)['data'].item()
    rand_pcs = data['random_principal_curvatures']
    adv_pcs = data['adversarial_principal_curvatures']

else:
    dtype = torch.double
    subspace_size = num_advs-1

    rand_pcs = np.zeros((num_models, num_images, subspace_size))
    adv_pcs = np.zeros((num_models, num_images, subspace_size))
    for model_idx, (model_, data_) in enumerate(zip((model_natural, model_madry), (data_natural, data_madry))):
        pbar = tqdm(total=num_images, leave=False)
        for image_idx, origin_idx in enumerate(adv_origin_indices[model_idx, :]):
            for adv_idx in range(num_advs):
                boundary_image, boundary_dir, pert_length = get_paired_boundary_image(
                        model=model_,
                        origin=data_['images'][origin_idx, ...],
                        alt_image=data_['advs'][origin_idx, adv_idx, ...],
                        num_steps_per_iter=num_steps_per_iter,
                        num_iters=num_iters)

                clean_lbl = int(data_['labels'][origin_idx])
                adv_lbl = int(data_['adv_class'][origin_idx, adv_idx])

                activation, gradient = paired_activation_and_gradient(model_, torchify(boundary_image[None, ...]), clean_lbl, adv_lbl)
                gradient = gradient.reshape(-1).type(dtype)
                def func(x):
                    acts_diff = paired_activation(model_, x, clean_lbl, adv_lbl)
                    return acts_diff
                hessian = torch.autograd.functional.hessian(func, torchify(boundary_image[None,...]))
                hessian = hessian.reshape((int(boundary_image.size), int(boundary_image.size))).type(dtype)

                dirs = [(gradient / torch.linalg.norm(gradient)).detach().cpu().numpy()]
                n_pixels = gradient.numel()
                n_iterations = 3
                random_basis = torch.from_numpy(make_orth_basis(dirs, n_pixels, n_iterations)[:subspace_size, :]).type(dtype)

                curvature = curve_utils.local_response_curvature_isoresponse_surface(gradient, hessian, projection_subspace_of_interest=random_basis)
                rand_subspace_curvatures = curvature[1].detach().cpu().numpy()

                adv_dirs = data_['dirs'][origin_idx, :num_advs, ...].reshape(num_advs, n_pixels)
                if adv_idx > 0 and adv_idx < num_advs: # exclude current perturbation direction
                    adv_dirs = np.concatenate((adv_dirs[:adv_idx], adv_dirs[adv_idx+1:]))
                elif adv_idx == 0:
                    adv_dirs = adv_dirs[adv_idx+1:]
                else:
                    adv_dirs = adv_dirs[:adv_idx]

                adv_basis = torch.from_numpy(adv_dirs).type(dtype)
                curvature = curve_utils.local_response_curvature_isoresponse_surface(gradient, hessian, projection_subspace_of_interest=adv_basis)
                adv_subspace_curvatures = curvature[1].detach().cpu().numpy()

                rand_pcs[model_idx, image_idx, :] = rand_subspace_curvatures
                adv_pcs[model_idx, image_idx, :] = adv_subspace_curvatures
            pbar.update(1)
        pbar.close()

        save_dict = {
            'random_principal_curvatures':rand_pcs,
            'adversarial_principal_curvatures':adv_pcs,
        }
        np.savez(filename, data=save_dict)

In [None]:
fig, axs = pplt.subplots(nrows=1, ncols=2, figwidth=figwidth_inch, dpi=dpi)

percentiles = np.percentile(rand_pcs[0, ...], (5, 95), axis=0)
std = np.std(rand_pcs[0, ...], axis=0)

axs[0].scatter(rand_pcs[0, ...].mean(axis=0), s=2.0, c=plot_colors[0],
               bardata=np.std(rand_pcs[0, ...], axis=0), barc=plot_colors[0], barlw=0.5, capsize=0.0,)
axs[0].scatter(rand_pcs[1, ...].mean(axis=0), s=2.0, c=plot_colors[1],
               bardata=np.std(rand_pcs[1, ...], axis=0), barc=plot_colors[1], barlw=0.5, capsize=0.0,)
axs[0].axhline(0.0, color='black', linestyle='dashed', linewidth=0.5)
axs[0].format(
    title=f'Random subspaces'
)
for ax_loc in ['top', 'right']:
    axs[0].spines[ax_loc].set_color('none')

axs[1].scatter(adv_pcs[0, ...].mean(axis=0), s=2.0, c=plot_colors[0],
               bardata=np.std(adv_pcs[0, ...], axis=0), barc=plot_colors[0], barlw=0.5, capsize=0.0,)
axs[1].scatter(adv_pcs[1, ...].mean(axis=0), s=2.0, c=plot_colors[1],
               bardata=np.std(adv_pcs[1, ...], axis=0), barc=plot_colors[1], barlw=0.5, capsize=0.0,)
axs[1].axhline(0.0, color='black', linestyle='dashed', linewidth=0.5)
axs[1].format(
    title=f'Adversarial subspaces'
)
for ax_loc in ['top', 'right']:
    axs[1].spines[ax_loc].set_color('none')

axs.format(
    ylabel='Curvature',
    xlabel='Principal curvature directions',
    grid=False
)

legend_handles = [mpatches.Patch(color=plot_colors[0], label='Natural'),
                  mpatches.Patch(color=plot_colors[1], label='Adversarial')]
axs[0].legend(handles=legend_handles, loc='upper right', ncols=1, frame=False)

pplt.show()

fig.savefig(f'../data/cifar_subspace_curvatures.pdf', transparent=True, bbox_inches='tight', pad_inches=0.01, dpi=dpi)

In [None]:
import svgutils.compose as sc
from IPython.display import SVG, Image 

In [None]:
#sc.Figure(figwidth, figwidth, 
#    sc.Panel(sc.SVG('../data/mean_curvature_boxplots.svg')),
#    sc.Panel(sc.SVG('../data/subspace_curvatures.svg')),
#    sc.Panel(sc.SVG('../data/curvature_profile.svg'))
#    ).save('compose.svg')
#SVG('compose.svg')

In [None]:
import xml.etree.ElementTree as etree
import re
from six import StringIO
import requests

import svgpathtools as svgpt
from svgpath2mpl import parse_path

In [None]:
diagram = '../data/curvature_diagram.svg'
imported_diagram = SVG(diagram)

In [None]:
imported_diagram

In [None]:
curve_paths, curve_attributes, svg_attributes = svgpt.svg2paths2(diagram)

In [None]:
svg_attributes

In [None]:
curve_attributes[0]

In [None]:
curve_attributes[3]

In [None]:
def normalize_hex(c):
    if c.startswith('#') and len(c) == 4:
        return '#{0}{0}{1}{1}{2}{2}'.format(c[1], c[2], c[3])
    return c

paths = [parse_path(attrib['d']) for attrib in curve_attributes]
facecolors = [normalize_hex(attrib.get('fill', 'none')) for attrib in curve_attributes]
edgecolors = [normalize_hex(attrib.get('stroke', 'none')) for attrib in curve_attributes]
linewidths = [attrib.get('stroke_width', 1) for attrib in curve_attributes]
collection = mpl.collections.PathCollection(paths, 
                                      edgecolors=edgecolors, 
                                      linewidths=linewidths,
                                      facecolors=facecolors)
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)
collection.set_transform(ax.transData)
ax.add_artist(collection)
#ax.set_xlim([0, 960])
#ax.set_ylim([540, 0])

In [None]:
paths = [parse_path(attrib['d']) for attrib in curve_attributes]
facecolors = [attrib.get('fill', 'none') for attrib in curve_attributes]
edgecolors = [attrib.get('stroke', 'none') for attrib in curve_attributes]
linewidths = [attrib.get('stroke_width', 1) for attrib in curve_attributes]
collection = mpl.collections.PathCollection(paths, 
                                      edgecolors=edgecolors, 
                                      linewidths=linewidths,
                                      facecolors=facecolors)
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)
collection.set_transform(ax.transData)
ax.add_artist(collection)
ax.set_xlim([0, 960])
ax.set_ylim([540, 0])

In [None]:
r = requests.get('http://thenewcode.com/assets/images/thumbnails/homer-simpson.svg')
tree = etree.parse(StringIO(r.text))
root = tree.getroot()
width = int(re.match(r'\d+', root.attrib['width']).group())
height = int(re.match(r'\d+', root.attrib['height']).group())
path_elems = root.findall('.//{http://www.w3.org/2000/svg}path')
paths = [parse_path(elem.attrib['d']) for elem in path_elems]
facecolors = [elem.attrib.get('fill', 'none') for elem in path_elems]
edgecolors = [elem.attrib.get('stroke', 'none') for elem in path_elems]
linewidths = [elem.attrib.get('stroke_width', 1) for elem in path_elems]
collection = mpl.collections.PathCollection(paths, 
                                      edgecolors=edgecolors, 
                                      linewidths=linewidths,
                                      facecolors=facecolors)
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111)
collection.set_transform(ax.transData)
ax.add_artist(collection)
ax.set_xlim([0, width])
ax.set_ylim([height, 0])