In [12]:
import torch
import numpy as np

In [1]:
from roarena.corruption import CorruptionJob
from roarena.corruption import CORRUPTIONS

job = CorruptionJob('store', 'datasets', readonly=True)

In [2]:
for animal in ['mouse', 'monkey']:
    if animal=='mouse':
        model_paths = {
            'base': 'store/models/CIFAR10-G_Baseline.pt',
            'neural': 'store/models/CIFAR10-G_Mouse.pt',
        }
    if animal=='monkey':
        model_paths = {
            'base': 'store/models/VGG19_Baseline.pt',
            'neural': 'store/models/VGG19_MTL.pt',
        }
    
    accs = {}
    for label, model_path in model_paths.items():
        saved = torch.load(model_path)
        accs[label] = {0: saved['acc']}
        for severity in range(1, 6):
            accs[label][severity] = []
            for corruption in CORRUPTIONS:
                config = {
                    'model_pth': model_path,
                    'severity': severity,
                    'corruption': corruption,
                }
                key = job.configs.get_key(config)
                accs[label][severity].append(job.previews[key]['acc'])
    torch.save(accs, f'store/figs-data/{animal}_corruption.pt')

In [48]:
from roarena.attack import AttackJob
from jarvis.vision import prepare_datasets

job = AttackJob('store', 'datasets', readonly=True)

metric = 'LI'
targeted = True

for animal in ['mouse', 'monkey']:
    if animal=='mouse':
        model_paths = {
            'base': 'store/models/CIFAR10-G_Baseline.pt',
            'neural': 'store/models/CIFAR10-G_Mouse.pt',
        }
        dataset = prepare_datasets('CIFAR10-Gray', 'datasets')
    if animal=='monkey':
        model_paths = {
            'base': 'store/models/VGG19_Baseline.pt',
            'neural': 'store/models/VGG19_MTL.pt',
        }
        dataset = prepare_datasets('TinyImageNet-Gray', 'datasets')
    loader = torch.utils.data.DataLoader(dataset, batch_size=20)

    imgs, advs = {}, {}
    for label, model_path in model_paths.items():
        print(f'gathering advs for {model_path}...')
        imgs[label] = []
        advs[label] = []
        for b_idx, (_imgs, _) in enumerate(loader):
            if b_idx>=50:
                break

            cond = {
                'model_pth': model_path,
                'batch_idx': b_idx,
                'metric': metric, 'targeted': targeted,
                'name': 'BB',
            }
            _advs, _dists = [], []
            for key, config in job.conditioned(cond):
                result = job.results[key]
                _advs.append(result['advs'])
                _dists.append(result['dists'])
            if len(_dists)==0:
                continue
            imgs[label].append(_imgs)

            _advs = np.array(_advs)
            _dists = np.array(_dists)
            idxs = np.argmin(_dists, axis=0)
            advs[label].append(np.array([_advs[idx, i] for i, idx in enumerate(idxs)]))

            if (b_idx+1)%10==0:
                print('{:2d} batches loaded'.format(b_idx+1))
        imgs[label] = np.concatenate(imgs[label])
        advs[label] = np.concatenate(advs[label])
    torch.save({'imgs': imgs, 'advs': advs}, f'store/figs-data/{animal}_advs.pt')

gathering advs for store/models/CIFAR10-G_Baseline.pt...
gathering advs for store/models/CIFAR10-G_Mouse.pt...
gathering advs for store/models/VGG19_Baseline.pt...
gathering advs for store/models/VGG19_MTL.pt...
