### Precursor Cells

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
root = Path().resolve().parent
if str(root) not in sys.path:
    sys.path.insert(0, '..')
# Standard Libraries
import numpy as np
import matplotlib.patches as mpatches
%matplotlib inline
import os
import pickle

import matplotlib.transforms as transforms
import string
from utils.utils import *
from utils.plots import *




def add_letters(fig, ax, dx=-35/72., dy=15/72.):
    
    letterkwargs = dict(weight='bold', va='top', ha='left')

    offset = transforms.ScaledTranslation(
            dx, dy, fig.dpi_scale_trans)

    for idx in range(len(ax)):
        ax[idx].text(0, 1, string.ascii_lowercase[idx], transform=ax[idx].transAxes + offset, 
                    **letterkwargs)

## Load metrics basic

In [None]:
RESULTS_FOLDER_NAME = '../results'

# Load the metrics for all models
with open(os.path.join(f'{RESULTS_FOLDER_NAME}', 'realized_metrics_dict.pkl'), 'rb') as f:
    realized_metrics_dict = pkl.load(f)
with open(os.path.join(f'{RESULTS_FOLDER_NAME}', 'ATC_metrics_dict.pkl'), 'rb') as f:
    ATC_metrics_dict = pkl.load(f)
with open(os.path.join(f'{RESULTS_FOLDER_NAME}', 'CBPE_metrics_dict.pkl'), 'rb') as f:
    CBPE_metrics_dict = pkl.load(f)
with open(os.path.join(f'{RESULTS_FOLDER_NAME}', 'CM_ATC_metrics_dict.pkl'), 'rb') as f:
    CMATC_metrics_dict = pkl.load(f)
with open(os.path.join(f'{RESULTS_FOLDER_NAME}', 'DoC_metrics_dict.pkl'), 'rb') as f:
    DoC_metrics_dict = pkl.load(f)
with open(os.path.join(f'{RESULTS_FOLDER_NAME}', 'CM_DoC_metrics_dict.pkl'), 'rb') as f:
    CMDoC_metrics_dict = pkl.load(f)

FIGURE_WIDTH = 4.803 # inches
MODEL_NAMES = list(realized_metrics_dict.keys())

METRICS = ['bal_accuracy','recall', 'specificity', 'auc', 'accuracy', 'precision', 'f1_score',]
SEEDS = [f'seed_{i}' for i in range(1, 6)]

In [None]:
MAE_dict = {'CBPE': {calib: {distr: {metric: [] for metric in METRICS} for distr in ['id_test', 'ood_test']} for calib in ['uncal', 'TS', 'CWTS',]},
            'ATC': {calib: {distr: {metric: [] for metric in METRICS} for distr in ['id_test', 'ood_test']} for calib in ['uncal', 'TS', 'CWTS',]},
            'CMATC': {calib: {distr: {metric: [] for metric in METRICS} for distr in ['id_test', 'ood_test']} for calib in ['uncal', 'TS', 'CWTS',]},
            'DoC': {calib: {distr: {metric: [] for metric in METRICS} for distr in ['id_test', 'ood_test']} for calib in ['uncal', 'TS', 'CWTS',]},
            'CMDoC': {calib: {distr: {metric: [] for metric in METRICS} for distr in ['id_test', 'ood_test']} for calib in ['uncal', 'TS', 'CWTS',]},
            }


for model_name in MODEL_NAMES:
    for seed in SEEDS:
        if seed != 'seed_1':  # For now only look at seed 1
            continue
        for n, metric in enumerate(METRICS):
            for calib in ['uncal', 'TS', 'CWTS']:
                for m, distr in enumerate(['id_test', 'ood1_test', 'ood2_test']):
                    if calib == 'uncal':   
                        realized = realized_metrics_dict[model_name][seed][distr][metric]
                        ATC = ATC_metrics_dict[model_name][calib][seed][distr][metric]
                        CBPE = CBPE_metrics_dict[model_name][calib][seed][distr][metric]
                        CMATC = CMATC_metrics_dict[model_name][calib][seed][distr][metric]
                        DoC = DoC_metrics_dict[model_name][calib][seed][distr][metric]
                        CMDoC = CMDoC_metrics_dict[model_name][calib][seed][distr][metric]
                    elif calib == 'CWTS':
                        realized = realized_metrics_dict[model_name][seed][distr][metric]
                        CBPE = CBPE_metrics_dict[model_name][calib][seed][distr][metric]

                    else:
                        continue
                    # Calculate MAE for each metric
                    if 'ood' in distr:
                        distr = 'ood_test'
    
                    MAE_dict['CBPE'][calib][distr][metric].append(np.abs(realized - CBPE))
                    MAE_dict['ATC'][calib][distr][metric].append(np.abs(realized - ATC))
                    MAE_dict['CMATC'][calib][distr][metric].append(np.abs(realized - CMATC))
                    MAE_dict['DoC'][calib][distr][metric].append(np.abs(realized - DoC))
                    MAE_dict['CMDoC'][calib][distr][metric].append(np.abs(realized - CMDoC))

