In [None]:
import numpy as np
import seml.database as db_utils
import torch

from localized_smoothing.segmentation.eval import (
        calc_certified_pixel_accuracy_naive,
        calc_certified_ratios_naive,
        calc_pixel_accuracy,
        calc_mean_iou,
        calc_certified_ratios_collective,
        calc_certified_pixel_accuracy_center,
        calc_certified_ratios_center,
        calc_certified_pixel_accuracy_collective)

from itertools import product

import matplotlib.pyplot as plt

import sys

import seaborn as sns



sys.path.append('../../')

from utils import load_results

import os

In [None]:
collection = 'cert_images_pascal_localized_training'
collection_iid = 'cert_images_pascal_iid'

jk_config = {
    'username': 'your_username',
    'password': 'your_password',
    'host': 'host_ip',
    'port': 27017,
    'db_name': 'your_db_name'
}

col = db_utils.get_collection(collection, mongodb_config=jk_config)
col_iid = db_utils.get_collection(collection_iid, mongodb_config=jk_config)


In [None]:
def area_under_curve(x,
                     y,
                     pre=True):
    if pre:
        return np.diff(x) @ y[1:]
    else:
        return np.diff(x) @ y[:-1]

In [None]:
def get_experiments(col, restrictions={}):
    
    restrictions['status'] = 'COMPLETED'

    if col.count_documents(restrictions) == 0:
        raise ValueError('No matches!')

    exps = col.find(restrictions, {'result': 1, 'stats': 1, 'host': 1})
    
    return exps

In [None]:
def get_results_dict_iid(exp, cert_type='argmax_holm', abstain=True, n_images=100, n_classes=21):
    res = torch.load(exp['result']['cert_file'])
    config = res.pop('config')
    budgets = res['budgets']


    results_dict = {
        'std': config['distribution_params']['std_min'],
        'grid_height': config['distribution_params']['grid_shape'][0],
        'grid_width': config['distribution_params']['grid_shape'][1],
        'n_samples_pred': config['sample_params']['n_samples_pred'],
        'n_samples_cert': config['sample_params']['n_samples_cert'],
    }

    results_dict.update({
        'accuracy': calc_pixel_accuracy(res, cert_type, False, abstain, n_images, n_classes),
        'iou': calc_mean_iou(res, cert_type, False, abstain=abstain, n_images=n_images),
        'iou_center': calc_mean_iou(res, 'center_bonferroni', True, abstain=False, n_images=n_images),
        'budgets': budgets,
        'certified_ratios': calc_certified_ratios_naive(res, cert_type, n_images, n_classes),
        'certified_accuracies': calc_certified_pixel_accuracy_naive(res, cert_type, n_images, n_classes),
        'certified_ratios_center': calc_certified_ratios_center(res, n_pixels=(166*250)),
        'certified_accuracies_center': calc_certified_pixel_accuracy_center(res, n_images=n_images),
        'time': exp['stats']['real_time'],
        'vram': exp['host']['gpus']['gpus'][0]['total_memory']
    })

    for metric, pre in product(['ratios', 'accuracies', 'ratios_center', 'accuracies_center'], [True, False]):

        results_dict[f'auc_{metric}_{"pre" if pre else "post"}'] = area_under_curve(
            results_dict['budgets'],
            results_dict[f'certified_{metric}'],
            pre
        )

    return results_dict

