In [1]:
from datasets import Dataset, DatasetDict
import ast
from tqdm import tqdm
import pandas as pd
import json
import os
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime

# Seaborn settings
sns.set_context("notebook")
sns.set_palette("colorblind")
sns.color_palette("pastel")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb
api = wandb.Api()

In [17]:
project_paths = [
    "dri-ice/Composable_Interventions",
]
all_runs_frame = None
filter_dict = { "state": "Finished" }
for project_name in project_paths:
    runs = api.runs(project_name, filters=filter_dict)
    for run in tqdm(runs, desc=f"Loading runs for project: {project_name}"):
        if "hparam_search" not in run.config["tag"].lower():
            continue

        try:
            # run_start_datetime = datetime.fromtimestamp(run.summary_metrics["_timestamp"])
            # start_cutoff = datetime.strptime("2024-05-20 12:00:00", "%Y-%m-%d %H:%M:%S")
            # if run_start_datetime > start_cutoff:
            #     continue

            config_frame = pd.DataFrame([run.config])
            summary_frame = pd.DataFrame([run.summary_metrics])
            combined_frame = pd.concat([config_frame, summary_frame], axis=1)
            if all_runs_frame is None:
                all_runs_frame = combined_frame
            else:
                all_runs_frame = pd.concat([all_runs_frame, combined_frame])
        except:
            print(f"Failed to load run {run.id}")
            continue

# Filter out runs that don't evaluate on the whole QA dataset. 
all_runs_frame = all_runs_frame[all_runs_frame["qa_question_count_limit"].apply(lambda x: x == None)]

all_runs_frame = all_runs_frame.sort_values("_timestamp", ascending=False)
for column in all_runs_frame.columns:
    all_runs_frame[column] = all_runs_frame[column].apply(lambda x: str(x) if isinstance(x, dict) or isinstance(x, list) else x)

all_runs_frame

Loading runs for project: dri-ice/Composable_Interventions: 100%|██████████| 1590/1590 [00:01<00:00, 1249.61it/s]


Unnamed: 0,tag,edit,save,seed,dtype,ga_lr,wandb,wbits,device,method,...,wmdp_bio stderr,mmlu,_wandb,Locality,wmdp_bio,wmdp_cyber,Generalization recall,PPL edits,Success recall,PPl edits unmasked
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.001000,online,16,0,none,...,0.012299,0.246048,{'runtime': 1901},0.000000,0.260016,0.243080,0.000000,99045222653174644806585299340716146688.0,0.000000,217096868239425592020145041484632031232.0
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.001000,online,16,0,none,...,0.012382,0.255092,{'runtime': 1789},0.000000,0.265515,0.245596,0.000000,33239153044387153113264297686859776.0,0.000000,584509637511607917409729553367040.0
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.001000,online,16,0,none,...,0.012099,0.246546,{'runtime': 1796},0.000000,0.247447,0.243080,0.000000,Infinity,0.000000,Infinity
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.001000,online,16,0,none,...,0.012099,0.246546,{'runtime': 1793},0.000000,0.247447,0.243080,0.000000,Infinity,0.000000,Infinity
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.001000,online,16,0,none,...,0.012099,0.246546,{'runtime': 1795},0.000000,0.247447,0.243080,0.000000,Infinity,0.000000,Infinity
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.000005,online,16,0,none,...,0.012649,0.621706,{'runtime': 1813},0.027140,0.715632,0.441369,0.020667,40004.007812,0.008889,528.27356
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.000005,online,16,0,none,...,0.012690,0.620353,{'runtime': 1806},0.025928,0.712490,0.440866,0.016667,39372.539062,0.008889,529.268188
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.000005,online,16,0,none,...,0.012670,0.620496,{'runtime': 1779},0.027140,0.714061,0.438853,0.020667,40381.101562,0.004444,527.113037
0,ga_llama3_hparam_search,{},out/,42,torch.bfloat16,0.000005,online,16,0,none,...,0.012690,0.620567,{'runtime': 1757},0.027140,0.712490,0.440866,0.020667,40353.078125,0.008889,521.927307


