In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, './..')
sys.path.insert(0, '../data')

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import matplotlib.gridspec as gridspec
import proplot as pplt
from tqdm import tqdm

import numpy as np
import torch
from torchvision import datasets, transforms

from models import model, eval
import plots as pl
from utils import dev, load_data, classification

sys.path.insert(0, './../../')

import response_contour_analysis.utils.model_handling as model_utils
import response_contour_analysis.utils.dataset_generation as data_utils
import response_contour_analysis.utils.histogram_analysis as hist_utils
import response_contour_analysis.utils.principal_curvature as curve_utils
import response_contour_analysis.utils.plotting as plot_utils

# check device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

# Experiment parameters

In [None]:
#experiment_params = dict()

#experiment_params['target_model_id'] = 0
#experiment_params['data_shape'] = images_nat[0,...].shape
#experiment_params['window_scale'] = 2.0
#experiment_params['num_edge_images'] = 30
#experiment_params['target'] = 0.0
#experiment_params['target_is_act'] = True

#experiment_params['num_images'] = int(experiment_params['num_edge_images']**2)
#experiment_params['x_range'] = (-experiment_params['window_scale'], experiment_params['window_scale'])
#experiment_params['y_range'] = experiment_params['x_range']
#experiment_params['device'] = DEVICE
#yx_range = experiment_params['yx_range'] = (experiment_params['y_range'], experiment_params['x_range'])

hess_params = dict()
hess_params['hessian_num_pts'] = 1e4
hess_params['hessian_lr'] = 1e-3
hess_params['hessian_random_walk'] = False
hess_params['return_points'] = False
num_iters = 2 # for paired image boundary search
num_steps_per_iter = 10#100 # for paired image boundary search
buffer_portion = 0.25
num_eps = 1000
batch_size = 1000
num_images = 50#labels_nat.size
hess_radius_mult = 0.1
num_advs = 8
autodiff = True

# Load data & models

In [None]:
seed = 0

# load data
data_natural = np.load(f'../data/natural_{seed}.npy', allow_pickle=True).item()
advs_nat = data_natural['advs']
pert_lengths_nat = data_natural['pert_lengths']
classes_nat = data_natural['adv_class']
dirs_nat = data_natural['dirs']
images_nat = data_natural['images']
labels_nat = data_natural['labels']

data_madry = np.load(f'../data/robust_{seed}.npy', allow_pickle=True).item()
advs_madry = data_madry['advs']
pert_lengths_madry = data_madry['pert_lengths']
classes_madry = data_madry['adv_class']
dirs_madry = data_madry['dirs']
images_madry = data_madry['images']
labels_madry = data_madry['labels']

In [None]:
# load models
if autodiff:
    model_natural = model.madry_diff()
    model_madry = model.madry_diff()
    model_random = model.madry_diff()
else:
    model_natural = model.madry()
    model_madry = model.madry()
    model_random = model.madry()

model_natural.load_state_dict(torch.load(f'./../models/natural_{seed}.pt', map_location=DEVICE))
model_natural.to(DEVICE)

model_madry.load_state_dict(torch.load(f'./../models/robust_{seed}.pt', map_location=DEVICE))
model_madry.to(DEVICE)

model_random.load_state_dict(torch.load('./../models/random.pt', map_location=DEVICE))
model_random.to(DEVICE)

# Plane curvature for one pair of dirs

In [None]:
def torchify(img):
    if autodiff:
        output = torch.from_numpy(img).type(torch.DoubleTensor).to(DEVICE)
    else:
        output = torch.from_numpy(img).type(torch.FloatTensor).to(DEVICE)
    return output