In [None]:
MAE_dict.keys()
id_mae_metric_dict = {metric: [] for metric in METRICS}
ood_mae_metric_dict = {metric: [] for metric in METRICS}
for method in ['CBPE', 'ATC', 'CMATC', 'DoC', 'CMDoC',]:
    for calib in ['uncal']:
        for distr in ['id_test', 'ood_test']:
            for metric in ['accuracy']:
                mae = (np.mean(MAE_dict[method][calib][distr][metric]))
                if distr == 'id_test':
                    id_mae_metric_dict[metric].append(mae)
                else:
                    ood_mae_metric_dict[metric].append(mae)
                    
                

print(f'MAE ID Accuracy = {np.mean(id_mae_metric_dict["accuracy"])} +- {np.std(id_mae_metric_dict["accuracy"])}')
print(f'MAE OOD Accuracy = {np.mean(ood_mae_metric_dict["accuracy"])} +- {np.std(ood_mae_metric_dict["accuracy"])}')

## Paper Figure


In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['hatch.linewidth'] = 0.5
cmap_b = plt.get_cmap('Blues')  # same hue, different lightness
cmap_r = plt.get_cmap('Reds')

color_mapping = {'test': 'k',
                'validation': 'grey',
                'CBPE': 'g',
                'ATC': cmap_b(0.55),
                'CMATC': cmap_b(0.9),
                'DoC': cmap_r(0.55),
                'CMDoC': cmap_r(0.9),
                }




