In [None]:
import pickle
from utils import plot_pct_change
from niagara.probabilistic_modeling.optimize_cascade import make_full_data, score_cascade, \
    compute_results_grids, get_expected_uncumulated_costs, smooth_outliers
from early_abs_setup import setup_data
import numpy as np
from tqdm import tqdm

CASCADES = [
    ("llama_chain", [0,1]),
    ("llama_chain", [0,2]),
    ("llama_chain", [0,3]),
    ("llama_chain", [0,4]),
    ("llama_chain", [1,2]),
    ("llama_chain", [1,3]),
    ("llama_chain", [1,4]),
    ("llama_chain", [2,3]),
    ("llama_chain", [2,4]),
    ("llama_chain", [3,4]),
    ("qwen_oai_chain", [0,1]),
    ("qwen_oai_chain", [0,2]),
    ("qwen_oai_chain", [0,3]),
    ("qwen_oai_chain", [1,2]),
    ("qwen_oai_chain", [1,3]),
    ("qwen_oai_chain", [2,3])
]

outlier_stats = []
outlier_stats_no = []

benchmark_names = ["mmlu", "medmcqa", "triviaqa", "xsum", "gsm8k", "truthfulqa"]

metric = 'error'
average_train_results = { }
average_test_results = { }

for cascade in tqdm(CASCADES):
    chain_name, model_indices = cascade
    model_idx_str = "".join([ str(x) for x in model_indices ])

    average_train_results[chain_name + model_idx_str] = { name: {} for name in benchmark_names }
    average_test_results[chain_name + model_idx_str] = { name: {} for name in benchmark_names }

    for benchmark_name in benchmark_names:
        ### ### ### ### ### ### ### ###
        PROB_MODEL_FILENAME = f"./optimal_thresholds_data/{benchmark_name}/prob_model_results_{chain_name}.pkl"
        OPT_THOLDS_FILENAME = f"./optimal_thresholds_data/{benchmark_name}/optimal_tholds_{chain_name}_{model_idx_str}.pkl"

        with open(PROB_MODEL_FILENAME, "rb") as file:
            prob_model = pickle.load(file)
            
        with open(OPT_THOLDS_FILENAME, "rb") as file:
            optimal_thresholds = pickle.load(file)

        ###### Evaluate The Cascade ######

        lambda_abs_grid = optimal_thresholds["lambda_abs_grid"]
        lambda_cost_grid = optimal_thresholds["lambda_cost_grid"]
        error_type = optimal_thresholds["error_type"]
        T_2d, S_2d = [ optimal_thresholds["early_abs"][x] for x in ["T","S"] ]
        T_2d_no, S_2d_no = [ optimal_thresholds["final_model_abs"][x] for x in ["T","S"] ]

        ### Get The Train and Test Data

        all_setup_data = setup_data(NAME=benchmark_name, CHAIN_NAME=chain_name)

        raw_model_costs = { 
            model_name: all_setup_data['chain'].models[i].cpm_tokens 
                for i, model_name in enumerate(all_setup_data['chain'].model_names) 
        }
        results_train = all_setup_data['raw_results']['train']
        results_test = all_setup_data['raw_results']['test']
        expected_uncumulated_costs_train = get_expected_uncumulated_costs(raw_model_costs, results_train)
        expected_uncumulated_costs_test = get_expected_uncumulated_costs(raw_model_costs, results_test)

        calibrated_conf_train = all_setup_data['calibrated_conf']['train']
        calibrated_conf_test = all_setup_data['calibrated_conf']['test']
        corr_train = all_setup_data['corr']['train']
        corr_test = all_setup_data['corr']['test']
        fit_stats = all_setup_data['logreg_fit_stats']

        test_data = {
            'calib_conf': make_full_data(calibrated_conf_test), 
            'corr': make_full_data(corr_test)
        }

        train_data = {
            'calib_conf': make_full_data(calibrated_conf_train), 
            'corr': make_full_data(corr_train)
        }

        ### Smooth the optimal thresholds

        outlier_threshold = 10

        if outlier_threshold is not None:
            T_2d, outliers_T_2d = smooth_outliers(T_2d, r=outlier_threshold)
            S_2d, outliers_S_2d = smooth_outliers(S_2d, r=outlier_threshold)
            T_2d_no, outliers_T_2d_no = smooth_outliers(T_2d_no, r=outlier_threshold)
            S_2d_no, outliers_S_2d_no = smooth_outliers(S_2d_no, r=outlier_threshold)

            print(f"{np.mean(outliers_T_2d)} outliers smoothed! (early T)")
            print(f"{np.mean(outliers_S_2d)} outliers smoothed! (early S)")
            print(f"{np.mean(outliers_T_2d_no)} outliers smoothed! (final T)")
            print(f"{np.mean(outliers_S_2d_no)} outliers smoothed! (final S)")

            outlier_stats += [ 
                np.mean(outliers_T_2d),
                np.mean(outliers_S_2d),
            ]

            outlier_stats_no += [
                np.mean(outliers_T_2d_no),
                np.mean(outliers_S_2d_no)
            ]

        ### Compute Results Grids

        results_grids_train_early_abs = compute_results_grids(
            model_indices=model_indices,
            T_2d=T_2d,
            S_2d=S_2d,
            lambda_cost_grid=lambda_cost_grid,
            lambda_abs_grid=lambda_abs_grid,
            data=train_data,
            expected_uncumulated_costs=expected_uncumulated_costs_train,
            error_type=error_type
        )

        results_grids_train_final_model_abs = compute_results_grids(
            model_indices=model_indices,
            T_2d=T_2d_no,
            S_2d=S_2d_no,
            lambda_cost_grid=lambda_cost_grid,
            lambda_abs_grid=lambda_abs_grid,
            data=train_data,
            expected_uncumulated_costs=expected_uncumulated_costs_train,
            error_type=error_type
        )

        results_grids_test_early_abs = compute_results_grids(
            model_indices=model_indices,
            T_2d=T_2d,
            S_2d=S_2d,
            lambda_cost_grid=lambda_cost_grid,
            lambda_abs_grid=lambda_abs_grid,
            data=test_data,
            expected_uncumulated_costs=expected_uncumulated_costs_test,
            error_type=error_type
        )

        results_grids_test_final_model_abs = compute_results_grids(
            model_indices=model_indices,
            T_2d=T_2d_no,
            S_2d=S_2d_no,
            lambda_cost_grid=lambda_cost_grid,
            lambda_abs_grid=lambda_abs_grid,
            data=test_data,
            expected_uncumulated_costs=expected_uncumulated_costs_test,
            error_type=error_type
        )

        plot_pct_change(
            results_grids_test_early_abs, 
            results_grids_test_final_model_abs, 
            lambda_cost_grid,
            lambda_abs_grid,
            filename=f"pct_chg_heatmap_{benchmark_name}_{chain_name}_{model_idx_str}_{metric}_outlier={outlier_threshold}.pdf",
            save_fig=True
        )

        early_abs_val_train = np.mean(results_grids_train_early_abs[metric])
        final_model_abs_val_train = np.mean(results_grids_train_final_model_abs[metric])

        early_abs_val_test = np.mean(results_grids_test_early_abs[metric])
        final_model_abs_val_test = np.mean(results_grids_test_final_model_abs[metric])

        average_train_results[chain_name + model_idx_str][benchmark_name]['early_abs'] = early_abs_val_train
        average_train_results[chain_name + model_idx_str][benchmark_name]['final_model_abs'] = final_model_abs_val_train
        average_test_results[chain_name + model_idx_str][benchmark_name]['early_abs'] = early_abs_val_test
        average_test_results[chain_name + model_idx_str][benchmark_name]['final_model_abs'] = final_model_abs_val_test

### Save to file
with open(f"./performance_data/overall_results_{metric}_outlier={outlier_threshold}.pkl", "wb") as file:
    pickle.dump({
        "train": average_train_results,
        "test": average_test_results
    }, file)

with open(f"./performance_data/outlier_stats_outlier={outlier_threshold}.pkl", "wb") as file:
    pickle.dump({
        "early_abs": outlier_stats,
        "final_model_abs": outlier_stats_no
    }, file)