In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from confpred.cp import run_cp
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from confpred import ConformalPredictor,SparseScore,SoftmaxScore,RAPSScore,RAPSPredictor,LimitScore
import pickle


In [None]:
# Get optimal RAPS parameters
raps_params = pd.read_csv('./data/results_analysis/raps_optimal_vit_test.csv')
raps_params.set_index(['dataset','random_state','alpha'],inplace = True)

In [None]:
# Get optimal opt-entmax parameter
opt_entmax_params = pd.read_csv(f'{ROOT_DIR}/data/results_analysis/optimal_entmax_vit_all_alpha.csv')
opt_entmax_params.set_index(['dataset','random_state','alpha'],inplace = True)

In [None]:
dataset_list = ['ImageNet','CIFAR10','NewsGroups','CIFAR100']
seed = '23'
model_loss = 'softmax'
alpha_list = np.round(np.linspace(0.01,0.1,10),4)
random_states = [1,12,123,1234,12345]
ROOT_DIR = '.'
score_list = ['limit','opt_entmax','RAPS','sparsemax','softmax','entmax']
transformation='logits'
summary_results = pd.DataFrame({'dataset':dataset_list})\
    .merge(pd.DataFrame({'random_state':random_states}), how = 'cross')\
        .merge(pd.DataFrame({'alpha':alpha_list}), how = 'cross')\
            .merge(pd.DataFrame({'score':score_list}), how = 'cross')
summary_results['avg_size'] = np.nan
summary_results['coverage'] = np.nan
summary_results.set_index(['dataset','random_state','alpha','score'],inplace=True)
summary_results.sort_index(inplace=True)
for dataset in dataset_list:
    if dataset == 'NewsGroups':
        model_type = 'bert'
    elif dataset == 'CIFAR10':
        model_type = 'cnn'
    else:
        model_type = 'vit'
    path = f'{ROOT_DIR}/data/predictions/{model_type}_{dataset}_test_{model_loss}_{transformation}_{seed}_proba.pickle'
    with open(path, 'rb') as f:
        test_preds_og = pickle.load(f)
    path = f'{ROOT_DIR}/data/predictions/{dataset}_{seed}_test_true.pickle'
    with open(path, 'rb') as f:
        test_true_enc_og = pickle.load(f)
    for random_state in random_states:
        cal_size = np.ceil(test_true_enc_og.shape[0]*0.4).astype(int)
        test_preds,test_true_enc = shuffle(test_preds_og,test_true_enc_og,random_state = random_state)
        cal_proba = test_preds[0:cal_size]
        test_proba = test_preds[cal_size:]
        cal_true_enc = test_true_enc[0:cal_size]
        test_true_enc = test_true_enc[cal_size:]
        for alpha in alpha_list:
            lam_reg, k = raps_params.loc[(dataset,random_state,alpha),('lam_reg','k_reg')]
            k = int(k)
            lambd = opt_entmax_params.loc[(dataset,random_state,alpha),'best_lambda']
            for score in score_list:
                use_temperature = False
                if score == 'sparsemax':
                    cp = ConformalPredictor(SparseScore(2))
                elif score == 'softmax':
                    cp = ConformalPredictor(SoftmaxScore())
                elif score == 'entmax':
                    cp = ConformalPredictor(SparseScore(1.5))
                elif score == 'limit':
                    cp = ConformalPredictor(LimitScore())
                elif score == 'opt_entmax':
                    cp = ConformalPredictor(SparseScore(lambd))
                if score=='RAPS':
                    cp = RAPSPredictor(RAPSScore(lam_reg=lam_reg,k_reg=k))
                print(f'Running {dataset} {score} {seed} {random_state}')
                cp.calibrate(cal_true_enc, cal_proba, alpha)
                if score in ['entmax','sparsemax','opt_entmax']:
                    use_temperature = True
                avg_set_size, coverage = cp.evaluate(test_true_enc, test_proba,use_temperature=use_temperature)
                summary_results.loc[(dataset,random_state,alpha,score),'avg_size'] = avg_set_size
                summary_results.loc[(dataset,random_state,alpha,score),'coverage'] = coverage
                cal_preds = cp.predict(cal_proba,use_temperature=use_temperature)
                test_preds = cp.predict(test_proba,use_temperature=use_temperature)
                
                # write predictions to file
                with open(f'./data/set_prediction/new/{dataset}_{score}_{seed}_{random_state}_alpha{alpha}_cal_pred.pickle', 'wb') as handle:
                    pickle.dump(cal_preds, handle)
                with open(f'./data/set_prediction/new/{dataset}_{score}_{seed}_{random_state}_alpha{alpha}_test_pred.pickle', 'wb') as handle:
                    pickle.dump(test_preds, handle)
# Save summary results to file
summary_results.reset_index().to_csv('./data/results_analysis/summary_results.csv', index=False)