with plt.style.context('../config/plot_style.txt'): # type: ignore
    fig, axs = plt.subplots(2, 4, figsize=(FIGURE_WIDTH, 0.5*FIGURE_WIDTH), layout='constrained', sharey=True, sharex=False) #TODO: FOr no 2times figsize for better view

    axs = axs.flatten()
    add_letters(fig, axs, dx=-1/72, dy=8/72) # Add letters to the subplots
    for n, metric in enumerate(METRICS):

        if metric == 'bal_accuracy':
            axs[n].set_title('bal. accuracy')
        elif metric == 'f1_score':
            axs[n].set_title('F1-score')
        elif metric == 'auc':
            axs[n].set_title('AUC')
        elif metric == 'precision':
            axs[n].set_title('PPV')
        else:
            axs[n].set_title(metric)
        axs[n].set_yticks([0, 0.2, 0.4, ])
        for m, calib in enumerate(['uncal']):   
            for k, distr in enumerate(['id_test', 'ood_test']):
                CBPE_mae = np.mean(MAE_dict['CBPE'][calib][distr][metric])
                ATC_mae = np.mean(MAE_dict['ATC'][calib][distr][metric])
                CMATC_mae = np.mean(MAE_dict['CMATC'][calib][distr][metric])
                DoC_mae = np.mean(MAE_dict['DoC'][calib][distr][metric])
                CMDoC_mae = np.mean(MAE_dict['CMDoC'][calib][distr][metric])

                CBPE_std = np.std(MAE_dict['CBPE'][calib][distr][metric])
                ATC_std = np.std(MAE_dict['ATC'][calib][distr][metric])
                CMATC_std = np.std(MAE_dict['CMATC'][calib][distr][metric])
                DoC_std = np.std(MAE_dict['DoC'][calib][distr][metric])
                CMDoC_std = np.std(MAE_dict['CMDoC'][calib][distr][metric])

                                
                alpha = 1 if distr == 'ood_test' else 1
                hatch = '//////' if distr == 'ood_test' else ''
                axs[n].bar(k*0.2, CBPE_mae, width=0.2, linewidth=0.5,label='CBPE', edgecolor='black', color=color_mapping['CBPE'], alpha=alpha, hatch=hatch,)# yerr=CBPE_std, capsize=1)
                axs[n].bar(k*0.2 + 0.5, ATC_mae, width=0.2, label='ATC', color=color_mapping['ATC'], edgecolor='black', 
                            linewidth=0.5, alpha=alpha, hatch=hatch,)# yerr=ATC_std, capsize=1)
                axs[n].bar(k*0.2 + 1., CMATC_mae, width=0.2,linewidth=0.5, label='CMATC', color=color_mapping['CMATC'], edgecolor='black', alpha=alpha, hatch=hatch,)# yerr=CMATC_std, capsize=1)
                axs[n].bar(k*0.2 + 1.5, DoC_mae, width=0.2,linewidth=0.5, label='DoC', color=color_mapping['DoC'], edgecolor='black', alpha=alpha, hatch=hatch,)# yerr=DoC_std, capsize=1)
                axs[n].bar(k*0.2 + 2., CMDoC_mae, width=0.2,linewidth=0.5, label='CMDoC', color=color_mapping['CMDoC'], edgecolor='black', alpha=alpha, hatch=hatch,)# yerr=CMDoC_std, capsize=1)



                if n >=4:
                    axs[n].set_xticks([0.1, 0.6, 1.1, 1.6, 2.1])
                    axs[n].set_xticklabels(['CBPE', 'ATC', 'CM-\nATC', 'DoC', 'CM-\nDoC',], rotation=0, ha='center',)
                else:
                    axs[n].set_xticks([0.1, 0.6, 1.1, 1.6, 2.1])
                    axs[n].set_xticklabels([])

                    

    ax = axs[-1]
    for l, calib in enumerate(['', 'TS_', 'CWTS_', 'DE_']):
        if calib != '':
            continue  
        ax.set_title('Calibration')
        
        rbs = np.mean([realized_metrics_dict[model_name]['seed_1']['id_test'][f'{calib}rbs'] for model_name in MODEL_NAMES])
        rbs_std = np.std([realized_metrics_dict[model_name]['seed_1']['id_test'][f'{calib}rbs'] for model_name in MODEL_NAMES])
        ace = np.mean([realized_metrics_dict[model_name]['seed_1']['id_test'][f'{calib}ACE'] for model_name in MODEL_NAMES])
        ace_std = np.std([realized_metrics_dict[model_name]['seed_1']['id_test'][f'{calib}ACE'] for model_name in MODEL_NAMES])
        ax.bar(0, rbs, width=0.2, linewidth=0.5,label='RBS', color='grey', edgecolor='black',alpha=1,)
        ax.bar(0.5, ace, width=0.2, linewidth=0.5,label='ACE', color='orange', edgecolor='black', alpha=1,)
        
        rbs = np.mean([[realized_metrics_dict[model_name]['seed_1'][distr][f'{calib}rbs'] for distr in ['ood1_test', 'ood2_test'] ] for model_name in MODEL_NAMES])
        rbs_std = np.std([[realized_metrics_dict[model_name]['seed_1'][distr][f'{calib}rbs'] for distr in ['ood1_test', 'ood2_test'] ] for model_name in MODEL_NAMES])
        ace = np.mean([[realized_metrics_dict[model_name]['seed_1'][distr][f'{calib}ACE'] for distr in ['ood1_test', 'ood2_test'] ] for model_name in MODEL_NAMES])
        ace_std = np.std([[realized_metrics_dict[model_name]['seed_1'][distr][f'{calib}ACE'] for distr in ['ood1_test', 'ood2_test'] ] for model_name in MODEL_NAMES])
        ax.bar(0.2, rbs, width=0.2,linewidth=0.5, color='grey', edgecolor='black', hatch='///', )
        ax.bar(0.7, ace, width=0.2,linewidth=0.5, color='orange', edgecolor='black', hatch='///', )
        
        ax.set_xticks([0.1, 0.6])
        ax.set_xticklabels(['RBS', 'ACE'], rotation=0, ha='center')
        
        
