In [None]:
import numpy as np
import seml.database as db_utils
import torch

from localized_smoothing.segmentation.eval import (
        calc_certified_pixel_accuracy_naive,
        calc_certified_ratios_naive,
        calc_pixel_accuracy,
        calc_mean_iou,
        calc_certified_ratios_collective,
        calc_certified_pixel_accuracy_center,
        calc_certified_ratios_center,
        calc_certified_pixel_accuracy_collective)

from itertools import product

import matplotlib.pyplot as plt

import sys

import seaborn as sns

sys.path.append('../../')

from utils import load_results


In [None]:
collection = 'cert_images_pascal_masked'


jk_config = {
    'username': 'your_username',
    'password': 'your_password',
    'host': 'host_ip',
    'port': 27017,
    'db_name': 'your_db_name'
}

col = db_utils.get_collection(collection, mongodb_config=jk_config)


In [None]:
def area_under_curve(x,
                     y,
                     pre=True):
    if pre:
        return np.diff(x) @ y[1:]
    else:
        return np.diff(x) @ y[:-1]

In [None]:
def get_experiments(col, restrictions={}):
    
    restrictions['status'] = 'COMPLETED'

    if col.count_documents(restrictions) == 0:
        raise ValueError('No matches!')

    exps = col.find(restrictions, {'result': 1, 'stats': 1, 'host': 1})
    
    return exps

In [None]:
def get_results_dict_iid(exp, cert_type='argmax_holm', abstain=True, n_images=100, n_classes=21):
    res = torch.load(exp['result']['cert_file'])
    config = res.pop('config')
    budgets = res['budgets']


    results_dict = {
        'std': config['distribution_params']['std_min'],
        'grid_height': config['distribution_params']['grid_shape'][0],
        'grid_width': config['distribution_params']['grid_shape'][1],
        'n_samples_pred': config['sample_params']['n_samples_pred'],
        'n_samples_cert': config['sample_params']['n_samples_cert'],
    }

    results_dict.update({
        'accuracy': calc_pixel_accuracy(res, cert_type, False, abstain, n_images, n_classes),
        'iou': calc_mean_iou(res, cert_type, False, abstain=abstain, n_images=n_images),
        'budgets': budgets,
        'certified_ratios': calc_certified_ratios_naive(res, cert_type, n_images, n_classes),
        'certified_accuracies': calc_certified_pixel_accuracy_naive(res, cert_type, n_images, n_classes),
        #'certified_ratios_center': calc_certified_ratios_center(res, n_pixels=(166*250)),
        #'certified_accuracies_center': calc_certified_pixel_accuracy_center(res, n_pixels=(166*250), n_images=n_images),
        'time': exp['stats']['real_time'],
        'vram': exp['host']['gpus']['gpus'][0]['total_memory']
    })

    for metric, pre in product(['ratios', 'accuracies'], [True, False]):

        results_dict[f'auc_{metric}_{"pre" if pre else "post"}'] = area_under_curve(
            results_dict['budgets'],
            results_dict[f'certified_{metric}'],
            pre
        )

    return results_dict

In [None]:
def get_result_dicts_center(exp, n_images=100, n_classes=21):
    res = torch.load(exp['result']['cert_file'])
    config = res.pop('config')
    budgets = res['budgets']


    results_dict = {
        'std': config['distribution_params']['std_min'],
        'grid_height': config['distribution_params']['grid_shape'][0],
        'grid_width': config['distribution_params']['grid_shape'][1],
        'n_samples_pred': config['sample_params']['n_samples_pred'],
        'n_samples_cert': config['sample_params']['n_samples_cert'],
    }

    results_dict.update({
        'iou': calc_mean_iou(res, 'center_bonferroni', True, abstain=False, n_images=n_images),
        'budgets': budgets,
        'certified_accuracies': calc_certified_pixel_accuracy_center(res, n_images=n_images, n_classes=n_classes),
        'time': exp['stats']['real_time'],
        'vram': exp['host']['gpus']['gpus'][0]['total_memory']
    })

    for metric, pre in product(['accuracies'], [True, False]):

        results_dict[f'auc_{metric}_{"pre" if pre else "post"}'] = area_under_curve(
            results_dict['budgets'],
            results_dict[f'certified_{metric}'],
            pre
        )

    return results_dict

In [None]:
def get_results_dict_collective(exp, cert_type='argmax_holm', abstain=True, n_images=100, n_classes=21, store_accumulate_gradients=True):
    
    res = torch.load(exp['result']['cert_file'])
    config = res.pop('config')
    budgets = res['budgets']

    results_dict = {
        'std': config['distribution_params']['std_min'],
        'std_max': config['distribution_params']['std_max'],
        'grid_height': config['distribution_params']['grid_shape'][0],
        'grid_width': config['distribution_params']['grid_shape'][1],
        'n_samples_pred': config['sample_params']['n_samples_pred'],
        'n_samples_cert': config['sample_params']['n_samples_cert'],
    }

    if store_accumulate_gradients:
        results_dict['acc_grads'] = config['train_loading']['restrictions']['training_params']['accumulate_gradients']
    results_dict.update({
        'accuracy': calc_pixel_accuracy(res, cert_type, False, abstain, n_images, n_classes),
        'iou': calc_mean_iou(res, cert_type, False, abstain=abstain, n_images=n_images),
        'budgets': budgets,
        'certified_ratios_all': calc_certified_ratios_collective(res, cert_type, True, False, n_images, n_classes),
        'certified_ratios_correct': calc_certified_ratios_collective(res, cert_type, False, True, n_images, n_classes),
        'certified_accuracies_all': calc_certified_pixel_accuracy_collective(res, cert_type, True, False, n_images, n_classes),
        'certified_accuracies_correct': calc_certified_pixel_accuracy_collective(res, cert_type, False, True, n_images, n_classes),
        'time': exp['stats']['real_time'],
        'vram': exp['host']['gpus']['gpus'][0]['total_memory']
    })

    for metric, subset, pre in product(['ratios', 'accuracies'], ['all', 'correct'], [True, False]):

        results_dict[f'auc_{metric}_{subset}_{"pre" if pre else "post"}'] = area_under_curve(
            results_dict['budgets'],
            results_dict[f'certified_{metric}_{subset}'],
            pre
        )

    return results_dict


