In [None]:
import numpy as np
import seml.database as db_utils
import torch

from localized_smoothing.segmentation.eval import (
        calc_certified_pixel_accuracy_naive,
        calc_certified_ratios_naive,
        calc_pixel_accuracy,
        calc_mean_iou,
        calc_certified_ratios_collective,
        calc_certified_pixel_accuracy_center,
        calc_certified_ratios_center,
        calc_certified_pixel_accuracy_collective)

from itertools import product

import matplotlib.pyplot as plt

import sys

import seaborn as sns


sys.path.append('../../')

from utils import load_results

In [None]:
collection = 'cert_images_pascal_localized_training'
collection_iid = 'cert_images_pascal_iid'

jk_config = {
    'username': 'your_username',
    'password': 'your_password',
    'host': 'host_ip',
    'port': 27017,
    'db_name': 'your_db_name'
}

col = db_utils.get_collection(collection, mongodb_config=jk_config)
col_iid = db_utils.get_collection(collection_iid, mongodb_config=jk_config)


In [None]:
def area_under_curve(x,
                     y,
                     pre=True):
    if pre:
        return np.diff(x) @ y[1:]
    else:
        return np.diff(x) @ y[:-1]

In [None]:
def get_experiments(col, restrictions={}):
    
    restrictions['status'] = 'COMPLETED'

    if col.count_documents(restrictions) == 0:
        raise ValueError('No matches!')

    exps = col.find(restrictions, {'result': 1, 'stats': 1, 'host': 1})
    
    return exps

In [None]:
def get_results_dict_iid(exp, cert_type='argmax_holm', abstain=True, n_images=100, n_classes=21, time_only=True):
    res = torch.load(exp['result']['cert_file'])
    config = res.pop('config')
    budgets = res['budgets']


    results_dict = {
        'std': config['distribution_params']['std_min'],
        'grid_height': config['distribution_params']['grid_shape'][0],
        'grid_width': config['distribution_params']['grid_shape'][1],
        'n_samples_pred': config['sample_params']['n_samples_pred'],
        'n_samples_cert': config['sample_params']['n_samples_cert'],
    }

    if time_only:
        results_dict.update({
        'time': exp['stats']['real_time'],
        'vram': exp['host']['gpus']['gpus'][0]['total_memory']
    })
        
    else:

        results_dict.update({
            'accuracy': calc_pixel_accuracy(res, cert_type, False, abstain, n_images, n_classes),
            'iou': calc_mean_iou(res, cert_type, False, abstain=abstain, n_images=n_images),
            'iou_center': calc_mean_iou(res, 'center_bonferroni', True, abstain=False, n_images=n_images),
            'budgets': budgets,
            'certified_ratios': calc_certified_ratios_naive(res, cert_type, n_images, n_classes),
            'certified_accuracies': calc_certified_pixel_accuracy_naive(res, cert_type, n_images, n_classes),
            'certified_ratios_center': calc_certified_ratios_center(res, n_pixels=(166*250)),
            'certified_accuracies_center': calc_certified_pixel_accuracy_center(res, n_images=n_images),
            'time': exp['stats']['real_time'],
            'vram': exp['host']['gpus']['gpus'][0]['total_memory']
        })

        for metric, pre in product(['ratios', 'accuracies', 'ratios_center', 'accuracies_center'], [True, False]):

            results_dict[f'auc_{metric}_{"pre" if pre else "post"}'] = area_under_curve(
                results_dict['budgets'],
                results_dict[f'certified_{metric}'],
                pre
            )

    return results_dict

In [None]:
def get_results_dict_collective(exp, cert_type='argmax_holm', abstain=True, n_images=100, n_classes=21, store_accumulate_gradients=True):
    
    res = torch.load(exp['result']['cert_file'])
    config = res.pop('config')
    budgets = res['budgets']

    results_dict = {
        'std': config['distribution_params']['std_min'],
        'std_max': config['distribution_params']['std_max'],
        'grid_height': config['distribution_params']['grid_shape'][0],
        'grid_width': config['distribution_params']['grid_shape'][1],
        'n_samples_pred': config['sample_params']['n_samples_pred'],
        'n_samples_cert': config['sample_params']['n_samples_cert'],
    }

    if store_accumulate_gradients:
        results_dict['acc_grads'] = config['train_loading']['restrictions']['training_params']['accumulate_gradients']
    results_dict.update({
        'accuracy': calc_pixel_accuracy(res, cert_type, False, abstain, n_images, n_classes),
        'iou': calc_mean_iou(res, cert_type, False, abstain=abstain, n_images=n_images),
        'budgets': budgets,
        'certified_ratios_all': calc_certified_ratios_collective(res, cert_type, True, False, n_images, n_classes),
        'certified_ratios_correct': calc_certified_ratios_collective(res, cert_type, False, True, n_images, n_classes),
        'certified_accuracies_all': calc_certified_pixel_accuracy_collective(res, cert_type, True, False, n_images, n_classes),
        'certified_accuracies_correct': calc_certified_pixel_accuracy_collective(res, cert_type, False, True, n_images, n_classes),
        'naive_certified_ratios': calc_certified_ratios_naive(res, cert_type, n_images, n_classes),
        'naive_certified_accuracies': calc_certified_pixel_accuracy_naive(res, cert_type, n_images, n_classes),
        'time': exp['stats']['real_time'],
        'vram': exp['host']['gpus']['gpus'][0]['total_memory']
    })

    for metric, subset, pre in product(['ratios', 'accuracies'], ['all', 'correct'], [True, False]):

        results_dict[f'auc_{metric}_{subset}_{"pre" if pre else "post"}'] = area_under_curve(
            results_dict['budgets'],
            results_dict[f'certified_{metric}_{subset}'],
            pre
        )

    return results_dict


