In [None]:
import wandb
import pickle
import os
import numpy as np
from collections import defaultdict

def compute_majority_votes(model: str, dataset: str, num_runs: int = 10, project_path_base: str = "hugomilosz-imperial-college-london") -> dict:
    """
    Compute majority vote binary predictions from multiple wandb runs.

    Parameters:
    - model (str): Model name from AVAILABLE_MODELS
    - dataset (str): Dataset name from AVAILABLE_DATASETS
    - num_runs (int): Number of runs to aggregate over
    - project_path_base (str): Base path for wandb project

    Returns:
    - dict: {method_name: majority_vote_array}
    """
    project_name = f"{model}_{dataset}_analysis"
    run_names = [f"{model}_{dataset}_run{i}" for i in range(num_runs)]
    project_path = f"{project_path_base}/{project_name}"

    method_binaries = defaultdict(list)

    api_run = wandb.init(project=project_name, job_type="aggregate_eval_summary")

    for run_name in run_names:
        run_id = run_name.split("_")[-1]
        artifact_path = f"{project_path}/eval_summary_{run_id}:v0"
        artifact = api_run.use_artifact(artifact_path, type='pickle')
        artifact_dir = artifact.download()
        
        with open(os.path.join(artifact_dir, f"eval_summary_{run_id}.pkl"), "rb") as f:
            eval_summary = pickle.load(f)

        binary_scores = eval_summary["binary_scores"]

        for method_name, epoch_values in binary_scores.items():
            if not epoch_values:
                continue
            last_epoch_array = epoch_values[-1]
            method_binaries[method_name].append(np.array(last_epoch_array))

    majority_vote_dict = {}

    for method, binary_arrays in method_binaries.items():
        stacked = np.stack(binary_arrays)
        majority = (np.sum(stacked, axis=0) >= (len(binary_arrays) / 2)).astype(int)
        majority_vote_dict[method] = majority

    api_run.finish()
    return majority_vote_dict

[34m[1mwandb[0m: Currently logged in as: [33mhugomilosz[0m ([33mhugomilosz-imperial-college-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Downloading large artifact eval_dict_run0:v0, 99.62MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact eval_dict_run1:v0, 99.62MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact eval_dict_run2:v0, 99.62MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5
[34m[1mwandb[0m: Downloading large artifact eval_dict_run3:v0, 99.62MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5
[34m[1mwandb[0m: Downloading large artifact eval_dict_run4:v0, 99.62MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5
[34m[1mwandb[0m: Downloading large artifact eval_dict_run5:v0, 99.62MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4
[34m[1mwandb[0m: Downloading large artifact eval_dict_run6:v0, 99.62MB. 1 files... 
[34m[1mwand

In [16]:
votes = compute_majority_votes(model="bert-tiny", dataset="multi_nli")

array(False)

In [None]:
import pandas as pd

methods = list(votes.keys())
overlap_matrix = np.zeros((len(methods), len(methods)))

for i, method_i in enumerate(methods):
    for j, method_j in enumerate(methods):
        easy_i = votes[method_i] == 1
        easy_j = votes[method_j] == 1
        overlap = np.sum(np.logical_and(easy_i, easy_j))
        total_easy = np.sum(easy_i)
        overlap_matrix[i, j] = overlap / total_easy if total_easy > 0 else 0.0

# To get it nicely formatted
overlap_df = pd.DataFrame(overlap_matrix, index=methods, columns=methods)
print(overlap_df)

                          aum        datamap          el2n     grand  \
aum                  1.000000       0.615773      0.216838       1.0   
datamap              0.619622       1.000000      0.009117       1.0   
el2n                 0.912893       0.038143      1.000000       1.0   
grand                0.708293       0.703893      0.168240       1.0   
loss                 0.926675       0.027609      0.940217       1.0   
forgetting      278148.000000  276420.000000  66068.000000  392702.0   
regularisation       0.851895       0.542786      0.263827       1.0   

                        loss  forgetting  regularisation  
aum                 0.221911         1.0        0.762601  
datamap             0.006653         1.0        0.488930  
el2n                0.947902         1.0        0.994294  
grand               0.169615         1.0        0.634051  
loss                1.000000         1.0        0.993124  
forgetting      66608.000000         1.0   248993.000000  
regularisa