In [None]:
%cd ..

: 

In [None]:
%load_ext autoreload 
%autoreload 2
%config InlineBackend.figure_format = 'retina'

: 

In [None]:
import numpy as np
from sklearn.utils import resample
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
plt.ioff()


class Bootstrap:

    def __init__(self, 
                 loss,
                 lambda_: float = None,
                 n_iterations: int = 100, 
                 random_state: int = None,
                 norm_weights: bool = True):
        """
        Initializes the Bootstrap object.

        Parameters:
        data (array-like): The dataset to bootstrap.
        n_iterations (int): Number of bootstrap iterations.
        random_state (int): Seed for reproducibility.
        """

        assert loss in ['robust', 'ridge', 'lasso']
        self.loss = loss
        self.lambda_ = lambda_
        self.n_iterations = n_iterations
        self.random_state = random_state
        self.bootstrap_samples = []
        self.norm_weights = norm_weights


    def generate_data(self):
        """ /!!!!!!!!!!!!!!\ CURRENTLY ONLY ON ONE MODEL REMOVED VERSION, NOT TWO MODELS REMOVED
        """
        from utils.data import get_full_post_training_df, regression_raw_benchmarks_difference_cot_naive_get_kwargs, regression_raw_benchmarks_difference_cot_naive, get_weights_from_fit_results
        from utils.constants import BBH_SUBTASKS

        merged_eval, mmlu_subtasks = get_full_post_training_df()
        models_list = merged_eval['Model'].unique().tolist()

        # The first 2 models are Llama2, used for the equiv scale so they can't be removed
        lists_with_one_model_removed = [models_list[:i] + models_list[i+1:] for i in range(2, len(models_list))] 
        print('Number of lists with one model removed', len(lists_with_one_model_removed))

        lists_with_two_models_removed = [models_list[:i] + models_list[i+1:j] + models_list[j+1:] for i in range(2, len(models_list)) for j in range(i+1, len(models_list))]
        print('Number of lists with two models removed', len(lists_with_two_models_removed))

        all_fit_results = pd.DataFrame()
        for bbh_subtask in BBH_SUBTASKS:
            regression_kwargs = regression_raw_benchmarks_difference_cot_naive_get_kwargs(bbh_task=bbh_subtask,
                                                                                          other_subtasks=mmlu_subtasks,
                                                                                          loss=self.loss,
                                                                                          lambda_=self.lambda_)
            
            for s, model_sublist in enumerate(lists_with_one_model_removed):
                sub_merged_eval = merged_eval[merged_eval['Model'].isin(model_sublist)]
                fit_results = regression_raw_benchmarks_difference_cot_naive(base_llm_eval_with_post_training=sub_merged_eval,
                                                                             regression_kwargs=regression_kwargs)
                weights = get_weights_from_fit_results(fit_results=fit_results,
                                                       regression_kwargs=regression_kwargs,
                                                       norm_weights=self.norm_weights)

                temp_df = {
                    'subtask': bbh_subtask,
                    'sublist_id': s
                }
                for i, weight in enumerate(weights):
                    temp_df[regression_kwargs['metric_list'][i]] = weight
                all_fit_results = pd.concat([all_fit_results, pd.DataFrame(temp_df)], ignore_index=True)

        return all_fit_results


    def generate_samples(self):
        """
        Generates bootstrap samples.
        """
        from utils.constants import BBH_SUBTASKS
        all_fit_results = self.generate_data()
        n_sublists = len(all_fit_results['sublist_id'].unique())

        np.random.seed(self.random_state)
        for i in range(self.n_iterations):
            for bbh_subtask in BBH_SUBTASKS:
                subtask_data = all_fit_results[(all_fit_results['subtask'] == bbh_subtask) & (all_fit_results['sublist_id'] == i%n_sublists)]
                float_columns = subtask_data.select_dtypes(include='float64')
                data_array = float_columns.to_numpy()
            
            sample = resample(data_array, replace=True)
            self.bootstrap_samples.append(sample)


    def compute_statistic(self, statistic_func):
        """
        Computes a statistic over the bootstrap samples.

        Parameters:
        statistic_func (function): Function to compute the statistic.

        Returns:
        list: List of statistic values for each bootstrap sample.
        """
        if not self.bootstrap_samples:
            self.generate_samples()
        
        results = [statistic_func(sample) for sample in self.bootstrap_samples]
        return results


    def summary(self, results):
        """
        Summarizes the bootstrap results.

        Parameters:
        stats (list): List of statistic values.

        Returns:
        dict: Dictionary with mean, standard error, and confidence interval.
        """
        stats = [result[0] for result in results]
        zero_in_ci_mask = [result[1] for result in results]
        masked = [np.ma.masked_where(zero_in_ci_mask[i], stats[i]) for i in range(0, len(stats))]

        # Count the number of times 0 is in the confidence interval
        false_counts = np.sum(~zero_in_ci_mask, axis=0)

        mean_stat = np.mean(masked, axis=0)
        std_error = np.std(masked, axis=0)
        confidence_interval = np.percentile(masked, [2.5, 97.5], axis=0)
        
        return {
            'stats': masked,
            'mean': mean_stat,
            'std_error': std_error,
            'confidence_interval': confidence_interval,
            'false_counts': false_counts
        }
    

def average_weights(sample):
    mean_sample = np.mean(sample, axis=1)
    confidence_interval_sample = np.percentile(sample, [2.5, 97.5], axis=1)
    zero_in_ci = np.sign(confidence_interval_sample[0]*confidence_interval_sample[1])
    zero_in_ci_mask = zero_in_ci < 0 # If 0 is contained in the confidence interval, the value will be removed from the mean calculus
    return mean_sample, zero_in_ci_mask


def plot_raw_benchmarks_weights(weights: np.ndarray,
                                metric_list: list,
                                vmax=1
):
    
    fig, ax = plt.subplots(figsize=(18,1))
    sns.heatmap(weights.reshape((1,31)), annot=True, fmt='.2f', 
                cmap='coolwarm', vmin=-vmax, vmax=vmax, 
                ax=ax, annot_kws={"fontsize":8})
        
    ax.set_xticklabels(metric_list, rotation=90)
    return fig

: 

In [4]:
bootstrap = Bootstrap(loss='ridge',
                      lambda_=1e-1,
                      n_iterations=1,
                      random_state=42)

stats = bootstrap.compute_statistic(average_weights)

summary = bootstrap.summary(stats)
print('Number of times 0 is in the confidence interval', summary['false_counts'])

In [None]:
from utils.data import get_full_post_training_df
merged_eval, mmlu_subtasks = get_full_post_training_df()

#plot_raw_benchmarks_weights(weights=summary['mean'],
#                            metric_list=['ARC-C', 'HellaSwag', 'Winograd', 'TruthfulQA', 'GSM8K', 'XWinograd', 'HumanEval']+mmlu_subtasks)

: 

In [None]:

merged_eval['Model'].unique().tolist()

: 