In [None]:
exps_iid = get_experiments(col_iid, {'config.sample_params.n_samples_pred': {'$ne': 512}})
results_iid = load_results(get_results_dict_iid, exps_iid, './data/pascal_iid', overwrite=False)  # Set to True if you want to load your own data

In [None]:
exps = get_experiments(col)
results_collective = load_results(get_results_dict_collective, exps, './data/pascal_collective_locally_trained', overwrite=False)  # Set to True if you want to load your own data

In [None]:
results_collective

In [None]:
def filter_dominated(ious, cert_accs):
    iou_worse = ious[:, np.newaxis] < ious

    cert_accs_worse = cert_accs[:, np.newaxis] < cert_accs

    worse = iou_worse & cert_accs_worse

    dominated = np.any(worse, axis=1)

    return ious[~dominated], cert_accs[~dominated]

In [None]:
def plot(results_iid, results_collective, n_samples_iid, n_samples_collective, std_min=None, pareto_collective=True, pareto_iid=True, pareto_center=True,
         highlight_index_iid=None, highlight_index_collective=None):
    plt.clf()
    plt.cla()
    #plt.figure(facecolor='white')

    pal = sns.color_palette('colorblind', 3)

    # Collective

    if std_min is not None:
        results_collective = results_collective.loc[results_collective['std'] == std_min]
    results_collective = results_collective.loc[results_collective['n_samples_pred'] == n_samples_collective]

    ious_collective = results_collective['iou'].to_numpy()
    if std_min is not None:
        ious_collective = np.append(ious_collective, results_iid.loc[results_iid['std'] == std_min]['iou'].iloc[0])

    aucs_collective = results_collective['auc_accuracies_correct_post'].to_numpy()
    if std_min is not None:
        aucs_collective = np.append(aucs_collective, results_iid.loc[results_iid['std'] == std_min]['auc_accuracies_post'].iloc[0])

    if pareto_collective:
        ious_collective, aucs_collective = filter_dominated(ious_collective, aucs_collective)

    plt.scatter(ious_collective, aucs_collective, label='Localized LP', marker='.', s=10, color=pal[0])

    # center

    results_iid = results_iid.loc[results_iid['n_samples_pred'] == n_samples_iid]

    iou_center, auc_center = results_iid['iou_center'].to_numpy(), results_iid['auc_accuracies_center_post'].to_numpy()

    if pareto_center:
        iou_center, auc_center = filter_dominated(iou_center, auc_center)

    plt.scatter(iou_center, auc_center, label='CenterSmooth', s=6, marker='x', color=pal[2])

    # iid

    iou_iid, auc_iid = results_iid['iou'].to_numpy(), results_iid['auc_accuracies_post'].to_numpy()

    if pareto_iid:
        iou_iid, auc_iid = filter_dominated(iou_iid, auc_iid)

    plt.scatter(iou_iid, auc_iid, label='SegCertify$^*$', s=4, marker='*', color=pal[1])

    if highlight_index_iid is not None:
        plt.scatter(iou_iid[highlight_index_iid], auc_iid[highlight_index_iid], s=30, marker='*', color=pal[1])
        print(iou_iid[highlight_index_iid], auc_iid[highlight_index_iid])

    # Repeat collective
    #plt.scatter(ious_collective, aucs_collective, marker='.', s=10, color=pal[0])

    if highlight_index_collective is not None:
        plt.scatter(ious_collective[highlight_index_collective], aucs_collective[highlight_index_collective], marker='.', s=50, color=pal[0])
        print(ious_collective[highlight_index_collective], aucs_collective[highlight_index_collective])
    

    
    plt.xlabel('mIOU')
    plt.ylabel('Avg. cert. radius')
    #plt.title(f'std_min = {std_min}')
    plt.legend()

In [None]:
# Same number of samples

sns.set()
pal = sns.color_palette('colorblind', 3)


std_mins = np.sort(list(set(results_collective['std'])))
print(std_mins)

plot(results_iid, results_collective.loc[results_collective['grid_height'] != results_collective['grid_width']], 820, 820, highlight_index_iid=7, highlight_index_collective=6)


In [None]:
# More samples for baseline

sns.set()
pal = sns.color_palette('colorblind', 3)


std_mins = np.sort(list(set(results_collective['std'])))
print(std_mins)

plot(results_iid, results_collective.loc[results_collective['grid_height'] != results_collective['grid_width']], 12288, 820)