def get_adv_boundary_image(model, origin, direction, length, origin_class, adv_class, num_eps, buffer_portion, batch_size, autodiff):
    linspace_min = buffer_portion * length
    linspace_max = (1 + buffer_portion) * length
    eps_vals = np.linspace(linspace_min, linspace_max, num_eps)
    direction = direction.reshape(1, direction.size)
    eps_vals = eps_vals.reshape(-1, 1)
    origin = origin.reshape(1, origin.size)
    adv_line = origin + (direction * eps_vals)
    adv_line = adv_line.reshape(-1, 1, int(np.sqrt(origin.size)), int(np.sqrt(origin.size)))
    num_batches = int(np.ceil(num_eps / batch_size))
    if autodiff:
        input_batches = torch.split(torch.from_numpy(adv_line).type(torch.DoubleTensor).to(DEVICE), num_batches)
    else:
        input_batches = torch.split(torch.from_numpy(adv_line).type(torch.FloatTensor).to(DEVICE), num_batches)
    model_outputs = np.empty((0, 10))
    for batch in input_batches:
        batch_outputs = model(batch).detach().cpu().numpy()
        model_outputs = np.concatenate((model_outputs, batch_outputs), axis=0)
    decision_scores = model_outputs[:, int(origin_class)] - model_outputs[:, int(adv_class)]
    boundary_index = np.abs(decision_scores).argmin()
    boundary_image = adv_line[boundary_index]
    return boundary_image


def get_paired_boundary_image(model, origin, alt_image, num_steps_per_iter, num_iters):
    num_channels, num_rows, num_cols = origin.shape
    input_shape = [1, num_channels, num_rows, num_cols]
    def find_pert(img_line):
        correct_lbl = torch.argmax(model(torchify(img_line[0, ...].reshape(input_shape))))
        pert_lbl = correct_lbl.clone()
        step_idx = 1 # already know the first one
        while pert_lbl == correct_lbl:
            pert_img = img_line[step_idx, ...]
            pert_lbl = torch.argmax(model(torchify(img_line[step_idx, ...].reshape(input_shape))))
            step_idx += 1
        return step_idx-1, pert_img
    img_line = np.linspace(origin.reshape(-1), alt_image.reshape(-1), num_steps_per_iter)
    for search_iter in range(num_iters):
        step_idx, pert_img = find_pert(img_line)
        img_line = np.linspace(img_line[step_idx - 1, ...], img_line[step_idx, ...], num_steps_per_iter)
    delta_image = origin - pert_image
    pert_length = np.linalg.norm(delta_image)
    direction = delta_image / pert_length
    return pert_image.reshape(origin.shape), direction, pert_length


def get_random_boundary_image(origin, eps=0.01, max_dist=4):
    num_channels, num_rows, num_cols = origin.shape
    direction = random_vector(num_channels * num_rows * num_cols)
    correct_label = torch.argmax(model(origin))
    pert_label = correct_label.clone()
    pert_image = origin.clone()
    num_steps = 0
    while pert_label == correct_label:
        pert_image = pert_image + direction * eps
        pert_label = torch.argmax(model(pert_image))
        num_steps += 1
    pert_image = origin + num_steps-1 * direction * eps
    small_eps = eps * 0.01
    num_small_steps = 0
    while pert_label == correct_label:
        pert_image = pert_image + direction * small_eps
        pert_label = torch.argmax(model(pert_image))
        num_small_steps += 1
    return pert_image
    

def random_vector(size):
    components = [np.random.normal() for i in range(size)]
    radius = np.sqrt(sum(x**2 for x in components))
    vect = np.array([x/radius for x in components])
    return vect


