In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, './..')
sys.path.insert(0, '../data')

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import matplotlib.gridspec as gridspec
import proplot as pplt
from tqdm import tqdm

import numpy as np
import torch
from torchvision import datasets, transforms

from models import model, eval
import plots as pl
from utils import dev, load_data, classification

sys.path.insert(0, './../../')

import response_contour_analysis.utils.model_handling as model_utils
import response_contour_analysis.utils.dataset_generation as data_utils
import response_contour_analysis.utils.histogram_analysis as hist_utils
import response_contour_analysis.utils.principal_curvature as curve_utils
import response_contour_analysis.utils.plotting as plot_utils

# check device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


# Load data & models

In [2]:
seed = 0

# load data
data_natural = np.load(f'../data/natural_{seed}.npy', allow_pickle=True).item()
advs_nat = data_natural['advs']
pert_lengths_nat = data_natural['pert_lengths']
classes_nat = data_natural['adv_class']
dirs_nat = data_natural['dirs']
images_nat = data_natural['images']
labels_nat = data_natural['labels']

data_madry = np.load(f'../data/robust_{seed}.npy', allow_pickle=True).item()
advs_madry = data_madry['advs']
pert_lengths_madry = data_madry['pert_lengths']
classes_madry = data_madry['adv_class']
dirs_madry = data_madry['dirs']
images_madry = data_madry['images']
labels_madry = data_madry['labels']

In [3]:
# load models
model_natural = model.madry()
model_natural.load_state_dict(torch.load(f'./../models/natural_{seed}.pt', map_location=DEVICE))
model_natural.to(DEVICE)

model_madry = model.madry()
model_madry.load_state_dict(torch.load(f'./../models/robust_{seed}.pt', map_location=DEVICE))
model_madry.to(DEVICE)

model_random = model.madry()
model_random.load_state_dict(torch.load('./../models/random.pt', map_location=DEVICE))
model_random.to(DEVICE)

madry(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxPool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxPool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=3136, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=10, bias=True)
  (relu): ReLU()
)

# Plane curvature for one pair of dirs

In [None]:
#experiment_params = dict()

#experiment_params['target_model_id'] = 0
#experiment_params['data_shape'] = images_nat[0,...].shape
#experiment_params['window_scale'] = 2.0
#experiment_params['num_edge_images'] = 30
#experiment_params['target'] = 0.0
#experiment_params['target_is_act'] = True

#experiment_params['num_images'] = int(experiment_params['num_edge_images']**2)
#experiment_params['x_range'] = (-experiment_params['window_scale'], experiment_params['window_scale'])
#experiment_params['y_range'] = experiment_params['x_range']
#experiment_params['device'] = DEVICE
#yx_range = experiment_params['yx_range'] = (experiment_params['y_range'], experiment_params['x_range'])

hess_params = dict()
hess_params['hessian_num_pts'] = 1e4
hess_params['hessian_lr'] = 1e-3
hess_params['hessian_random_walk'] = False
hess_params['return_points'] = False
buffer_portion = 0.25
num_eps = 1000
batch_size = 1000
num_images = 100#labels_nat.size

In [14]:
def get_boundary_image(model, origin, direction, length, origin_class, adv_class, num_eps, buffer_portion, batch_size):
    linspace_min = buffer_portion * length
    linspace_max = (1 + buffer_portion) * length
    eps_vals = np.linspace(linspace_min, linspace_max, num_eps)
    direction = direction.reshape(1, direction.size)
    eps_vals = eps_vals.reshape(-1, 1)
    origin = origin.reshape(1, origin.size)
    adv_line = origin + (direction * eps_vals)
    adv_line = adv_line.reshape(-1, 1, int(np.sqrt(img.size)), int(np.sqrt(img.size)))
    num_batches = int(np.ceil(num_eps / batch_size))
    input_batches = torch.split(torch.from_numpy(adv_line).to(DEVICE).float(), num_batches)
    model_outputs = np.empty((0, 10))
    for batch in input_batches:
        batch_outputs = model(batch).detach().cpu().numpy()
        model_outputs = np.concatenate((model_outputs, batch_outputs), axis=0)
    decision_scores = model_outputs[:, int(origin_class)] - model_outputs[:, int(adv_class)]
    boundary_index = np.abs(decision_scores).argmin()
    boundary_image = adv_line[boundary_index]
    return boundary_image


def get_random_boundary_image(origins, directions, eps=0.01):
    """
    TODO: Take several steps at once to use speedup advantage from batching
    or batch over origins & directions to do many at once.
    """
    num_images, num_channels, num_rows, num_cols = origins.shape
    correct_labels = torch.argmax(model(origins))
    pert_label = correct_labels
    pert_image = origins.clone()
    num_steps = 0
    while torch.all(pert_label == correct_labels):
        pert_image = pert_image + directions * eps
        pert_label = torch.argmax(model(pert_image))
        num_steps += 1
    pert_image = origins + num_steps-1 * directions * eps
    small_eps = eps * 0.01
    num_small_steps = 0
    while pert_label == correct_labels:
        pert_image = pert_image + directions * small_eps
        pert_label = torch.argmax(model(pert_image))
        num_small_steps += 1
    return pert_image
    

