In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams.update({
    'font.size': 15, 'lines.linewidth': 2,
    'xtick.labelsize': 13, 'ytick.labelsize': 13,
    'axes.spines.top': False, 'axes.spines.right': False,
    'savefig.dpi': 1200,
})

import numpy as np
import torch

In [2]:
model_paths = [
    'store/exported/BlurResNet18_[1.5]_A73EED68.pt',
    'store/exported/BlurResNet18_[1.5]_1E1A0C95.pt',
    'store/exported/BlurResNet18_[1.5]_2A9BC19E.pt',
    'store/exported/BlurResNet18_[1.5]_CCA1DE26.pt',
    'store/exported/BlurResNet18_[1.5]_78B6B7CD.pt',
]

In [3]:
from roarena.einmon import EinMonJob

e_job = EinMonJob('store/e-tests', 'datasets', read_only=True)

In [10]:
f_revs = {}
for model_path in model_paths:
    tag = model_path.split('_')[-1].split('.')[0]
    f_revs[tag] = []
    for seed in range(6):
        alphas, accs = [], []
        for key in e_job.completed(cond={'model_path': model_path, 'seed': seed}):
            alphas.append(e_job.configs[key]['alpha'])
            ckpt = e_job.ckpts[key]
            accs.append(ckpt['acc_low']-ckpt['acc_high'])
        alphas, accs = zip(*sorted(zip(alphas, accs), key=lambda x:x[0]))
        f_revs[tag].append(np.interp(0, accs, alphas))

torch.save(f_revs, 'store/figs-data/CIFAR10_reverse-freqs.pt')

In [6]:
from jarvis.vision import prepare_datasets
from roarena.attack import AttackJob

dataset = prepare_datasets('CIFAR10', 'datasets')
imgs = []
for i in range(1000):
    img, _ = dataset[i]
    imgs.append(img.numpy())
imgs = np.stack(imgs)

a_job = AttackJob('store/a-tests', 'datasets', read_only=True)

In [7]:
import time

for model_path in model_paths:
    print(f"Collecting advs for {model_path}...")
    tic = time.time()
    dists = []
    advs = []
    for sample_idx in range(1000):
        _min_dists, _, _advs = a_job.best_attack(model_path, 'LI', True, 'elm', 0, sample_idx=sample_idx, min_probs=[0.5], return_advs=True)
        dists.append(_min_dists[0])
        advs.append(_advs[0])
    dists = np.array(dists)
    advs = np.stack(advs)
    toc = time.time()
    print("{:.1f} secs".format(toc-tic))
    
    file_name = 'store/figs-data/CIFAR10-advs_{}.pt'.format(model_path.split('_')[-1].split('.')[0])
    torch.save({'dists': dists, 'advs': advs}, file_name)

Collecting advs for store/exported/BlurResNet18_[1.5]_A73EED68.pt...
3778.3 secs
Collecting advs for store/exported/BlurResNet18_[1.5]_1E1A0C95.pt...
3614.3 secs
Collecting advs for store/exported/BlurResNet18_[1.5]_2A9BC19E.pt...
3596.6 secs
Collecting advs for store/exported/BlurResNet18_[1.5]_CCA1DE26.pt...
3849.9 secs
Collecting advs for store/exported/BlurResNet18_[1.5]_78B6B7CD.pt...
3688.6 secs