# Paper figure
    for ax in axs.flat:
        ax.grid(False,)
    # Add legend
    labels = ['in-distribution', 'out-of-distribution']
    handles = [
        mpatches.Patch(facecolor='white', edgecolor='black', linewidth=0.5,label='in-distribution'),
        mpatches.Patch(facecolor='white', edgecolor='black', linewidth=0.5,hatch='//////', label='out-of-distribution'),
    ]
    axs[0].legend(handles, labels, loc='best', ncol=1, fontsize=5,  frameon=False)
    axs[0].set_ylabel(f'MAE')
    axs[4].set_ylabel(f'MAE')
    axs[-1].set_ylabel('Calibration Error', labelpad=0)
    
    

# Save the figure
plt.savefig('../figures/Fig2_MAE_performance_comparison.pdf', bbox_inches='tight')

## Appendix A. Calibration on CBPE

In [None]:
with open(os.path.join(f'{RESULTS_FOLDER_NAME}', 'realized_metrics_dict.pkl'), 'rb') as f:
    realized_metrics_dict = pkl.load(f)
with open(os.path.join(f'{RESULTS_FOLDER_NAME}', 'CBPE_cali_metrics_dict.pkl'), 'rb') as f:
    CBPE_metrics_dict = pkl.load(f)

In [None]:
MAE_dict = {'CBPE': {calib: {distr: {metric: [] for metric in METRICS} for distr in ['id_test', 'ood_test']} for calib in ['uncal', 'TS', 'CWTS', 'DE']},
            }


for model_name in MODEL_NAMES:
    for seed in SEEDS:
        if seed != 'seed_1':  # For now only look at seed 1
            continue
        for n, metric in enumerate(METRICS):
            for calib in ['uncal', 'TS', 'CWTS']:
                for m, distr in enumerate(['id_test', 'ood1_test', 'ood2_test']):
                    ax = axs[m]
                    realized = realized_metrics_dict[model_name][seed][distr][metric]
                    CBPE = CBPE_metrics_dict[model_name][calib][seed][distr][metric]
        
                    # Calculate MAE for each metric
                    if 'ood' in distr:
                        distr = 'ood_test'
    
                    MAE_dict['CBPE'][calib][distr][metric].append(np.abs(realized - CBPE))          
                        

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['hatch.linewidth'] = 0.5

cmap_g = plt.get_cmap('Greens')

color_mapping = {'test': 'k',
                'validation': 'grey',
                'CBPE': 'g',
                'CBPE_TS': cmap_g(0.4),
                'CBPE_CWTS': cmap_g(0.6)
                }