def random_vector(size):
    components = [np.random.normal() for i in range(size)]
    radius = np.sqrt(sum(x**2 for x in components))
    vect = np.array([x/radius for x in components])
    return vect

def get_mean_curvatures(condition_zip, num_images, num_advs, num_eps, batch_size, buffer_portion, hess_radius_mult):
    mean_curvatures = np.zeros((2, num_images, 2))
    for model_idx, (model, data)  in enumerate(condition_zip):
        advs = data['advs']
        classes = data['adv_class']
        dirs = data['dirs']
        images = data['images']
        labels = data['labels']
        for image_idx in tqdm(range(num_images)):
            clean_id = int(labels[image_idx])
            adv_ids = classes[image_idx, ...]
            img_dirs = dirs[image_idx,...]
            img_advs = advs[image_idx,...]
            img = images[image_idx,...].reshape(-1)
            for adv_idx in range(2):
                boundary_image = get_boundary_image(
                    model=model,
                    origin=img,
                    direction=img_dirs[adv_idx],
                    length=np.linalg.norm(img_advs[adv_idx, ...].reshape(-1) - img),
                    origin_class=clean_id,
                    adv_class=adv_ids[adv_idx],
                    num_eps=num_eps,
                    buffer_portion=buffer_portion,
                    batch_size=batch_size) 
                hess_params['hessian_dist'] = np.linalg.norm(img_advs[adv_idx, ...].reshape(-1) - img) * hess_radius_mult # radius around the target image
                torch_image = torch.from_numpy(boundary_image).to(DEVICE).float()[None, ...]
                activation, gradient = model_utils.unit_activation_and_gradient(model, torch_image, clean_id)
                gradient = gradient.reshape(-1)
                def func(x):
                    model.zero_grad()
                    acts_diff = model(x)[:, clean_id] - model(x)[:, int(adv_ids[adv_idx])]
                    grad = torch.autograd.grad(acts_diff, x)[0]
                    return acts_diff, grad
                #func = lambda x : model_utils.unit_activation_and_gradient(model, x, clean_id)
                #hessian = torch.autograd.functional.hessian(func, torch_image, create_graph=True, strict=True)
                #hessian = hessian.reshape((int(torch_image.numel()), int(torch_image.numel())))
                hessian = curve_utils.sr1_hessian(
                    func, torch_image,
                    distance=hess_params['hessian_dist'],
                    n_points=hess_params['hessian_num_pts'],
                    random_walk=hess_params['hessian_random_walk'],
                    learning_rate=hess_params['hessian_lr'],
                    return_points=False,
                    progress=False)
                shape_operator, principal_curvatures, principal_directions = curve_utils.local_response_curvature(gradient, hessian)
                mean_curvatures[model_idx, image_idx, adv_idx] = np.mean(principal_curvatures.detach().cpu().numpy())
    return mean_curvatures

In [23]:
advs_nat.shape

(500, 30, 1, 784)

In [22]:
#dirs_nat.shape
#data_natural = np.load(f'../data/natural_{seed}.npy', allow_pickle=True).item()
#advs_nat = data_natural['advs']
#pert_lengths_nat = data_natural['pert_lengths']
#classes_nat = data_natural['adv_class']
#dirs_nat = data_natural['dirs']
#images_nat = data_natural['images']
#labels_nat = data_natural['labels']

num_images, num_advs, _, vector_length = data_natural['dirs'].shape
data_rand = {}
data_rand['images'] = data_natural['images'].copy()
data_rand['labels'] = data_natural['labels'].copy()

dirs = np.stack([random_vector(vector_length)[None, :] for _ in range(num_advs)], axis=0)
dirs = np.repeat(dirs[None, ...], num_images, axis=0)
data_rand['dirs'] = dirs

advs = get_random_boundary_image(origin, direction, eps=0.01)

(500, 30, 1, 784)

In [None]:
if random_dirs:
else:
    condition_zip = zip([model_natural, model_madry], [data_natural, data_madry])

hess_radius_mult = 0.1
num_advs = 2
get_mean_curvatures(condition_zip, num_images, num_advs, num_eps, batch_size, buffer_portion, hess_radius_mult):

In [None]:
colors = ['blue', 'orange']
bar_width = 0.5
fig, ax = plt.subplots(nrows=1, ncols=1)
for i in range(2):
    boxprops = dict(color=colors[i], linewidth=1.5, alpha=0.7)
    whiskerprops = dict(color=colors[i], alpha=0.7)
    capprops = dict(color=colors[i], alpha=0.7)
    medianprops = dict(linestyle='--', linewidth=0.5, color=colors[i])
    meanpointprops = dict(marker='o', markeredgecolor='black',
                          markerfacecolor=colors[i])
    meanprops = dict(linestyle='-', linewidth=0.5, color=colors[i])
    data = mean_curvatures[i, :, :].reshape(-1)
    ax.boxplot(data, sym='', positions=[i], whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,
        whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops, meanprops=meanprops)
ax.set_title(f'Curvature at the decision boundary for the first {mean_curvatures.shape[-1]} adversarial directions')
ax.set_ylabel('Mean curvature')
ax.set_xticks([0, 1], minor=False)
ax.set_xticks([], minor=True)
ax.set_xticklabels(['Naturally trained', 'Adversarially trained'])
plt.show()