In [1]:
from pprint import PrettyPrinter
pp = PrettyPrinter(compact=True, indent=4, depth=3)
from functools import partial
import json
from collections import defaultdict
from operator import itemgetter
from pathlib import Path
import pandas as pd
import torch
import numpy as np
from tdigest import TDigest

%matplotlib inline
import matplotlib.pyplot as plt; plt.style.use('bmh')
import matplotlib as mpl
from importlib import reload
import src.conformal  as cp
import src.temperature as ts
import src.helpers as helpers
reload(helpers)
reload(cp)
reload(ts)
from matplotlib import rcParams
rcParams['font.family'] = 'serif'
rcParams['font.sans-serif'] = ['Times']

In [2]:
fst, snd = map(itemgetter, range(2))

In [141]:
fig_dir = Path('figures')
fig_dir.mkdir(exist_ok=True)

path_to_experiments = Path('experiments/')

In [231]:
# dataset = 'mnist'
# dataset = 'svhn'
# dataset = 'fashion'
# dataset = 'cifar10'
# dataset = 'cifar100'
dataset = 'bloodmnist'
# dataset = 'dermamnist'
# dataset = 'pathmnist'
# dataset = 'tissuemnist'
# dataset = 'fitzpatrick'

experiment_names = ['tct', 'fedavg', 'central', 'tct_iid', 'fedavg_iid']
if dataset == 'fitzpatrick':
    model = 'resnet18'
    # partition = 'skin_type_partition'
    partition = 'three_label_partition'
    _val_df = pd.read_csv(path_to_experiments / f'fitzpatrick_tct_{model}_pretrained_{partition}' / 'val_df.csv')
    _test_df = pd.read_csv(path_to_experiments / f'fitzpatrick_tct_{model}_pretrained_{partition}' / 'test_df.csv')
    df = pd.concat([_val_df, _test_df]).reset_index()
    experiments = {
        'central': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_central_{model}_pretrained_{partition}'), dataset=dataset),
        'tct': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_tct_{model}_pretrained_{partition}'), dataset=dataset),
        'fedavg': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_fedavg_{model}_pretrained_{partition}'), dataset=dataset),
        'tct_iid': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_tct_{model}_iid_partition_pretrained_{partition}'), dataset=dataset),
        'fedavg_iid': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_fedavg_{model}_iid_partition_pretrained_{partition}'), dataset=dataset),
    }
    num_classes = 114
    clients_class_map = None
else:
    clients_class_map = helpers.get_client_map(dataset)
    num_classes = sum(map(len, clients_class_map.values()))
    model = 'small_resnet14'
    experiments = {
        'central': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_central_{model}'), dataset=dataset),
        'tct': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_tct_{model}'), dataset=dataset),
        'fedavg': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_fedavg_{model}'), dataset=dataset),
        'tct_iid': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_tct_{model}_iid_partition'), dataset=dataset),
        'fedavg_iid': helpers.load_scores(*path_to_experiments.glob(f'{dataset}_fedavg_{model}_iid_partition'), dataset=dataset),
    }
    
experiments = dict(filter(snd, experiments.items()))

In [232]:
def accuracy(scores, targets):
    correct = scores.argmax(1) == targets
    # total = targets.size(0)
    total = targets.shape[0]
    return (correct.sum() / total).item()

In [233]:
reload(helpers)
trial_val_acc = defaultdict(list)
trial_test_acc = defaultdict(list) 
for trial in range(100):
    trial = helpers.get_new_trial(experiments)
    
    for k, v in trial['experiments'].items():
        trial_val_acc[k].append(accuracy(v['val_scores'], v['val_targets']))
        trial_test_acc[k].append(accuracy(v['test_scores'], v['test_targets']))

for k, v in trial_test_acc.items():
    print(f"\n{k.upper().center(20, '=')}")
    print('val\t', f'{np.mean(trial_val_acc[k]):.2f}')
    print('test\t', f'{np.mean(v):.2f}')
                                          


val	 0.97
test	 0.96

val	 0.88
test	 0.88

val	 0.20
test	 0.20

val	 0.97
test	 0.97

=====FEDAVG_IID=====
val	 0.97
test	 0.97


In [234]:
reload(cp)
reload(helpers)
reload(ts)
num_trials = 10