def get_curvature(condition_zip, num_images, num_advs, num_eps, batch_size, buffer_portion, hess_radius_mult, autodiff=False):
    models, adv_data = zip(*condition_zip)
    num_models = len(models)
    img_size = np.prod(adv_data[0]['images'][0,...].shape)
    shape_operators = np.zeros((num_models, num_images, num_advs, img_size, img_size))
    principal_curvatures = np.zeros((num_models, num_images, num_advs, img_size))
    principal_directions = np.zeros((num_models, num_images, num_advs, img_size, img_size))
    mean_curvatures = np.zeros((num_models, num_images, num_advs))
    for model_idx, (model, data)  in enumerate(zip(models, adv_data)):
        advs = data['advs']
        classes = data['adv_class']
        dirs = data['dirs']
        images = data['images']
        labels = data['labels']
        pbar = tqdm(total=num_images, leave=True)
        image_idx = 0
        processed_images = 0
        while processed_images < num_images:
            clean_id = int(labels[image_idx])
            adv_ids = classes[image_idx, ...]
            img_dirs = dirs[image_idx, ...]
            img_advs = advs[image_idx, ...]
            img = images[image_idx, ...]
            if autodiff:
                torch_image = torch.from_numpy(img).type(torch.DoubleTensor).to(DEVICE)[None, ...]
            else:
                torch_image = torch.from_numpy(img).type(torch.FloatTensor).to(DEVICE)[None, ...]
            if torch.argmax(model(torch_image)) == clean_id: # model labeled the clean image correctly
                img = img.reshape(-1)
                for adv_idx in range(num_advs):
                    boundary_image = get_adv_boundary_image(
                        model=model,
                        origin=img,
                        direction=img_dirs[adv_idx],
                        length=np.linalg.norm(img_advs[adv_idx, ...].reshape(-1) - img),
                        origin_class=clean_id,
                        adv_class=adv_ids[adv_idx],
                        num_eps=num_eps,
                        buffer_portion=buffer_portion,
                        batch_size=batch_size,
                        autodiff=autodiff) 
                    hess_params['hessian_dist'] = np.linalg.norm(img_advs[adv_idx, ...].reshape(-1) - img) * hess_radius_mult # radius around the target image
                    if autodiff:
                        torch_image = torch.from_numpy(boundary_image).type(torch.DoubleTensor).to(DEVICE)[None, ...]
                    else:
                        torch_image = torch.from_numpy(boundary_image).type(torch.FloatTensor).to(DEVICE)[None, ...]
                    activation, gradient = model_utils.unit_activation_and_gradient(model, torch_image, clean_id)
                    gradient = gradient.reshape(-1)
                    if autodiff:
                        def func(x):
                            model.zero_grad()
                            acts_diff = model(x)[:, clean_id] - model(x)[:, int(adv_ids[adv_idx])]
                            return acts_diff
                        hessian = torch.autograd.functional.hessian(func, torch_image).reshape(784,784)
                        hessian = hessian.reshape((int(torch_image.numel()), int(torch_image.numel())))
                    else:
                        def func(x):
                            model.zero_grad()
                            acts_diff = model(x)[:, clean_id] - model(x)[:, int(adv_ids[adv_idx])]
                            grad = torch.autograd.grad(acts_diff, x)[0]
                            return acts_diff, grad
                        hessian = curve_utils.sr1_hessian(
                            func, torch_image,
                            distance=hess_params['hessian_dist'],
                            n_points=hess_params['hessian_num_pts'],
                            random_walk=hess_params['hessian_random_walk'],
                            learning_rate=hess_params['hessian_lr'],
                            return_points=False,
                            progress=False)
                    curvature = curve_utils.local_response_curvature(gradient, hessian)
                    shape_operators[model_idx, processed_images, adv_idx, ...] = curvature[0].detach().cpu().numpy()
                    principal_curvatures[model_idx, processed_images, adv_idx, :] = curvature[1].detach().cpu().numpy()
                    principal_directions[model_idx, processed_images, adv_idx, ...] = curvature[2].detach().cpu().numpy()
                    mean_curvatures[model_idx, processed_images, adv_idx] = np.mean(curvature[1].detach().cpu().numpy())
                pbar.update(1)
                processed_images += 1
            image_idx += 1
        pbar.close()
    return shape_operators, principal_curvatures, principal_directions, mean_curvatures

In [None]:
img1 = images_nat[0, ...]
img_class1 = int(labels_nat[0])
class_idx = 1
img_class2 = int(labels_nat[class_idx])
while img_class1 == img_class2:
    class_idx += 1
    img_class2 = int(labels_nat[class_idx])
img2 = images_nat[class_idx, ...]
img_line = np.linspace(img1.reshape(-1), img2.reshape(-1), num_steps_per_iter)

num_channels, num_rows, num_cols = images_nat[0,...].shape
input_shape = [1, num_channels, num_rows, num_cols]
correct_label = torch.argmax(model_natural(torchify(img1.reshape(input_shape))))

pert_label = correct_label.clone()
step_idx = 0
while pert_label == correct_label:
    pert_image = img_line[step_idx, ...].reshape(input_shape)
    pert_label = torch.argmax(model_natural(torchify(pert_image)))
    step_idx += 1

sm_img_line = np.linspace(img_line[step_idx - 2, ...], img_line[step_idx-1, ...], num_steps_per_iter)
pert_label = torch.argmax(model_natural(torchify(sm_img_line[0]).reshape(input_shape)))
small_step_idx = 1
while pert_label == correct_label:
    pert_image = sm_img_line[small_step_idx, ...].reshape(input_shape)
    pert_label = torch.argmax(model_natural(torchify(pert_image)))
    small_step_idx += 1