with plt.style.context('../config/plot_style.txt'): # type: ignore
    fig, axs = plt.subplots(2, 4, figsize=(FIGURE_WIDTH, 0.5*FIGURE_WIDTH), layout='constrained', sharey=True, sharex=False) 

    axs = axs.flatten()
    add_letters(fig, axs, dx=-1/72, dy=8/72) # Add letters to the subplots
    for n, metric in enumerate(METRICS):

        if metric == 'bal_accuracy':
            axs[n].set_title('bal. accuracy')
        elif metric == 'f1_score':
            axs[n].set_title('F1-score')
        elif metric == 'auc':
            axs[n].set_title('AUC')
        elif metric == 'precision':
            axs[n].set_title('PPV')
        else:
            axs[n].set_title(metric)
        axs[n].set_yticks([0, 0.2, 0.4, ])
        for m, calib in enumerate(['uncal', 'TS', 'CWTS', 'DE']):   
            for k, distr in enumerate(['id_test', 'ood_test']):
                if calib == 'uncal':
                    CBPE_mae = np.mean(MAE_dict['CBPE'][calib][distr][metric])
                    CBPE_std = np.std(MAE_dict['CBPE'][calib][distr][metric])
                                                        
                    alpha = 1 if distr == 'ood_test' else 1
                    hatch = '//////' if distr == 'ood_test' else ''
                    axs[n].bar(k*0.2, CBPE_mae, width=0.2, linewidth=0.5,label='CBPE', edgecolor='black', color=color_mapping['CBPE'], alpha=alpha, hatch=hatch,)# yerr=CBPE_std, capsize=1)
                
                elif calib=='TS':
                    CBPE_mae = np.mean(MAE_dict['CBPE'][calib][distr][metric])
                    CBPE_std = np.std(MAE_dict['CBPE'][calib][distr][metric]) 
                    alpha = 1 if distr == 'ood_test' else 1
                    hatch = '//////' if distr == 'ood_test' else ''
                    axs[n].bar(k*0.2 + 0.5, CBPE_mae,linewidth=0.5, width=0.2, label='CBPE', color=color_mapping['CBPE_TS'], edgecolor='black', alpha=alpha, hatch=hatch,)# yerr=CBPE_std, capsize=1)
        
                elif calib == 'CWTS': # Add CWTS for CBPE
                    CBPE_mae = np.mean(MAE_dict['CBPE'][calib][distr][metric])
                    CBPE_std = np.std(MAE_dict['CBPE'][calib][distr][metric]) 
                    alpha = 1 if distr == 'ood_test' else 1
                    hatch = '//////' if distr == 'ood_test' else ''
                    axs[n].bar(k*0.2 + 1, CBPE_mae,linewidth=0.5, width=0.2, label='CBPE', color=color_mapping['CBPE_CWTS'], edgecolor='black', alpha=alpha, hatch=hatch,)# yerr=CBPE_std, capsize=1)
                    
                if n >=4:
                    axs[n].set_xticks([0.1, 0.6, 1.1])
                    axs[n].set_xticklabels(['CBPE', 'CBPE\nts','CBPE\ncsts'], rotation=0, ha='center',)
                else:
                    #axs[n].set_xticks([0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.])
                    axs[n].set_xticks([0.1, 0.6, 1.1,])
                    axs[n].set_xticklabels([])

                    
    ax = axs[-1]
    for l, calib in enumerate(['', 'TS_', 'CWTS_', 'DE_']):
        if calib == 'DE_':
            continue  
        ax.set_title('Calibration')
        
        rbs = np.mean([realized_metrics_dict[model_name]['seed_1']['id_test'][f'{calib}rbs'] for model_name in MODEL_NAMES])
        ace = np.mean([realized_metrics_dict[model_name]['seed_1']['id_test'][f'{calib}ACE'] for model_name in MODEL_NAMES])
        ax.bar(0+l, rbs, width=0.2, linewidth=0.5,label='RBS', color='grey', edgecolor='black',alpha=1,)
        ax.bar(0.5+l, ace, width=0.2, linewidth=0.5,label='ACE', color='orange', edgecolor='black', alpha=1,)       
        rbs = np.mean([[realized_metrics_dict[model_name]['seed_1'][distr][f'{calib}rbs'] for distr in ['ood1_test', 'ood2_test'] ] for model_name in MODEL_NAMES])
        ace = np.mean([[realized_metrics_dict[model_name]['seed_1'][distr][f'{calib}ACE'] for distr in ['ood1_test', 'ood2_test'] ] for model_name in MODEL_NAMES])
        ax.bar(0.2+l, rbs, width=0.2,linewidth=0.5, color='grey', edgecolor='black', hatch='///', )
        ax.bar(0.7+l, ace, width=0.2,linewidth=0.5, color='orange', edgecolor='black', hatch='///', )
        
        ax.set_xticks([0.1, 0.6, 1.1, 1.6, 2.1, 2.6,])
        ax.set_xticklabels(['RBS', 'ACE', 'RBS\nts', 'ACE\nts', 'RBS\ncsts', 'ACE\ncsts'], rotation=0, ha='center')
        
        

    for ax in axs.flat:
        ax.grid(False,)
    # Add legend
    labels = ['in-distribution', 'out-of-distribution']
    handles = [
        mpatches.Patch(facecolor='white', edgecolor='black', linewidth=0.5,label='in-distribution'),
        mpatches.Patch(facecolor='white', edgecolor='black', linewidth=0.5,hatch='//////', label='out-of-distribution'),
    ]
    axs[0].legend(handles, labels, loc='best', ncol=1, fontsize=5,  frameon=False)
    axs[0].set_ylabel(f'MAE')
    axs[4].set_ylabel(f'MAE')
    axs[-1].set_ylabel('Calibration Error', labelpad=0)
    
    

# Save the figure
plt.savefig('../figures/A2_CBPE_calibration.pdf', bbox_inches='tight')