In [None]:
exps = get_experiments(col, {'config.distribution_params.mask_distance': 1, 'config.certification_params.base_certs': ['argmax_holm']})
results_iid = load_results(get_results_dict_iid, exps, './data/pascal_masked_iid', overwrite=False)  # Set to True if you want to use your own results
print(len(list(exps)))

In [None]:
exps = get_experiments(col, {'config.distribution_params.mask_distance': 1, 'config.certification_params.base_certs': ['argmax_holm']})
results_collective = load_results(get_results_dict_collective, exps, './data/pascal_masked_collective', overwrite=False)  # Set to True if you want to use your own results
print(len(list(exps)))

In [None]:
exps_center = get_experiments(col, {'config.distribution_params.mask_distance': 1, 'config.certification_params.base_certs': [],
                                     'config.certification_params.naive_certs': ['center_independent', 'center_bonferroni']})
results_center = load_results(get_result_dicts_center, exps_center, './data/pascal_masked_center', overwrite=False)  # Set to True if you want to use your own results
print(len(list(exps_center)))

In [None]:
results_iid

In [None]:
results_collective

In [None]:
results_center

In [None]:
np.unique(results_collective['std'])

In [None]:
def filter_dominated(ious, cert_accs):
    iou_worse = ious[:, np.newaxis] < ious

    cert_accs_worse = cert_accs[:, np.newaxis] < cert_accs

    worse = iou_worse & cert_accs_worse

    dominated = np.any(worse, axis=1)

    return ious[~dominated], cert_accs[~dominated]

In [None]:
def plot(results_iid, results_center, results_collective, n_samples_iid, n_samples_center, n_samples_collective, std_min=None, pareto=True):
    plt.clf()
    plt.cla()
    #plt.figure(facecolor='white')

    pal = sns.color_palette('colorblind', 3)

    ious_collective = results_collective['iou'].to_numpy()
    if std_min is not None:
        ious_collective = np.append(ious_collective, results_iid.loc[results_iid['std'] == std_min]['iou'].iloc[0])

    aucs_collective = results_collective['auc_accuracies_correct_post'].to_numpy()
    if std_min is not None:
        aucs_collective = np.append(aucs_collective, results_iid.loc[results_iid['std'] == std_min]['auc_accuracies_post'].iloc[0])

    if pareto:
        ious_collective, aucs_collective = filter_dominated(ious_collective, aucs_collective)

    plt.scatter(ious_collective, aucs_collective, label='Localized LP', marker='.', s=20, color=pal[0])

    # center

    results_center = results_center.loc[results_center['n_samples_pred'] == n_samples_center]

    iid_center, auc_center = results_center['iou'].to_numpy(), results_center['auc_accuracies_post'].to_numpy()

    if pareto:
        iid_center, auc_center = filter_dominated(iid_center, auc_center)

    plt.scatter(iid_center, auc_center, label='CenterSmooth', s=20, marker='x', color=pal[2])

    # iid

    results_iid = results_iid.loc[results_iid['n_samples_pred'] == n_samples_iid]

    results_collective = results_collective.loc[results_collective['n_samples_pred'] == n_samples_collective]

    iid_iou, auc_iou = results_iid['iou'].to_numpy(), results_iid['auc_accuracies_post'].to_numpy()

    if pareto:
        iid_iou, auc_iou = filter_dominated(iid_iou, auc_iou)

    plt.scatter(iid_iou, auc_iou, label='SegCertify$^*$', s=20, marker='x', color=pal[1])

    
    plt.xlabel('mIOU')
    plt.ylabel('Avg. cert. radius')
    #plt.title(f'std_min = {std_min}')
    plt.legend()


In [None]:
# 3 x 5, 820 samples
std_mins = np.sort(list(set(results_collective['std'])))
print(std_mins)

plot(results_iid.loc[results_iid['grid_height'] != results_iid['grid_width']], 
     results_center.loc[results_center['grid_height'] != results_center['grid_width']], 
     results_collective.loc[results_collective['grid_height'] != results_collective['grid_width']], 820, 820, 820)

In [None]:
# 2x2, 820 samples

std_mins = np.sort(list(set(results_collective['std'])))
print(std_mins)

plot(results_iid.loc[results_iid['grid_height'] == results_iid['grid_width']], 
     results_center.loc[results_center['grid_height'] == results_center['grid_width']], 
     results_collective.loc[results_collective['grid_height'] == results_collective['grid_width']], 820, 820, 820)

In [None]:
# 2x2, 3072 samples

std_mins = np.sort(list(set(results_collective['std'])))
print(std_mins)

plot(results_iid.loc[results_collective['grid_height'] == results_collective['grid_width']], 
     results_center.loc[results_center['grid_height'] == results_center['grid_width']], 
     results_collective.loc[results_collective['grid_height'] == results_collective['grid_width']], 3072, 3072, 3072)