fig, ax = plt.subplots(1, 6)
ax[0].imshow(img_line[0, ...].reshape(images_nat[0,...].shape[-2:]), cmap='greys', vmin=0, vmax=1)
ax[1].imshow(img_line[num_steps_per_iter//2, ...].reshape(images_nat[0,...].shape[-2:]), cmap='greys', vmin=0, vmax=1)
ax[2].imshow(img_line[-1, ...].reshape(images_nat[0,...].shape[-2:]), cmap='greys', vmin=0, vmax=1)
ax[3].imshow(pert_image.reshape(images_nat[0,...].shape[-2:]), cmap='greys', vmin=0, vmax=1)

new_pert_image = get_paired_boundary_image(model_natural, img1, img2, num_steps_per_iter, num_iters)[0]
ax[4].imshow(new_pert_image.reshape(images_nat[0,...].shape[-2:]), cmap='greys', vmin=0, vmax=1)
ax[5].imshow(pert_image.reshape(images_nat[0,...].shape[-2:]) - new_pert_image.reshape(images_nat[0,...].shape[-2:]), cmap='greys', vmin=0, vmax=1)

fig.tight_layout()
plt.show()

print(img_class1, img_class2, pert_label)
print(img_line.shape)
print(np.linalg.norm(pert_image.reshape(-1) - img1.reshape(-1)))
print(np.allclose(pert_image.reshape(-1), new_pert_image.reshape(-1)))
print(np.linalg.norm(pert_image.reshape(-1) - new_pert_image.reshape(-1)))

In [None]:
def generate_paired_dict(data_dict, model, num_images):
    output_dict = {}
    output_dict['images'] = data_dict['images'].copy()
    output_dict['labels'] = data_dict['labels'].copy()
    origin = data_dict['images'][0, ...]
    num_adv_directions = data_dict['adv_class'].shape[1]
    shuffled_indices = np.random.choice([int(i) for i in range(data_dict['labels'].size)], size=data_dict['labels'].size, replace=False)
    dirs = []
    advs = []
    pert_lengths = []
    adv_class = []
    img_idx = 0
    processed_images = 0
    while processed_images < num_images * num_adv_directions:
        correct_class = int(data_dict['labels'][img_idx])
        model_prediction = torch.argmax(model(torchify(origin[None, ...]))).item()
        if correct_class == model_prediction:
            adv_dir_idx = 0
            search_index = 0
            sub_adv_class = []
            sub_advs = []
            sub_dirs = []
            sub_pert_lengths = []
            while adv_dir_idx < num_adv_directions:
                alt_index = shuffled_indices[search_index]
                alt_class = int(labels_nat[alt_index])
                alt_img = data_dict['images'][alt_index, ...][None,...]
                model_prediction = torch.argmax(model(torchify(alt_img))).item()
                search_index += 1
                while (correct_class == alt_class) and (alt_class == model_prediction):
                    alt_index = shuffled_indices[search_index]
                    alt_class = int(labels_nat[alt_index])
                    alt_img = data_dict['images'][alt_index, ...][None,...]
                    model_prediction = torch.argmax(model(torchify(alt_img))).item()
                    search_index += 1
                alt_img = alt_img[0,...]
                boundary_image, boundary_dir, pert_length = get_paired_boundary_image(model_natural, origin, alt_img, num_steps_per_iter, num_iters)
                sub_adv_class.append(torch.argmax(model(torchify(boundary_image[None, ...]))).item())
                sub_advs.append(boundary_image.reshape(1, -1))
                sub_dirs.append(boundary_dir.reshape(1, -1))
                sub_pert_lengths.append(pert_length)
                adv_dir_idx += 1
                processed_images += 1
            adv_class.append(np.stack(sub_adv_class, axis=0))
            advs.append(np.stack(sub_advs, axis=0))
            dirs.append(np.stack(sub_dirs, axis=0))
            pert_lengths.append(np.stack(sub_pert_lengths, axis=0))
        img_idx += 1
    output_dict['dirs'] = np.stack(dirs, axis=0)
    output_dict['advs'] = np.stack(advs, axis=0)
    output_dict['pert_lengths'] = np.stack(pert_lengths, axis=0)
    output_dict['adv_class'] = np.stack(adv_class, axis=0)
    return output_dict

data_paired_natural = generate_paired_dict(data_natural, model_natural, num_images)
data_paired_madry = generate_paired_dict(data_madry, model_madry, num_images)

In [None]:
paired_condition_zip = zip([model_natural, model_madry], [data_paired_natural, data_paired_madry])
adv_condition_zip = zip([model_natural, model_madry], [data_natural, data_madry])

adv_shape_operators, adv_principal_curvatures, adv_principal_directions, adv_mean_curvatures = get_curvature(
    adv_condition_zip, num_images, num_advs, num_eps, batch_size, buffer_portion, hess_radius_mult, autodiff)
paired_shape_operators, paired_principal_curvatures, paired_principal_directions, paired_mean_curvatures = get_curvature(
    paired_condition_zip, num_images, num_advs, num_eps, batch_size, buffer_portion, hess_radius_mult, autodiff)

#np.savez('../../outputs/mean_curvatures_sr1.npz', data=mean_curvatures)

In [None]:
colors = ['blue', 'orange']
bar_width = 0.5
fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True)
for data_idx in range(2):
    mean_curvatures = [paired_mean_curvatures, adv_mean_curvatures][data_idx]
    for model_idx in range(2):
        boxprops = dict(color=colors[model_idx], linewidth=1.5, alpha=0.7)
        whiskerprops = dict(color=colors[model_idx], alpha=0.7)
        capprops = dict(color=colors[model_idx], alpha=0.7)
        medianprops = dict(linestyle='--', linewidth=0.5, color=colors[model_idx])
        meanpointprops = dict(marker='o', markeredgecolor='black',
                              markerfacecolor=colors[model_idx])
        meanprops = dict(linestyle='-', linewidth=0.5, color=colors[model_idx])
        data = mean_curvatures[model_idx, :, :].reshape(-1)
        axs[data_idx].boxplot(data, sym='', positions=[model_idx], whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,
            whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops, meanprops=meanprops)
    axs[data_idx].set_xticks([0, 1], minor=False)
    axs[data_idx].set_xticks([], minor=True)
    axs[data_idx].set_xticklabels(['Naturally trained', 'Adversarially trained'])
    if data_idx == 0:
        axs[data_idx].set_ylabel('Mean curvature')
        axs[data_idx].set_title('Paired image boundary')
    else:
        axs[data_idx].set_title('Adversarial image boundary')
fig.suptitle(f'Curvature at the decision boundary\nfor the {num_images} images and the first {num_advs} adversarial directions')
plt.show()

In [None]:
### Mean vs dim number
colors = ['blue', 'orange']
bar_width = 0.5
fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True)
for data_idx in range(2):
    mean_curvatures = [paired_mean_curvatures, adv_mean_curvatures][data_idx]
    for model_idx in range(mean_curvatures.shape[0]):
        boxprops = dict(color=colors[model_idx], linewidth=1.5, alpha=0.7)
        whiskerprops = dict(color=colors[model_idx], alpha=0.7)
        capprops = dict(color=colors[model_idx], alpha=0.7)
        medianprops = dict(linestyle='--', linewidth=0.5, color=colors[model_idx])
        meanpointprops = dict(marker='o', markeredgecolor='black',
                              markerfacecolor=colors[model_idx])
        meanprops = dict(linestyle='-', linewidth=0.5, color=colors[model_idx])
        for adv_idx in range(mean_curvatures.shape[-1]):
            data = mean_curvatures[model_idx, :, adv_idx].reshape(-1)
            axs[data_idx].boxplot(data, sym='', positions=[adv_idx], whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,
                whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops, meanprops=meanprops)
    if data_idx == 0:
        axs[data_idx].set_ylabel('Mean curvature')
        axs[data_idx].set_title('Paired image boundary')
    else:
        axs[data_idx].set_title('Adversarial image boundary')
    axs[data_idx].set_xlabel('Dimension number')
    axs[data_idx].set_xticks([i for i in range(mean_curvatures.shape[-1])], minor=False)
    axs[data_idx].set_xticks([], minor=True)
    axs[data_idx].set_xticklabels([str(i) for i in range(mean_curvatures.shape[-1])])
fig.suptitle(f'Curvature at the decision boundary for {num_images} images and the first {num_advs} adversarial directions')
plt.show()