tct_trials = {}
fedavg_trials = {}
central_trials = {}
for i in range(num_trials):
    
    # randomly split into calibration and evaluation sets
    trial = helpers.get_new_trial(experiments, fitzpatrick_df=df if dataset == 'fitzpatrick' else None)
    trial_experiments = trial['experiments']
    val_df = trial['val_df']
    test_df = trial['test_df']

    # apply aggregate temperature scaling
    ts.client_temp_scale(
        trial_experiments, clients_class_map, 
        val_df=val_df, test_df=test_df,
        use_three_partition_label=True if partition == 'three_partition_label' else False,
    )
    
    # partition validation data into clients
    if dataset == 'fitzpatrick':
        _partition = 'three_partition_label' if partition == 'three_partition_label' else 'aggregated_fitzpatrick_scale'
        client_index_map = {
            str(part): (val_df[_partition] == part).values for part in sorted(val_df[_partition].unique())
        }
    else:
        client_index_map = {
            k: sum(trial_experiments['tct']['val_targets'] == k for k in v).bool() for k, v in clients_class_map.items()
        }

    # conformal parameters
    # alphas = np.arange(0.05, 1, 0.05)
    alphas = np.arange(0.10, 1, 0.10)
    alphas = list(map(lambda x: np.round(x, 2), alphas))
    allow_empty_sets = False # set to True for upper marginal bound
    method = 'lac' # naive, lac, aps, raps

    f = itemgetter('temp_val_scores', 'val_targets', 'temp_test_scores', 'test_targets')
    
    tct_metrics = cp.get_coverage_size_over_alphas(
        *f(trial_experiments['tct']), method=method, 
        allow_empty_sets=allow_empty_sets, alphas=alphas, 
        decentral=True, client_index_map=client_index_map,
    )
    tct_trials[i] = tct_metrics
    
    fedavg_metrics = cp.get_coverage_size_over_alphas(
        *f(trial_experiments['fedavg']), method=method, 
        allow_empty_sets=allow_empty_sets, alphas=alphas, 
        decentral=True, client_index_map=client_index_map,
    )
    fedavg_trials[i] = fedavg_metrics

    central_metrics = cp.get_coverage_size_over_alphas(
        *f(trial_experiments['central']), method=method, 
        allow_empty_sets=allow_empty_sets, alphas=alphas, 
        decentral=False,
    )
    central_trials[i] = central_metrics
    
    print(f'finished trial={i}')
    

finished trial=0
finished trial=1
finished trial=2
finished trial=3
finished trial=4


KeyError: 'nan'

In [None]:
reload(helpers)
tct_results = helpers.combine_trials(tct_trials)
fedavg_results = helpers.combine_trials(fedavg_trials)
central_results = helpers.combine_trials(central_trials)

In [None]:
# plotting settings
fontsize=24
style = 'o:'
markersize=10

# choose experiments
exp_1 = central_results['mean']
exp_2 = fedavg_results['mean']
exp_3 = tct_results['mean']
exp_1_label = 'centralized'
exp_2_label = 'decentralized (baseline)'
exp_3_label = 'decentralized (ours)'

fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
if exp_1 is not None:
    ax[0].plot(*zip(*exp_1['coverage'].items()), style, label=exp_1_label, markersize=markersize)
    ax[1].plot(*zip(*exp_1['size'].items()), style, label=exp_1_label, markersize=markersize)
    
if exp_2 is not None:
    ax[0].plot(*zip(*exp_2['coverage'].items()), style, label=exp_2_label, markersize=markersize)
    ax[1].plot(*zip(*exp_2['size'].items()), style, label=exp_2_label, markersize=markersize)

if exp_3 is not None:
    ax[0].plot(*zip(*exp_3['coverage'].items()), style, label=exp_3_label, markersize=markersize)
    ax[1].plot(*zip(*exp_3['size'].items()), style, label=exp_3_label, markersize=markersize)

ax[0].plot([min(alphas), max(alphas)], [1-min(alphas), 1-max(alphas)], '--', color='black')

ax[0].set_xlim(0, 1)
ax[1].set_xlim(0, 1)
ax[0].set_ylim(0.001, 1)
ax[1].set_ylim(0, num_classes)
# ax[1].set_ylim(0, 25)
ax[0].set_xlabel(r'$\alpha$', fontsize=fontsize)
ax[1].set_xlabel(r'$\alpha$', fontsize=fontsize)
ax[0].set_ylabel('coverage', fontsize=fontsize)
ax[1].set_ylabel('set size', fontsize=fontsize)
ax[0].legend(fancybox=True, fontsize=fontsize-8, loc='lower left')
ax[1].legend(fancybox=True, fontsize=fontsize-8, loc='upper right')

plt.tight_layout()
plt.savefig(fig_dir / f'experiment-1-{dataset}.eps', bbox_inches='tight')
plt.show()

In [None]:
# alphas = [0.05, 0.1, 0.15, 0.2]
alphas = [0.1, 0.2, 0.3]
allow_empty_sets = False
# allow_empty_sets = True
method = 'lac'
precision=3

print('Central'.center(40, '='))
for k, v in central_results.items():
    print('\n', k.center(30, '='))
    for met, val in v.items():
        print(met.center(20, '='))
        for a, b in val.items():
            if a in alphas:
                print(a, '\t', round(b, 2))

print('TCT'.center(40, '='))
for k, v in tct_results.items():
    print('\n', k.center(30, '='))
    for met, val in v.items():
        print(met.center(20, '='))
        for a, b in val.items():
            if a in alphas:
                print(a, '\t', round(b, 2))

print('\n\n', 'Fedavg'.center(40, '='))
for k, v in fedavg_results.items():
    print('\n', k.center(30, '='))
    for met, val in v.items():
        print(met.center(20, '='))
        for a, b in val.items():
            if a in alphas:
                print(a, '\t', round(b, 2))
