In [1]:
import wandb
import pickle
import os
import numpy as np
from collections import defaultdict

def compute_majority_votes(model: str, dataset: str, num_runs: int = 10, project_path_base: str = "hugomilosz-imperial-college-london") -> dict:
    """
    Compute majority vote binary predictions from multiple wandb runs.
    """
    project_name = f"{model}_{dataset}_analysis"
    run_names = [f"{model}_{dataset}_run{i}" for i in range(num_runs)]
    project_path = f"{project_path_base}/{project_name}"

    method_binaries = defaultdict(list)

    api_run = wandb.init(project=project_name, job_type="aggregate_eval_summary")

    for run_name in run_names:
        run_id = run_name.split("_")[-1]
        artifact_path = f"{project_path}/eval_summary_{run_id}:v0"
        artifact = api_run.use_artifact(artifact_path, type='pickle')
        artifact_dir = artifact.download()
        
        with open(os.path.join(artifact_dir, f"eval_summary_{run_id}.pkl"), "rb") as f:
            eval_summary = pickle.load(f)

        binary_scores = eval_summary["binary_scores"]

        for method_name, epoch_values in binary_scores.items():
            if not epoch_values:
                continue
            last_epoch_array = epoch_values[-1]
            method_binaries[method_name].append(np.array(last_epoch_array))

    majority_vote_dict = {}

    for method, binary_arrays in method_binaries.items():
        stacked = np.stack(binary_arrays)
        majority = (np.sum(stacked, axis=0) >= (len(binary_arrays) / 2)).astype(int)
        majority_vote_dict[method] = majority

    api_run.finish()
    return majority_vote_dict

In [None]:
votes = compute_majority_votes(model="bert-base", dataset="mnli")

[34m[1mwandb[0m: Currently logged in as: [33mhugomilosz[0m ([33mhugomilosz-imperial-college-london[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [3]:
import pandas as pd
import numpy as np

methods = list(votes.keys())
overlap_matrix = np.zeros((len(methods), len(methods)))

for i, method_i in enumerate(methods):
    for j, method_j in enumerate(methods):
        easy_i = votes[method_i] == 1
        easy_j = votes[method_j] == 1
        intersection = np.sum(np.logical_and(easy_i, easy_j))
        union = np.sum(np.logical_or(easy_i, easy_j))
        overlap_matrix[i, j] = intersection / union if union > 0 else 0.0

overlap_df = pd.DataFrame(overlap_matrix, index=methods, columns=methods)
print(overlap_df)

                     aum   datamap      el2n     grand      loss  forgetting  \
aum             1.000000  0.151295  0.039604  0.028169  0.039604    0.192698   
datamap         0.151295  1.000000  0.015102  0.024811  0.015102    0.908909   
el2n            0.039604  0.015102  1.000000  0.195122  1.000000    0.020284   
grand           0.028169  0.024811  0.195122  1.000000  0.195122    0.029412   
loss            0.039604  0.015102  1.000000  0.195122  1.000000    0.020284   
forgetting      0.192698  0.908909  0.020284  0.029412  0.020284    1.000000   
regularisation  0.205882  0.502045  0.021583  0.032200  0.021583    0.546371   

                regularisation  
aum                   0.205882  
datamap               0.502045  
el2n                  0.021583  
grand                 0.032200  
loss                  0.021583  
forgetting            0.546371  
regularisation        1.000000  