In [None]:
def get_results_dict_collective(exp, cert_type='argmax_holm', abstain=True, n_images=100, n_classes=21, store_accumulate_gradients=True):
    
    res = torch.load(exp['result']['cert_file'])
    config = res.pop('config')
    budgets = res['budgets']

    results_dict = {
        'std': config['distribution_params']['std_min'],
        'std_max': config['distribution_params']['std_max'],
        'grid_height': config['distribution_params']['grid_shape'][0],
        'grid_width': config['distribution_params']['grid_shape'][1],
        'n_samples_pred': config['sample_params']['n_samples_pred'],
        'n_samples_cert': config['sample_params']['n_samples_cert'],
    }

    if store_accumulate_gradients:
        results_dict['acc_grads'] = config['train_loading']['restrictions']['training_params']['accumulate_gradients']
    results_dict.update({
        'accuracy': calc_pixel_accuracy(res, cert_type, False, abstain, n_images, n_classes),
        'iou': calc_mean_iou(res, cert_type, False, abstain=abstain, n_images=n_images),
        'budgets': budgets,
        'certified_ratios_all': calc_certified_ratios_collective(res, cert_type, True, False, n_images, n_classes),
        'certified_ratios_correct': calc_certified_ratios_collective(res, cert_type, False, True, n_images, n_classes),
        'certified_accuracies_all': calc_certified_pixel_accuracy_collective(res, cert_type, True, False, n_images, n_classes),
        'certified_accuracies_correct': calc_certified_pixel_accuracy_collective(res, cert_type, False, True, n_images, n_classes),
        'naive_certified_ratios': calc_certified_ratios_naive(res, cert_type, n_images, n_classes),
        'naive_certified_accuracies': calc_certified_pixel_accuracy_naive(res, cert_type, n_images, n_classes),
        'time': exp['stats']['real_time'],
        'vram': exp['host']['gpus']['gpus'][0]['total_memory']
    })

    for metric, subset, pre in product(['ratios', 'accuracies'], ['all', 'correct'], [True, False]):

        results_dict[f'auc_{metric}_{subset}_{"pre" if pre else "post"}'] = area_under_curve(
            results_dict['budgets'],
            results_dict[f'certified_{metric}_{subset}'],
            pre
        )

    return results_dict


In [None]:
exps_iid = get_experiments(col_iid, {'config.sample_params.n_samples_pred': {'$ne': 512}})
results_iid = load_results(get_results_dict_iid, exps_iid, './data/pascal_iid', overwrite=False)  # Set to True if you want to load your own data

In [None]:
exps = get_experiments(col)
results_collective = load_results(get_results_dict_collective, exps, './data/pascal_collective_locally_trained', overwrite=False)  # Set to True if you want to load your own data

In [None]:
results_iid = results_iid.loc[results_iid['n_samples_pred'] == 820]
results_iid = results_iid.reset_index(drop=True)
results_iid

In [None]:
results_collective

In [None]:
def plot(budgets_iid, cert_acc_iid, budgets_collective, cert_acc_collective, cert_acc_collective_naive):
    plt.clf()
    plt.cla()

    pal = sns.color_palette('colorblind', 2)

    plt.plot(budgets_collective, cert_acc_collective, label='Localized LP', color=pal[0])

    plt.plot(budgets_collective, cert_acc_collective_naive, label='Localized Naïve', color=pal[0], linestyle='--')

    plt.plot(budgets_iid, cert_acc_iid, label='SegCertify$^*$', color=pal[1])


    xlim = max(budgets_iid[cert_acc_iid == 0][0], budgets_collective[cert_acc_collective == 0][0])

    plt.xlim(0, xlim * 1.1)

    plt.legend()

    plt.xlabel('Adversarial budget $\epsilon$')
    plt.ylabel('Certified accuracy')

In [None]:
results_iid = results_iid.loc[results_iid['std'] == 0.2]

std, budgets_iid, cert_acc_iid = results_iid[['std', 'budgets', 'certified_accuracies']].values[0]

results_collective = results_collective.loc[(results_collective['std'] == 0.15) & (results_collective['std_max'] == 1.0)]

std_min, std_max, budgets_collective, cert_acc_collective, cert_acc_collective_naive = results_collective[['std', 'std_max', 'budgets', 'certified_accuracies_correct', 'naive_certified_accuracies']].values[0]

sns.set()


std_mins = np.sort(list(set(results_collective['std'])))



plot(budgets_iid, cert_acc_iid, budgets_collective, cert_acc_collective, cert_acc_collective_naive)


#mplt.savefig(f'./figures/0_2_vs_0_15_1_0', format='pgf', preview='png', dpi=512, tight={'pad': 0.5})
#mplt.savefig(f'./figures/0_2_vs_0_15_1_0', format='pdf', preview='png', dpi=512, tight={'pad': 0.5})