## GA Parameters

In [22]:
list(all_runs_frame.columns)

['tag',
 'edit',
 'save',
 'seed',
 'dtype',
 'ga_lr',
 'wandb',
 'wbits',
 'device',
 'method',
 'dataset',
 'unlearn',
 'alg_name',
 'compress',
 'edit_set',
 'ckpt_path',
 'ga_epochs',
 'load_ckpt',
 'save_ckpt',
 'stats_dir',
 'batch_size',
 'max_length',
 'model_name',
 'save_model',
 'compression',
 'edit_dataset',
 'ga_data_path',
 'rmu_layer_id',
 'wandb_entity',
 'ga_batch_size',
 'interventions',
 'wandb_project',
 'eval_zero_shot',
 'model_parallel',
 'sparsity_ratio',
 'unlearn_method',
 'number_of_edits',
 'ga_forget_corpora',
 'ga_retain_corpora',
 'compression_dataset',
 'ga_test_sample_size',
 'rmu_max_num_batches',
 'ga_train_sample_size',
 'qa_question_count_limit',
 'PPl QA',
 '_timestamp',
 'mmlu accuracy',
 'Local recall',
 'PPL',
 'FLOPs',
 'Metrics',
 'mmlu stderr',
 'wmdp_cyber accuracy',
 'Generalization',
 'Rewrite accuracy',
 'wmdp_bio accuracy',
 '_step',
 'Latency',
 'wmdp_cyber stderr',
 '_runtime',
 'Average bits',
 'wmdp_bio stderr',
 'mmlu',
 '_wandb',


In [37]:
ga_columns = ["ga_train_sample_size", "ga_test_sample_size", "ga_lr", "ga_batch_size", "mmlu accuracy", "wmdp_bio accuracy", "wmdp_cyber accuracy"]
ga_data = all_runs_frame[ga_columns]
ga_data["mean wmdp accuracy"] = ga_data[["wmdp_bio accuracy", "wmdp_cyber accuracy"]].mean(axis=1)

# Display all pandas rows
pd.set_option('display.max_rows', None)

# Only look at examples where MMU > 40
ga_data = ga_data[(ga_data["mmlu accuracy"] > 0.40) & (ga_data["mean wmdp accuracy"] < 0.4)]
ga_data.sort_values("mean wmdp accuracy", ascending=True)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ga_data["mean wmdp accuracy"] = ga_data[["wmdp_bio accuracy", "wmdp_cyber accuracy"]].mean(axis=1)


Unnamed: 0,ga_train_sample_size,ga_test_sample_size,ga_lr,ga_batch_size,mmlu accuracy,wmdp_bio accuracy,wmdp_cyber accuracy,mean wmdp accuracy
0,25,,0.0001,8,0.534183,0.260801,0.409663,0.335232
0,25,,0.0001,8,0.534397,0.260801,0.410669,0.335735
0,25,,0.0001,8,0.535964,0.261587,0.411173,0.33638


## RMU Parameters

In [18]:
rmu_columns = ["rmu_alpha", "rmu_layer_id", "rmu_max_num_batches", "mmlu accuracy", "wmdp_bio accuracy", "wmdp_cyber accuracy"]
rmu_data = all_runs_frame[rmu_columns]
rmu_data["rmu_alpha"] = rmu_data["rmu_alpha"].apply(lambda x: ast.literal_eval(x)[0])
rmu_data = rmu_data.sort_values("rmu_layer_id")
rmu_data["mean wmdp accuracy"] = rmu_data[["wmdp_bio accuracy", "wmdp_cyber accuracy"]].mean(axis=1)
rmu_data

KeyError: "['rmu_alpha'] not in index"

In [None]:
rmu_data[(rmu_data["mean wmdp accuracy"] < .30) & (rmu_data["mmlu accuracy"] > .55)].sort_values("mean wmdp accuracy", ascending=True)

In [None]:
display(rmu_data.value_counts("rmu_alpha"))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 10))

# line plot where the 