### Precursor Cells

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
root = Path().resolve().parent
if str(root) not in sys.path:
    sys.path.insert(0, '..')

# Standard Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import os

import pickle as pkl


# Self writted modules utils
from utils.plots import *
from utils.utils import LabelShiftEvaluator, CovariateShiftEvaluatorCXR



# 1) cheXpert Controlled Shifts

## a) Label Shift

In [None]:
## Prepare and Load the Data ##

# Create the the save folder
save_folder = '../outputs/controlled_label_shift'
os.makedirs(save_folder, exist_ok=True)

# Load the model trained on cheXpert for Pleural Effusion scores and labels 
PATH_TO_DATA = '../results'
with open(os.path.join(PATH_TO_DATA, 'uncal_scores_labs.pkl'), 'rb') as f:
    uncal_scores_labs_seed = pkl.load(f)


# Extract the models validation and test scores and labels
# Validation will serve as I.D. Dataset
# Test will serve as OOD Dataset (i.e. the dataset we want to evaluate the model on with controlled dataset shifts applied)
val_labels = uncal_scores_labs_seed['id_labs_val']
val_probs = uncal_scores_labs_seed['id_out_val']
labels_list = uncal_scores_labs_seed['id_out_labs']
probs_list = uncal_scores_labs_seed['id_out_test']

In [None]:
# Setup and evaluation
TOTAL_SAMPLES = 1000 
SHIFTS = np.linspace(0.05, 0.95, 19)
N_RESAMPLES = 50
METRICS = ['accuracy', 'bal_accuracy', 'precision', 'recall', 'specificity', 'auc', 'f1_score',]
ESTIMATION_METHODS = ['test', 'validation', 'CBPE', 'ATC', 'CMATC', 'DoC', 'CMDoC']


unequal_evaluator = LabelShiftEvaluator(
    keys=METRICS,
    estimation_methods=ESTIMATION_METHODS,
    shifts=SHIFTS,
    total_samples=TOTAL_SAMPLES,
    n_resamples=N_RESAMPLES,
    save_folder=save_folder,  
    optional_name='chexpert_pleff'  # Optional added name for the output file
)

# Run evaluation
unequal_df = unequal_evaluator.evaluate(
    test_labels=labels_list,
    test_probs=probs_list,
    val_labels=val_labels,
    val_probs=val_probs
)


### Paper Figure

In [None]:
path = '../outputs/controlled_label_shift/metrics_per_shift_chexpert_pleff.csv'
save_folder = '../figures/'
# Load the data
df = pd.read_csv(path)

# Plot the controlled label shift
plot_controlled_shift(df, save_folder, optional_name=f'chexpert_controlled', shift_type='label', methods_to_plot=['test', 'CBPE', 'CMATC', 'CMDoC', 'ATC', 'DoC'], FIGURE_WIDTH=4.803,
                     metrics=['accuracy', 'bal_accuracy', 'precision', 'recall', 'specificity', 'auc', 'f1_score', 'ACE / RBS'],)

## b) Covariate Shift Pleural Effusion

In [None]:
e_methods = ['test', 'validation', 'CBPE', 'CMATC', 'CMDoC', 'ATC', 'DoC']
path = '../results/artifact_array.npy'
save_folder = '../outputs/controlled_covariate_shift'
# Create the save folder
os.makedirs(save_folder, exist_ok=True)
## Prepare and Load the Data for Controlled Covariate Shift ##
BIAS_LEVEL=0.8

# laod from path
if os.path.exists(path):
    print(f'Loading test artifact from {path}')
    test_artifacts = np.load(path, allow_pickle=True)



e = CovariateShiftEvaluatorCXR(
    keys=['accuracy', 'bal_accuracy', 'precision', 'recall', 'specificity', 'auc', 'f1_score',],
    estimation_methods=e_methods,
    shift_ratios=np.linspace(0,1,11),
    total_samples=1000,
    n_resamples=50,
    save_folder=save_folder,
    bias_level=BIAS_LEVEL,      
    test_artifacts=test_artifacts, 
    optional_name=f'chexpert_{BIAS_LEVEL}'  # Optional name for the output file
)
df = e.evaluate()


### Paper Figure

In [None]:
path_to_covariate_shift_df = '../outputs/controlled_covariate_shift/metrics_per_covshift_Pleff_chexpert_0.8.csv'
save_folder = '../figures/'

# Load the data
df = pd.read_csv(path_to_covariate_shift_df)
# Plot the controlled covariate shift
plot_controlled_shift(df, save_folder, optional_name=f'chexpert_controlled', shift_type='covariate', methods_to_plot=['test', 'CBPE', 'ATC', 'CMATC', 'DoC', 'CMDoC'], FIGURE_WIDTH=4.803,
                      metrics=['bal_accuracy', 'recall', 'specificity', 'auc', 'accuracy', 'precision',  'f1_score', 'ACE / RBS'])