In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams.update({
    'font.size': 15, 'lines.linewidth': 2,
    'xtick.labelsize': 13, 'ytick.labelsize': 13,
    'axes.spines.top': False, 'axes.spines.right': False,
    'savefig.dpi': 1200,
})

import numpy as np
from numpy.fft import fft2, fftshift, fftfreq
import torch

store_dir, datasets_dir = 'store', 'datasets'

# Gather attack results

In [2]:
from roarena.attack import AttackManager
from jarvis.vision import prepare_datasets
from jarvis.utils import time_str
import time, os

a_manager = AttackManager(store_dir, datasets_dir, read_only=True)
dataset = prepare_datasets('CIFAR10', datasets_dir)

model_paths = {
    'Base-Res': 'store/exported/chenyaofo_resnet56.pt',
    'Base-VGG': 'store/exported/chenyaofo_vgg19_bn.pt',
    'Base-MobNet': 'store/exported/chenyaofo_mobilenetv2_x1_4.pt',
    'Base-ShufNet': 'store/exported/chenyaofo_shufflenetv2_x2_0.pt',
    'Base-RepVGG': 'store/exported/chenyaofo_repvgg_a2.pt',
}

filename = 'store/figs-data/CIFAR10_extra_baseline_advs.pt'
if os.path.exists(filename):
    saved = torch.load(filename)
    s_idxs = saved['s_idxs']
    imgs = saved['imgs']
    dists = saved['dists']
    advs = saved['advs']
    logps = saved['logps']
else:
    metric, targeted, shuffle_mode, shuffle_tag = 'LI', True, 'elm', 0

    s_idxs, imgs, dists, advs, logps = {}, {}, {}, {}, {}
    for tag, model_path in model_paths.items():
        print(f'Gathering attack results for {model_path}...')
        tic = time.time()
        s_idxs[tag], imgs[tag], dists[tag], advs[tag] = [], [], [], []
        for sample_idx in range(1000):
            try:
                _min_dists, _, _advs = a_manager.best_attack(
                    model_path, metric, targeted, shuffle_mode, shuffle_tag, sample_idx,
                    min_probs=[0.1], return_advs=True,
                )
            except:
                continue
            s_idxs[tag].append(sample_idx)
            imgs[tag].append(dataset[sample_idx][0].numpy())
            dists[tag].append(_min_dists[0])
            advs[tag].append(_advs[0])
        s_idxs[tag] = np.array(s_idxs[tag])
        imgs[tag] = np.array(imgs[tag])
        dists[tag] = np.array(dists[tag])
        advs[tag] = np.array(advs[tag])
        toc = time.time()
        print('{} images gathered ({})'.format(len(s_idxs[tag]), time_str(toc-tic)))

        diffs = imgs[tag]-advs[tag]
        diffs -= diffs.mean(axis=(2, 3), keepdims=True)
        powers = np.abs(fft2(diffs))**2
        powers[..., 0, 0] = np.nan
        logps[tag] = fftshift(np.log(powers.mean(axis=(0, 1))))

    torch.save({
        'model_paths': model_paths,
        's_idxs': s_idxs, 'imgs': imgs, 'dists': dists, 'advs': advs,
        'logps': logps,
    }, filename)

# Gather Ein-Mon test results

In [3]:
from roarena.corruption import CorruptionManager
from roarena.corruption import CORRUPTIONS
from roarena.einmon import EinMonManager

c_manager = CorruptionManager(store_dir, datasets_dir, read_only=True)
e_manager = EinMonManager(store_dir, datasets_dir, read_only=True)

filename = 'store/figs-data/CIFAR10_extra_baseline_einmon_accs.pt'

accs, accs_s, accs_c = {}, {}, {}
for tag, model_path in model_paths.items():
    accs[tag] = np.empty((5, len(CORRUPTIONS)))
    accs_s[tag] = dict((s, []) for s in range(1, 6))
    accs_c[tag] = dict((c, []) for c in CORRUPTIONS)
    for i, severity in enumerate(range(1, 6)):
        for j, corruption in enumerate(CORRUPTIONS):
            config = {'model_path': model_path, 'severity': severity, 'corruption': corruption}
            key = c_manager.configs.get_key(config)
            accs[tag][i, j] = c_manager.previews[key]['acc']
            accs_s[tag][severity].append(accs[tag][i, j])
            accs_c[tag][corruption].append(accs[tag][i, j])
    accs[tag] = np.mean(accs[tag])

alphas, accs_low, accs_high = {}, {}, {}
for tag, model_path in model_paths.items():
    alphas[tag] = [0, 7, 12, 16, 19, 22, 25, 32, 42, 57, 70, 85, 100]
    accs_low[tag] = []
    accs_high[tag] = []
    for alpha in alphas[tag]:
        _accs_low, _accs_high = [], []
        for key in e_manager.completed(cond={'model_path': model_path, 'alpha': alpha}):
            _accs_low.append(e_manager.previews[key]['acc_low'])
            _accs_high.append(e_manager.previews[key]['acc_high'])
        accs_low[tag].append(_accs_low)
        accs_high[tag].append(_accs_high)
torch.save({
    'model_paths': model_paths,
    'alphas': alphas, 'accs_low': accs_low, 'accs_high': accs_high, 'accs': accs,
    'accs_c': accs_c, 'accs_s': accs_s,
}, filename)