In [1]:
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import os
import ast
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime
from IPython.display import display, Latex

import wandb
api = wandb.Api()

tqdm.pandas()

# Plotting Constants

In [2]:
# Set the font family to serif
plt.rcParams["font.family"] = "serif"

# Seaborn settings
sns.set_context("notebook")
sns.set_palette("colorblind")
sns.color_palette("pastel")

# plotting constants
TITLE_FONT_SIZE = 18
LEGEND_FONT_SIZE = 12
WSPACE = 0.3
FIGURE_HEIGHT = 3
LINE_WIDTH = 2
FIG_SIZE = 3
MARKER_SIZE = 8
X_LABEL_ROTATION = 20

# Set colors for compositons with compression
colors = {"Wanda": "C1", "SparseGPT": "C2", "AWQ": "C3", "GPTQ": "C4"}

# Pull and Dedup Data

In [3]:
setting_columns = [
    # Overall
    "tag",
    # "seed",
    "_timestamp",

    # Interventions
    "interventions", "edit", "unlearn", "compression", "model_name",

    # Editing
    "edit_set", 
    "edit_dataset", "number_of_edits",

    # Unlearning
    "rmu_layer_id",

    # Compression
    "wbits", "compression_dataset", "sparsity_ratio",
]
evaluation_columns = [
    "qa_question_count_limit",  # An artifical max number of questions to ask during evaluation. Should be none when not debugging.
    "mmlu accuracy",            # The accuracy of the model on the MMLU dataset. This measures overall model utility. Llama-3 should be ~62%
    "wmdp_bio accuracy",        # The accuracy of the model on the WMDP bio split. This is the unlearning target. Should be ~25% when RMU is applied.
    "wmdp_cyber accuracy",      # The accuracy of the model on the WMDP cyber split. This is the unlearning target. Should be ~25% when RMU is applied.
    "PPL",                      # TODO:
    "PPL edits",                # Perplexity for the edits. Should be low when editing is applied.
    "PPl QA",                   # Perplexity for the QA. Should be low when QA is applied.
    "Generalization",           # TODO: 
    "FLOPs",                    # TODO: 
    "Success recall",           # TODO:
    "Generalization recall",    # TODO:
    "Locality",                 # TODO:
    "Average bits",             # TODO:
    "Rewrite accuracy",         # TODO:
    "PPl edits unmasked",       # TODO:
    "Local recall",             # TODO:
    "Latency",                  # TODO:
]
relevant_columns = setting_columns + evaluation_columns

In [4]:
# Composable_Interventions has all the results
project_paths = ["dri-ice/Composable_Interventions",]

filter_dict = { "state": "finished" }
data_frames = []
for project_path in project_paths:
    runs = api.runs(project_path, filters=filter_dict)

    # Iterate over eachrun and capture the c        onfig and summary metrics
    for run in tqdm(runs, desc=project_path):
        try:
            run_start_datetime = datetime.fromtimestamp(run.summary_metrics["_timestamp"])
            start_cutoff = datetime.strptime("2024-05-18 00:00:00", "%Y-%m-%d %H:%M:%S")
            end_cutoff = datetime.strptime("2024-10-29 00:00:00", "%Y-%m-%d %H:%M:%S")
            if run_start_datetime < start_cutoff or run_start_datetime > end_cutoff:
                continue

            skip_tags = ["test", "hparam_search"]
            should_skip = False
            for tag in skip_tags:
                if tag in run.config["tag"].lower():
                    should_skip = True

            if should_skip:
                continue

            config_frame = pd.DataFrame([run.config])
            summary_frame = pd.DataFrame([run.summary_metrics])
            combined_frame = pd.concat([config_frame, summary_frame], axis=1)
            data_frames.append(combined_frame)
        except Exception as e:
            print(f"Error processing run {run.id}: {e}")


dri-ice/Composable_Interventions:   0%|          | 0/8765 [00:00<?, ?it/s]

Error processing run 27f8pxs0: '_timestamp'
Error processing run xr5mede5: '_timestamp'
Error processing run n0iel6ok: '_timestamp'


dri-ice/Composable_Interventions:  26%|██▋       | 2301/8765 [00:17<01:03, 101.39it/s]

Error processing run arid375k: '_timestamp'
Error processing run r6kpsu09: '_timestamp'
Error processing run sdhehb2z: '_timestamp'
Error processing run 2nv88i8v: '_timestamp'
Error processing run 31j4yjsr: '_timestamp'
Error processing run o1ai36xl: '_timestamp'
Error processing run 7t3n8sq1: '_timestamp'
Error processing run cc3cmdlj: '_timestamp'
Error processing run 1wj0u6cj: '_timestamp'
Error processing run 64ed5z4t: '_timestamp'
Error processing run 71jdht68: '_timestamp'
Error processing run 2do500pc: '_timestamp'
Error processing run lrh5z3wp: '_timestamp'
Error processing run luhstpn5: '_timestamp'
Error processing run isna6rgu: '_timestamp'
Error processing run um0dxn3y: '_timestamp'
Error processing run mje2wvj7: '_timestamp'
Error processing run evuxnltk: '_timestamp'


dri-ice/Composable_Interventions:  57%|█████▋    | 5001/8765 [00:39<00:38, 98.72it/s] 

Error processing run 1tr3e1m4: '_timestamp'
Error processing run 1z3gnedp: '_timestamp'
Error processing run 54ohzkld: '_timestamp'
Error processing run aq2oaxur: '_timestamp'
Error processing run 7oaha4p2: '_timestamp'
Error processing run a1ri8lbk: '_timestamp'
Error processing run bqi05lpe: '_timestamp'
Error processing run h512597d: '_timestamp'
Error processing run iv5ul1n7: '_timestamp'
Error processing run lmaany12: '_timestamp'
Error processing run s64e9fxo: '_timestamp'
Error processing run wt59w8pr: '_timestamp'
Error processing run 117s16id: '_timestamp'
Error processing run 2jggmjk7: '_timestamp'
Error processing run sblvym9w: '_timestamp'
Error processing run u2uaxi9s: '_timestamp'
Error processing run x21qup6h: '_timestamp'
Error processing run m9q8fawv: '_timestamp'
Error processing run n1nj8xq5: '_timestamp'
Error processing run j50re1is: '_timestamp'
Error processing run 9pywqyfi: '_timestamp'
Error processing run n36d9wfy: '_timestamp'
Error processing run tsostb8c: '

dri-ice/Composable_Interventions:  71%|███████▏  | 6251/8765 [00:50<00:18, 134.45it/s]

Error processing run ustm079w: '_timestamp'


dri-ice/Composable_Interventions:  72%|███████▏  | 6301/8765 [00:50<00:18, 131.92it/s]

Error processing run x4zeo5mn: '_timestamp'
Error processing run 4cx23357: '_timestamp'
Error processing run pzbmjnlx: '_timestamp'
Error processing run sgw1bwq3: '_timestamp'
Error processing run w44vemk4: '_timestamp'
Error processing run 4u2krnmr: '_timestamp'
Error processing run ezsnt0g5: '_timestamp'
Error processing run ilvfkzv1: '_timestamp'
Error processing run syb2xqv3: '_timestamp'


dri-ice/Composable_Interventions:  77%|███████▋  | 6751/8765 [00:54<00:20, 100.35it/s]

Error processing run b9l540xx: '_timestamp'
Error processing run choi1k81: '_timestamp'


dri-ice/Composable_Interventions:  78%|███████▊  | 6801/8765 [00:54<00:18, 104.36it/s]

Error processing run eqa5mopp: '_timestamp'


dri-ice/Composable_Interventions:  78%|███████▊  | 6851/8765 [00:55<00:17, 108.74it/s]

Error processing run 9buz1998: '_timestamp'
Error processing run 886uc8iw: '_timestamp'
Error processing run r9giysve: '_timestamp'
Error processing run mhvullqf: '_timestamp'
Error processing run e3h9gw6u: '_timestamp'
Error processing run xl827wv8: '_timestamp'
Error processing run sj3fikwq: '_timestamp'
Error processing run 5iuwtu93: '_timestamp'
Error processing run z0zdi9rt: '_timestamp'
Error processing run o96r1eri: '_timestamp'
Error processing run qpbntab7: '_timestamp'
Error processing run 0zajxvii: '_timestamp'
Error processing run 70jjkyi8: '_timestamp'
Error processing run mr5hepax: '_timestamp'
Error processing run tyyr84xe: '_timestamp'
Error processing run ubwgkh20: '_timestamp'
Error processing run s9y8jjqc: '_timestamp'
Error processing run wxhmtvc3: '_timestamp'
Error processing run j8q837ii: '_timestamp'
Error processing run y6su2u0y: '_timestamp'
Error processing run 50u5183e: '_timestamp'
Error processing run 1sv3orzy: '_timestamp'
Error processing run 4h1svpsz: '

dri-ice/Composable_Interventions: 100%|██████████| 8765/8765 [01:13<00:00, 119.63it/s]


In [5]:
def should_keep_frame(frame):
    if frame["edit_dataset"] == "zsre":
        return True
    
    if "edit" not in frame["interventions"]:
        return True
    
    print(f"Skipping {frame['tag']} for edit dataset {frame['edit_dataset']}")
    return False

# Sort by "tag" and "_timestamp" in descending order to have the most recent run first
all_runs_df = pd.concat(data_frames, ignore_index=True)[relevant_columns]
all_runs_df["interventions"] = all_runs_df["interventions"].astype(str)

# Keep only the current edit dataset
all_runs_df = all_runs_df[all_runs_df.progress_apply(lambda x: should_keep_frame(x), axis=1)]

# WARNING: WHAT DOES EDIT SET 50 MEAN COMPARED TO EDIT SET 1?
# all_runs_df = all_runs_df[all_runs_df["edit_set"] == 50]
# all_runs_df_sorted = all_runs_df.sort_values(by=["tag", "_timestamp"], ascending=[True, False])
all_runs_df["date"] = pd.to_datetime(all_runs_df["_timestamp"], unit="s")
all_runs_df_sorted = all_runs_df.sort_values(by=["_timestamp"], ascending=[False])
all_runs_df_sorted["Avg WMDP"] = (all_runs_df_sorted["wmdp_bio accuracy"] + all_runs_df_sorted["wmdp_cyber accuracy"]) / 2
all_runs_df_sorted = all_runs_df_sorted[all_runs_df_sorted["qa_question_count_limit"].isnull()]

  all_runs_df = pd.concat(data_frames, ignore_index=True)[relevant_columns]
100%|██████████| 6726/6726 [00:00<00:00, 399242.70it/s]

Skipping lora-to-SparseGPT0.45% for edit dataset mquake
Skipping lora-to-SparseGPT0.45% for edit dataset counterfact
Skipping lora_Edit for edit dataset mquake
Skipping lora-to-SparseGPT0.25% for edit dataset mquake
Skipping lora-to-SparseGPT0.25% for edit dataset counterfact
Skipping lora-to-Wanda0.25% for edit dataset counterfact
Skipping lora-to-Wanda0.25% for edit dataset mquake
Skipping Wanda0.25%-to-lora for edit dataset counterfact
Skipping Wanda0.25%-to-lora for edit dataset mquake
Skipping lora_Edit for edit dataset counterfact
Skipping lora-to-SparseGPT0.65% for edit dataset counterfact
Skipping lora-to-SparseGPT0.65% for edit dataset mquake
Skipping lora-to-Wanda0.65% for edit dataset mquake
Skipping lora-to-Wanda0.65% for edit dataset counterfact
Skipping lora-to-AWQ2bit for edit dataset mquake
Skipping AWQ4bit-to-lora for edit dataset mquake
Skipping AWQ8bit-to-lora for edit dataset counterfact
Skipping ft_Edit for edit dataset counterfact
Skipping ft-to-Wanda0.25% for edi




In [8]:
# Sort by the recency column, for example, "date"
all_runs_df_sorted = all_runs_df_sorted.sort_values(by="date")

# Drop duplicates, keeping only the most recent occurrence for each "tag" and "edit_set"
latest_runs_df = all_runs_df_sorted.drop_duplicates(subset=["model_name", "tag", "edit_set"], keep="last")

# Define a function to calculate standard error
def standard_error(x):
    return x.std() / np.sqrt(len(x))

# Group by the "tag" column and calculate the mean for numerical columns
# grouped_df = latest_runs_df.groupby(["model_name", "tag"]).agg(["mean", standard_error])
numerical_columns = latest_runs_df.select_dtypes(include=[np.number]).columns
print(f"Numerical columns: {list(numerical_columns)}")
grouped_df = latest_runs_df.groupby(["model_name", "tag"]).agg({ col: ["mean", standard_error] for col in numerical_columns })

# Flatten the multi-level columns
grouped_df.columns = [f"{col[0]}_{col[1]}" for col in grouped_df.columns]

# Split the columns into means and standard errors
mean_columns = [col for col in grouped_df.columns if col.endswith("_mean")]
se_columns = [col for col in grouped_df.columns if col.endswith("_standard_error")]

# Create separate DataFrames for means and standard errors
mean_df = grouped_df[mean_columns].rename(columns=lambda x: x.replace("_mean", ""))
se_df = grouped_df[se_columns].rename(columns=lambda x: x.replace("_standard_error", "_se"))

# Merge the means and standard errors back into one DataFrame
all_runs_df_sorted_averaged = pd.concat([mean_df, se_df], axis=1).copy()

# Reset index if needed
all_runs_df_sorted_averaged.reset_index(inplace=True)

# Add non-numerical columns from the latest_runs_df
non_numerical_columns = latest_runs_df.select_dtypes(exclude=[np.number]).drop_duplicates(subset=["model_name", "tag"])
all_runs_df_sorted_averaged = all_runs_df_sorted_averaged.merge(non_numerical_columns, on=["model_name", "tag"], how="left")

# Display the resulting DataFrame
all_runs_df_sorted_averaged

Numerical columns: ['_timestamp', 'edit_set', 'number_of_edits', 'rmu_layer_id', 'wbits', 'sparsity_ratio', 'mmlu accuracy', 'wmdp_bio accuracy', 'wmdp_cyber accuracy', 'Generalization', 'Success recall', 'Generalization recall', 'Locality', 'Average bits', 'Rewrite accuracy', 'Local recall', 'Latency', 'Avg WMDP']


Unnamed: 0,model_name,tag,_timestamp,edit_set,number_of_edits,rmu_layer_id,wbits,sparsity_ratio,mmlu accuracy,wmdp_bio accuracy,...,compression,edit_dataset,compression_dataset,qa_question_count_limit,PPL,PPL edits,PPl QA,FLOPs,PPl edits unmasked,date
0,01-ai/Yi-1.5-9B-Chat,AWQ2bit,1.727392e+09,1.0,50.0,-1.0,2.0,0.0,0.229454,0.246661,...,awq,zsre,c4,,8100276.5,8778113,7557045,-1,6354282,2024-09-26 23:14:58.334940910
1,01-ai/Yi-1.5-9B-Chat,AWQ2bit_MEMIT,1.727429e+09,1.0,50.0,-1.0,2.0,0.0,0.229454,0.246661,...,awq,zsre,c4,,8133088.5,8773501,7564110.5,-1,6352639.5,2024-09-27 09:23:35.068500519
2,01-ai/Yi-1.5-9B-Chat,AWQ3bit,1.727376e+09,1.0,50.0,-1.0,3.0,0.0,0.606110,0.611940,...,awq,zsre,c4,,7.125771,306.892822,105.87944,-1,125.973007,2024-09-26 18:43:24.532109499
3,01-ai/Yi-1.5-9B-Chat,AWQ3bit_MEMIT,1.727420e+09,1.0,50.0,-1.0,3.0,0.0,0.603618,0.606441,...,awq,zsre,c4,,7.163979,682.019836,108.887817,-1,214.402084,2024-09-27 06:52:38.689813137
4,01-ai/Yi-1.5-9B-Chat,AWQ4bit,1.727392e+09,1.0,50.0,-1.0,4.0,0.0,0.669634,0.662215,...,awq,zsre,c4,,6.47934,356.460541,84.696404,-1,112.177864,2024-09-26 23:09:41.927449703
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
546,mistralai/Mistral-7B-Instruct-v0.3,rmu-awq8bit,1.726225e+09,1.0,50.0,6.0,8.0,0.0,0.585743,0.306363,...,awq,zsre,c4,,4.734568,90627.210938,300.767761,-1,448.873871,2024-09-13 10:59:33.378082514
547,mistralai/Mistral-7B-Instruct-v0.3,rmu-ft,1.726134e+09,1.0,50.0,6.0,16.0,0.0,0.494873,0.282797,...,none,zsre,c4,,4.805961,56338.621094,857.478333,1.82 TFLOPS,1146.820679,2024-09-12 09:35:43.696657896
548,mistralai/Mistral-7B-Instruct-v0.3,rmu-lora,1.726146e+09,1.0,50.0,6.0,16.0,0.0,0.567227,0.294580,...,none,zsre,c4,,10.45087,249094.265625,3847.212158,1.79 TFLOPS,3661.797852,2024-09-12 12:55:28.729677916
549,mistralai/Mistral-7B-Instruct-v0.3,rmu-memit,1.726228e+09,1.0,50.0,6.0,16.0,0.0,0.560675,0.293009,...,none,zsre,c4,,4.74005,16454.134766,254.244461,1.82 TFLOPS,629.994202,2024-09-13 11:48:50.830707312


In [9]:
# The experiment tags can be inconsistent, so we need to manually map them to standard names
def set_tag(experiment_row):
    if experiment_row["interventions"] in [None, np.nan]:
        return "NONE"

    intervention_categories = None
    if isinstance(experiment_row["interventions"], str):
        intervention_categories = ast.literal_eval(experiment_row["interventions"])
    else:
        intervention_categories = experiment_row["interventions"]

    interventions = []
    for category in intervention_categories:
        category = "compression" if category == "compress" else category
        intervention = experiment_row[category].upper()
        if intervention in ["AWQ", "GPTQ"]:
            intervention += str(int(experiment_row["wbits"])) + "bit"
        if intervention in ["WANDA", "SPARSEGPT"]:
            intervention += str(int(experiment_row["sparsity_ratio"] * 100)) + "%"

        interventions.append(intervention)
    
    if len(interventions) == 0:
        interventions.append("NONE")

    return "_".join(interventions)

all_runs_df_sorted_averaged["tag"] = all_runs_df_sorted_averaged.progress_apply(set_tag, axis=1)
print(all_runs_df_sorted_averaged[all_runs_df_sorted_averaged["model_name"] == "mistralai/Mistral-7B-Instruct-v0.3"].value_counts(["tag", "model_name"]).sort_index())
# print(all_runs_df_sorted_averaged[["tag", "model_name"]].value_counts())

100%|██████████| 551/551 [00:00<00:00, 68601.92it/s]

tag             model_name                        
AWQ2bit         mistralai/Mistral-7B-Instruct-v0.3    1
AWQ2bit_MEMIT   mistralai/Mistral-7B-Instruct-v0.3    1
AWQ2bit_RMU     mistralai/Mistral-7B-Instruct-v0.3    1
AWQ3bit         mistralai/Mistral-7B-Instruct-v0.3    1
AWQ3bit_MEMIT   mistralai/Mistral-7B-Instruct-v0.3    1
                                                     ..
WANDA65%_MEMIT  mistralai/Mistral-7B-Instruct-v0.3    1
WANDA65%_RMU    mistralai/Mistral-7B-Instruct-v0.3    1
WANDA75%        mistralai/Mistral-7B-Instruct-v0.3    1
WANDA75%_MEMIT  mistralai/Mistral-7B-Instruct-v0.3    1
WANDA75%_RMU    mistralai/Mistral-7B-Instruct-v0.3    1
Name: count, Length: 78, dtype: int64





In [10]:
# set to None if empty dict
all_runs_df_sorted_averaged["edit"] = all_runs_df_sorted_averaged["edit"].apply(lambda x: None if x == {} else x)
all_runs_df_sorted_averaged["unlearn"] = all_runs_df_sorted_averaged["unlearn"].apply(lambda x: None if x == {} else x)
all_runs_df_sorted_averaged["compression"] = all_runs_df_sorted_averaged["compression"].apply(lambda x: None if x == {} else x)

In [11]:
# Drop duplicates, keeping only the first occurrence (which is the most recent due to sorting)
# all_runs_df_deduplicated = all_runs_df_sorted.drop_duplicates(subset=[col for col in setting_columns if col not in ["_timestamp", "tag", "date"]], keep="first")
all_runs_df_deduplicated = all_runs_df_sorted_averaged.drop_duplicates(subset=["model_name", "tag"], keep="first")
all_runs_df_deduplicated["interventions"] = all_runs_df_deduplicated["interventions"].apply(lambda x : ast.literal_eval(x))

rename_dict = {
    "meta-llama/Meta-Llama-3-8B" : "Llama-3 (8b)",
    "mistralai/Mistral-7B-Instruct-v0.3": "Mistral (7b)",
    "01-ai/Yi-1.5-9B-Chat": "Yi 1.5 (9b)",
    "microsoft/Phi-3-mini-4k-instruct": "Phi-3 (3.8b)",
    "ft" : "Fine-tune",
    "memit" : "MEMIT",
    "lora" : "LoRA",
    "wanda" : "Wanda",
    "sparsegpt" : "SparseGPT",
    "gptq" : "GPTQ",
    "awq" : "AWQ",
    "rmu" : "RMU",
    "ga": "GA",
    "gd": "GD",
}
all_runs_df_deduplicated["model_name"] = all_runs_df_deduplicated["model_name"].apply(lambda x : rename_dict.get(x, None))
all_runs_df_deduplicated["edit"] = all_runs_df_deduplicated["edit"].apply(lambda x : rename_dict.get(x, None))
all_runs_df_deduplicated["compression"] = all_runs_df_deduplicated["compression"].apply(lambda x : rename_dict.get(x, None))
all_runs_df_deduplicated["unlearn"] = all_runs_df_deduplicated["unlearn"].apply(lambda x : rename_dict.get(x, None))
all_runs_df_deduplicated = all_runs_df_deduplicated
display(all_runs_df_deduplicated.value_counts(["model_name", "tag"]).sort_index())
print(f"Number of experiments: {len(all_runs_df_deduplicated)}")
print(f"Experiments by Model: {all_runs_df_deduplicated['model_name'].value_counts()}")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  all_runs_df_deduplicated["interventions"] = all_runs_df_deduplicated["interventions"].apply(lambda x : ast.literal_eval(x))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  all_runs_df_deduplicated["model_name"] = all_runs_df_deduplicated["model_name"].apply(lambda x : rename_dict.get(x, None))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/index

model_name    tag           
Llama-3 (8b)  AWQ2bit           1
              AWQ2bit_FT        1
              AWQ2bit_GA        1
              AWQ2bit_GD        1
              AWQ2bit_LORA      1
                               ..
Yi 1.5 (9b)   WANDA55%_MEMIT    1
              WANDA65%          1
              WANDA65%_MEMIT    1
              WANDA75%          1
              WANDA75%_MEMIT    1
Name: count, Length: 549, dtype: int64

Number of experiments: 549
Experiments by Model: model_name
Llama-3 (8b)    311
Phi-3 (3.8b)    123
Mistral (7b)     78
Yi 1.5 (9b)      37
Name: count, dtype: int64


# Check for Missing Ablation Experments

## Mistral Ablation

In [12]:
unlearning_interventions = []
editing_interventions = ["memit"]
pruning_interventions = ["wanda"]
pruning_levels = [0.25, 0.35, 0.45, 0.55, 0.65, 0.75]
quant_interventions = ["awq"]
quant_levels = [2, 3, 4, 5, 6, 8]
experiment_combinations = []

# Unlearning and Editing
for unlearner in unlearning_interventions:
    for editor in editing_interventions:
        experiment_combinations.append({
            "interventions": ["unlearn", "edit"], "edit": editor, "unlearn": unlearner, "wbits": None, "sparsity_ratio": None, })
        experiment_combinations.append({
            "interventions": ["edit", "unlearn"], "edit": editor, "unlearn": unlearner, "wbits": None, "sparsity_ratio": None, })

# Unlearning and compression
for unlearner in unlearning_interventions:
    for pruner in pruning_interventions:
        for pruning_level in pruning_levels:
            experiment_combinations.append({
                "interventions": ["unlearn", "compression"], "edit": None, "unlearn": unlearner, "compression": pruner, "sparsity_ratio": pruning_level, })
            experiment_combinations.append({
                "interventions": ["compression", "unlearn"], "edit": None, "unlearn": unlearner, "compression": pruner, "sparsity_ratio": pruning_level, })

    for quantizer in quant_interventions:
        for quant_level in quant_levels:
            experiment_combinations.append({
                "interventions": ["unlearn", "compression"], "edit": None, "unlearn": unlearner, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })
            experiment_combinations.append({
                "interventions": ["compression", "unlearn"], "edit": None, "unlearn": unlearner, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })

# Editing and compression
for editor in editing_interventions:
    for pruner in pruning_interventions:
        for pruning_level in pruning_levels:
            experiment_combinations.append({
                "interventions": ["edit", "compression"], "edit": editor, "unlearn": None, "compression": pruner, "sparsity_ratio": pruning_level, })
            experiment_combinations.append({
                "interventions": ["compression", "edit"], "edit": editor, "unlearn": None, "compression": pruner, "sparsity_ratio": pruning_level, })

    for quantizer in quant_interventions:
        for quant_level in quant_levels:
            experiment_combinations.append({
                "interventions": ["edit", "compression"], "edit": editor, "unlearn": None, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })
            experiment_combinations.append({
                "interventions": ["compression", "edit"], "edit": editor, "unlearn": None, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })

# No interventions
experiment_combinations.append({
    "interventions": [], "edit": None, "unlearn": None, "compression": None, "wbits": None, "sparsity_ratio": None })

# Just edit
for editor in editing_interventions:
    experiment_combinations.append({
        "interventions": ["edit"], "edit": editor, "unlearn": None, "compression": None, "wbits": None, "sparsity_ratio": None })

# Just unlearn
for unlearner in unlearning_interventions:
    experiment_combinations.append({
        "interventions": ["unlearn"], "edit": None, "unlearn": unlearner, "compression": None, "wbits": None, "sparsity_ratio": None })

# Just pruning
for pruner in pruning_interventions:
    for pruning_level in pruning_levels:
        experiment_combinations.append({
            "interventions": ["compression"], "edit": None, "unlearn": None, "compression": pruner, "sparsity_ratio": pruning_level, })

# Just quantization
for quantizer in quant_interventions:
    for quant_level in quant_levels:
        experiment_combinations.append({
            "interventions": ["compression"], "edit": None, "unlearn": None, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })

print(f"Total experiment combinations for Mistral ablation: {len(experiment_combinations)}")
# model_name = "Mistral (7b)"
model_name = "Yi 1.5 (9b)"
mistral_tags = set(all_runs_df_deduplicated[all_runs_df_deduplicated["model_name"] == model_name]["tag"].unique())
experiment_statuses = []
for experiment in experiment_combinations:
    experiment_tag = set_tag(experiment)
    experiment_statuses.append({
        "experiment_tag": experiment_tag,
        "first_intervention": experiment_tag.split("_")[0],
        "second_intervention": experiment_tag.split("_")[1] if len(experiment_tag.split("_")) > 1 else None,
        "completed": experiment_tag in mistral_tags,
    })

experiment_statuses_df = pd.DataFrame(experiment_statuses)
experiment_statuses_df.to_csv("mistral_experiment_statuses.csv", index=False)

# Percent of experiments completed
print(f"Percent of experiments completed: {experiment_statuses_df['completed'].mean() * 100:.2f}%")

# Missing experiments
missing_experiments = experiment_statuses_df[experiment_statuses_df["completed"] == False].sort_values(by="experiment_tag")
display(missing_experiments)

Total experiment combinations for Mistral ablation: 38
Percent of experiments completed: 97.37%


Unnamed: 0,experiment_tag,first_intervention,second_intervention,completed
25,MEMIT,MEMIT,,False


## Yi Ablation

In [13]:
unlearning_interventions = []
editing_interventions = ["memit"]
pruning_interventions = ["wanda"]
pruning_levels = [0.25, 0.35, 0.45, 0.55, 0.65, 0.75]
quant_interventions = ["awq"]
quant_levels = [2, 3, 4, 5, 6, 8]
experiment_combinations = []

# Unlearning and Editing
for unlearner in unlearning_interventions:
    for editor in editing_interventions:
        experiment_combinations.append({
            "interventions": ["unlearn", "edit"], "edit": editor, "unlearn": unlearner, "wbits": None, "sparsity_ratio": None, })
        experiment_combinations.append({
            "interventions": ["edit", "unlearn"], "edit": editor, "unlearn": unlearner, "wbits": None, "sparsity_ratio": None, })

# Unlearning and compression
for unlearner in unlearning_interventions:
    for pruner in pruning_interventions:
        for pruning_level in pruning_levels:
            experiment_combinations.append({
                "interventions": ["unlearn", "compression"], "edit": None, "unlearn": unlearner, "compression": pruner, "sparsity_ratio": pruning_level, })
            experiment_combinations.append({
                "interventions": ["compression", "unlearn"], "edit": None, "unlearn": unlearner, "compression": pruner, "sparsity_ratio": pruning_level, })

    for quantizer in quant_interventions:
        for quant_level in quant_levels:
            experiment_combinations.append({
                "interventions": ["unlearn", "compression"], "edit": None, "unlearn": unlearner, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })
            experiment_combinations.append({
                "interventions": ["compression", "unlearn"], "edit": None, "unlearn": unlearner, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })

# Editing and compression
for editor in editing_interventions:
    for pruner in pruning_interventions:
        for pruning_level in pruning_levels:
            experiment_combinations.append({
                "interventions": ["edit", "compression"], "edit": editor, "unlearn": None, "compression": pruner, "sparsity_ratio": pruning_level, })
            experiment_combinations.append({
                "interventions": ["compression", "edit"], "edit": editor, "unlearn": None, "compression": pruner, "sparsity_ratio": pruning_level, })

    for quantizer in quant_interventions:
        for quant_level in quant_levels:
            experiment_combinations.append({
                "interventions": ["edit", "compression"], "edit": editor, "unlearn": None, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })
            experiment_combinations.append({
                "interventions": ["compression", "edit"], "edit": editor, "unlearn": None, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })

# No interventions
experiment_combinations.append({
    "interventions": [], "edit": None, "unlearn": None, "compression": None, "wbits": None, "sparsity_ratio": None })

# Just edit
for editor in editing_interventions:
    experiment_combinations.append({
        "interventions": ["edit"], "edit": editor, "unlearn": None, "compression": None, "wbits": None, "sparsity_ratio": None })

# Just unlearn
for unlearner in unlearning_interventions:
    experiment_combinations.append({
        "interventions": ["unlearn"], "edit": None, "unlearn": unlearner, "compression": None, "wbits": None, "sparsity_ratio": None })

# Just pruning
for pruner in pruning_interventions:
    for pruning_level in pruning_levels:
        experiment_combinations.append({
            "interventions": ["compression"], "edit": None, "unlearn": None, "compression": pruner, "sparsity_ratio": pruning_level, })

# Just quantization
for quantizer in quant_interventions:
    for quant_level in quant_levels:
        experiment_combinations.append({
            "interventions": ["compression"], "edit": None, "unlearn": None, "compression": quantizer, "wbits": quant_level, "sparsity_ratio": None, })

print(f"Total experiment combinations for Yi ablation: {len(experiment_combinations)}")
yi_tags = set(all_runs_df_deduplicated[all_runs_df_deduplicated["model_name"] == "Yi 1.5 (9b)"]["tag"].unique())
experiment_statuses = []
for experiment in experiment_combinations:
    experiment_tag = set_tag(experiment)
    experiment_statuses.append({
        "experiment_tag": experiment_tag,
        "first_intervention": experiment_tag.split("_")[0],
        "second_intervention": experiment_tag.split("_")[1] if len(experiment_tag.split("_")) > 1 else None,
        "completed": experiment_tag in yi_tags,
    })

experiment_statuses_df = pd.DataFrame(experiment_statuses)
experiment_statuses_df.to_csv("mistral_experiment_statuses.csv", index=False)

# Percent of experiments completed
print(f"Percent of experiments completed: {experiment_statuses_df['completed'].mean() * 100:.2f}%")

# Missing experiments
missing_experiments = experiment_statuses_df[experiment_statuses_df["completed"] == False].sort_values(by="experiment_tag")
display(missing_experiments)

Total experiment combinations for Yi ablation: 38
Percent of experiments completed: 97.37%


Unnamed: 0,experiment_tag,first_intervention,second_intervention,completed
25,MEMIT,MEMIT,,False


## Check Main Results Experiment Status

In [14]:
# Math for determining number of interventions
awq_settings = 6
gptq_settings = 4 # only support quantize to [2, 3, 4, 8] bits.
wanda_count = 6
sparsegpt_count = 6
editor_settings = 3
composition_factor = 2

editor_count = composition_factor * (awq_settings + gptq_settings + wanda_count + sparsegpt_count + 1) * editor_settings
print(editor_count // 2)

rmu_count = composition_factor * (awq_settings + gptq_settings + wanda_count + sparsegpt_count + editor_settings)
print(rmu_count)

69
50


In [15]:
all_runs_df_deduplicated["unlearn"].value_counts()

unlearn
RMU    82
GD     54
GA     51
Name: count, dtype: int64

In [16]:
NUM_UNLEARNING = 3
NUM_EDITING = 3
NUM_COMPRESSION = 4 + 6 + 6 + 6
combination_of_unlearning = 2 * NUM_UNLEARNING * NUM_COMPRESSION
combination_of_unlearning

132

In [17]:

main_results = all_runs_df_deduplicated[all_runs_df_deduplicated["model_name"] == "Llama-3 (8b)"]

categories = {
    "No Intervention": main_results[main_results["interventions"].apply(lambda x: x == [])].copy(),
    "Editing": main_results[main_results["interventions"].apply(lambda x: x == ["edit"])].copy(),
    "Compression": main_results[main_results["interventions"].apply(lambda x: x == ["compress"])].copy(),
    "Edit to Compression": main_results[main_results["interventions"].apply(lambda x: x == ["edit", "compress"])].copy(),
    "Compression to Edit": main_results[main_results["interventions"].apply(lambda x: x == ["compress", "edit"])].copy(),
    "Unlearn": main_results[main_results["interventions"].apply(lambda x: x == ["unlearn"])].copy(),
    "Edit to Unlearn": main_results[main_results["interventions"].apply(lambda x: x == ["edit", "unlearn"])].copy(),
    "Unlearn to Edit": main_results[main_results["interventions"].apply(lambda x: x == ["unlearn", "edit"])].copy(),
    "Compress to Unlearn": main_results[main_results["interventions"].apply(lambda x: x == ["compress", "unlearn"])].copy(),
    "Unlearn to Compress": main_results[main_results["interventions"].apply(lambda x: x == ["unlearn", "compress"])].copy()
}

assert len(categories["No Intervention"]) == 1, f"{len(categories['No Intervention'])} != 1"
assert len(categories["Editing"]) == 3, f"{len(categories['Editing'])} != 3"

# display(categories["Compression"])
assert len(categories["Compression"]) == (awq_settings + gptq_settings + wanda_count + sparsegpt_count), f"{len(categories['Compression'])} != {awq_settings + gptq_settings + wanda_count + sparsegpt_count}"

# assert len(categories["Edit to Compression"]) == editor_count // 2, f"{len(categories['Edit to Compression'])} != {editor_count // 2}"

assert len(categories["Compression to Edit"]) == (editor_count // 2 ) - 3, f"{len(categories['Compression to Edit'])} != {editor_count // 2}" # TODO: Fix this by getting the latest results
assert len(categories["Unlearn"]) == 3, f"{len(categories['Unlearn'])} != 3"
assert len(categories["Edit to Unlearn"]) == 9, f"{len(categories['Edit to Unlearn'])} != 9"
assert len(categories["Unlearn to Edit"]) == 9, f"{len(categories['Unlearn to Edit'])} != 9"

display(categories["Compress to Unlearn"])
assert len(categories["Compress to Unlearn"]) == combination_of_unlearning // 2, f"{len(categories['Compress to Unlearn'])} != {combination_of_unlearning // 2}"

display(categories["Unlearn to Compress"])
assert len(categories["Unlearn to Compress"]) == combination_of_unlearning // 2, f"{len(categories['Unlearn to Compress'])} != {rmucombination_of_unlearning_count // 2}"

Unnamed: 0,model_name,tag,_timestamp,edit_set,number_of_edits,rmu_layer_id,wbits,sparsity_ratio,mmlu accuracy,wmdp_bio accuracy,...,compression,edit_dataset,compression_dataset,qa_question_count_limit,PPL,PPL edits,PPl QA,FLOPs,PPl edits unmasked,date
126,Llama-3 (8b),AWQ2bit_GA,1.718194e+09,1.000000,50.0,-1.000000,2.0,0.00,0.268908,0.240377,...,AWQ,zsre,c4,,Infinity,Infinity,Infinity,-1,Infinity,2024-06-12 12:02:27.059515953
127,Llama-3 (8b),AWQ2bit_GD,1.718204e+09,1.000000,50.0,-1.000000,2.0,0.00,0.268908,0.240377,...,AWQ,zsre,c4,,4496028624899119061335958001549312.0,2617349.75,6866961.5,-1,5905774.5,2024-06-12 14:48:16.610423088
128,Llama-3 (8b),AWQ2bit_RMU,1.717570e+09,9.545455,50.0,3.181818,2.0,0.00,0.268908,0.240377,...,AWQ,zsre,c4,,1749321.75,1055937.75,999726.5,-1,915356.5,2024-05-20 19:19:14.861625671
129,Llama-3 (8b),AWQ3bit_GA,1.718194e+09,1.000000,50.0,-1.000000,3.0,0.00,0.272896,0.241948,...,AWQ,zsre,c4,,Infinity,Infinity,Infinity,-1,Infinity,2024-06-12 12:03:07.964953661
130,Llama-3 (8b),AWQ3bit_GD,1.718204e+09,1.000000,50.0,-1.000000,3.0,0.00,0.341832,0.245876,...,AWQ,zsre,c4,,166895.90625,83131.453125,20670.785156,-1,3257.42627,2024-06-12 14:47:07.710373163
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
343,Llama-3 (8b),WANDA65%_GD,1.718204e+09,1.000000,50.0,-1.000000,4.0,0.65,0.229098,0.249018,...,Wanda,zsre,c4,,16268.005859,5423145.5,81961160,447.25 GFLOPS,89329664,2024-06-12 14:53:38.419730186
344,Llama-3 (8b),WANDA65%_RMU,1.716764e+09,18.666667,50.0,3.666667,4.0,0.65,0.229360,0.249542,...,Wanda,zsre,c4,,45.534767,83373.453125,1477.523804,760 GFLOPS,1549.510742,2024-05-20 19:27:06.522696495
345,Llama-3 (8b),WANDA75%_GA,1.718194e+09,1.000000,50.0,-1.000000,4.0,0.75,0.229811,0.249018,...,Wanda,zsre,c4,,8125971838094910145582358934348365824.0,176520989218511062933962752.0,31936578699555629957120.0,357.84 GFLOPS,31529243875858664390656.0,2024-06-12 12:09:17.554701090
346,Llama-3 (8b),WANDA75%_GD,1.718204e+09,1.000000,50.0,-1.000000,4.0,0.75,0.246902,0.247447,...,Wanda,zsre,c4,,10226426,743558.8125,145133.09375,357.84 GFLOPS,128967.648438,2024-06-12 14:55:21.376342773


Unnamed: 0,model_name,tag,_timestamp,edit_set,number_of_edits,rmu_layer_id,wbits,sparsity_ratio,mmlu accuracy,wmdp_bio accuracy,...,compression,edit_dataset,compression_dataset,qa_question_count_limit,PPL,PPL edits,PPl QA,FLOPs,PPl edits unmasked,date
170,Llama-3 (8b),GA_AWQ2bit,1.718191e+09,1.000000,50.0,-1.000000,2.0,0.00,0.246546,0.247447,...,AWQ,zsre,c4,,1769133.875,1243236.75,1118986.25,-1,1245522.5,2024-06-12 11:13:37.028556824
171,Llama-3 (8b),GA_AWQ3bit,1.718191e+09,1.000000,50.0,-1.000000,3.0,0.00,0.441675,0.492537,...,AWQ,zsre,c4,,Infinity,Infinity,Infinity,-1,Infinity,2024-06-12 11:17:15.184798717
172,Llama-3 (8b),GA_AWQ4bit,1.718191e+09,1.000000,50.0,-1.000000,4.0,0.00,0.465318,0.524745,...,AWQ,zsre,c4,,Infinity,Infinity,Infinity,-1,Infinity,2024-06-12 11:20:13.570003510
173,Llama-3 (8b),GA_AWQ5bit,1.718191e+09,1.000000,50.0,-1.000000,5.0,0.00,0.498576,0.556952,...,AWQ,zsre,c4,,Infinity,Infinity,Infinity,-1,Infinity,2024-06-12 11:20:19.390620708
174,Llama-3 (8b),GA_AWQ6bit,1.718191e+09,1.000000,50.0,-1.000000,6.0,0.00,0.489389,0.559309,...,AWQ,zsre,c4,,Infinity,Infinity,Infinity,-1,Infinity,2024-06-12 11:21:07.118664503
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
307,Llama-3 (8b),RMU_WANDA35%,1.717690e+09,5.500000,50.0,3.000000,4.0,0.35,0.532602,0.258759,...,Wanda,zsre,c4,,6.360727,58684.835938,459.011627,1.3 TFLOPS,647.534851,2024-06-06 13:50:28.648183107
308,Llama-3 (8b),RMU_WANDA45%,1.717546e+09,10.000000,50.0,3.200000,4.0,0.45,0.478009,0.272270,...,Wanda,zsre,c4,,7.707963,25395.089844,538.735168,1.12 TFLOPS,620.10907,2024-05-20 19:02:13.503230572
309,Llama-3 (8b),RMU_WANDA55%,1.717684e+09,5.500000,50.0,3.000000,4.0,0.55,0.350570,0.259309,...,Wanda,zsre,c4,,13.729619,46809.050781,728.273926,938.82 GFLOPS,732.329468,2024-06-06 14:15:58.615832090
310,Llama-3 (8b),RMU_WANDA65%,1.717552e+09,9.545455,50.0,3.181818,4.0,0.65,0.229500,0.242020,...,Wanda,zsre,c4,,52.669495,78055.304688,1446.553223,760 GFLOPS,1535.829468,2024-05-20 19:06:01.385294199


In [18]:
# Define intervention names and types
intervention_names = [intervention for intervention in list(main_results["edit"].unique()) + list(main_results["unlearn"].unique()) + list(main_results["compression"].unique()) if intervention is not None]
intervention_type = {
    "LoRA": "edit",
    "MEMIT": "edit",
    "Fine-tune": "edit",
    "SparseGPT": "compression",
    "Wanda": "compression",
    "GPTQ": "compression",
    "AWQ": "compression",
    "RMU": "unlearn",
    "GA": "unlearn",
    "GD": "unlearn",
}

# Initialize heatmap main_results frames with default values
default_value = None
mmlu_oi_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)
wmdp_oi_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)
edit_oi_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)
generalization_oi_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)
locality_oi_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)

# Initialize max value main_results frames
mmlu_mce_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)
wmdp_mce_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)
edit_mce_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)
generalization_mce_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)
locality_mce_main_results = pd.DataFrame(index=intervention_names, columns=intervention_names, dtype=float, data=default_value)

# Populate the heatmap and max value main_results frames
for first_intervention in intervention_names:
    for second_intervention in intervention_names:
        first_intervention_type = intervention_type[first_intervention]
        second_intervention_type = intervention_type[second_intervention]
        if first_intervention_type == second_intervention_type:
            continue

        compositions = main_results[(main_results[first_intervention_type] == first_intervention) & (main_results[second_intervention_type] == second_intervention)]
        if first_intervention in ["SparseGPT", "Wanda"] or second_intervention in ["SparseGPT", "Wanda"]:
            compositions = compositions[compositions["sparsity_ratio"] == 0.25]
        elif first_intervention in ["GPTQ", "AWQ"] or second_intervention in ["GPTQ", "AWQ"]:
            compositions = compositions[compositions["wbits"] == 4]
        
        assert len(compositions) == 2, f"Expected 2 compositions for {first_intervention} and {second_intervention}, but found {len(compositions)}"
        
        # Calculate OIs
        mmlu_diff = abs(compositions["mmlu accuracy"].iloc[0] - compositions["mmlu accuracy"].iloc[1]).round(4)
        mmlu_oi_main_results[first_intervention][second_intervention] = mmlu_diff
        
        avg_wmdp_diff = abs(((compositions.iloc[0]["wmdp_cyber accuracy"] + compositions.iloc[0]["wmdp_bio accuracy"]) / 2) - ((compositions.iloc[1]["wmdp_cyber accuracy"] + compositions.iloc[1]["wmdp_bio accuracy"]) / 2)).round(4)
        wmdp_oi_main_results[first_intervention][second_intervention] = avg_wmdp_diff
        
        edit_diff = abs(compositions["Rewrite accuracy"].iloc[0] - compositions["Rewrite accuracy"].iloc[1]).round(4)
        edit_oi_main_results[first_intervention][second_intervention] = edit_diff

        generalization_diff = abs(compositions["Generalization"].iloc[0] - compositions["Generalization"].iloc[1]).round(4)
        generalization_oi_main_results[first_intervention][second_intervention] = generalization_diff

        locality_diff = abs(compositions["Locality"].iloc[0] - compositions["Locality"].iloc[1]).round(4)
        locality_oi_main_results[first_intervention][second_intervention] = locality_diff
        
        # Calculate MCE values
        mmlu_mce = 1 - max(compositions["mmlu accuracy"].iloc[0], compositions["mmlu accuracy"].iloc[1]).round(4)
        mmlu_mce_main_results[first_intervention][second_intervention] = mmlu_mce
        
        avg_wmdp_acc = min((compositions.iloc[0]["wmdp_cyber accuracy"] + compositions.iloc[0]["wmdp_bio accuracy"]) / 2, (compositions.iloc[1]["wmdp_cyber accuracy"] + compositions.iloc[1]["wmdp_bio accuracy"]) / 2).round(4)
        wmdp_mce_main_results[first_intervention][second_intervention] = avg_wmdp_acc
        
        edit_mce = 1 - max(compositions["Rewrite accuracy"].iloc[0], compositions["Rewrite accuracy"].iloc[1]).round(4)
        edit_mce_main_results[first_intervention][second_intervention] = edit_mce

        generalization_mce = 1 - max(compositions["Generalization"].iloc[0], compositions["Generalization"].iloc[1]).round(4)
        generalization_mce_main_results[first_intervention][second_intervention] = generalization_mce

        locality_mce = 1 - max(compositions["Locality"].iloc[0], compositions["Locality"].iloc[1]).round(4)
        locality_mce_main_results[first_intervention][second_intervention] = locality_mce

# Display the results
print("MMLU OI")
display(mmlu_oi_main_results)

print("MMLU MCE Values")
display(mmlu_mce_main_results)

print("WMDP OI")
display(wmdp_oi_main_results)

print("WMDP MCE Values")
display(wmdp_mce_main_results)

print("Rewrite OI")
display(edit_oi_main_results)

print("Rewrite MCE Values")
display(edit_mce_main_results)

print("Generalization OI")
display(generalization_oi_main_results)

print("Generalization MCE Values")
display(generalization_mce_main_results)

print("Locality OI")
display(locality_oi_main_results)

print("Locality MCE Values")
display(locality_mce_main_results)

MMLU OI


You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  mmlu_oi_main_results[first_intervention][second_intervention] = mmlu_diff
You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work t

Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.0071,0.182,0.0122,0.0174,0.0032,0.0037,0.0053
LoRA,,,,0.0706,0.1414,0.039,0.0023,0.012,0.0001,0.0009
MEMIT,,,,0.0409,0.2225,0.0025,0.0084,0.0139,0.004,0.0034
GA,0.0071,0.0706,0.0409,,,,0.0174,0.0921,0.058,0.0234
GD,0.182,0.1414,0.2225,,,,0.07,0.2398,0.0927,0.01
RMU,0.0122,0.039,0.0025,,,,0.0085,0.05,0.008,0.0148
AWQ,0.0174,0.0023,0.0084,0.0174,0.07,0.0085,,,,
GPTQ,0.0032,0.012,0.0139,0.0921,0.2398,0.05,,,,
SparseGPT,0.0037,0.0001,0.004,0.058,0.0927,0.008,,,,
Wanda,0.0053,0.0009,0.0034,0.0234,0.01,0.0148,,,,


MMLU MCE Values


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.4745,0.41,0.4305,0.4094,0.4138,0.3978,0.3976
LoRA,,,,0.6435,0.413,0.4443,0.4144,0.4189,0.4039,0.4036
MEMIT,,,,0.5074,0.4207,0.4381,0.4142,0.4109,0.4037,0.4044
GA,0.4745,0.6435,0.5074,,,,0.5347,0.5246,0.4917,0.4591
GD,0.41,0.413,0.4207,,,,0.5803,0.4358,0.4227,0.42
RMU,0.4305,0.4443,0.4381,,,,0.4383,0.4205,0.4341,0.4288
AWQ,0.4094,0.4144,0.4142,0.5347,0.5803,0.4383,,,,
GPTQ,0.4138,0.4189,0.4109,0.5246,0.4358,0.4205,,,,
SparseGPT,0.3978,0.4039,0.4037,0.4917,0.4227,0.4341,,,,
Wanda,0.3976,0.4036,0.4044,0.4591,0.42,0.4288,,,,


WMDP OI


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.0046,0.0001,0.045,0.0144,0.009,0.0026,0.0039
LoRA,,,,0.0661,0.2437,0.0022,0.0006,0.0097,0.0019,0.0008
MEMIT,,,,0.0508,0.0205,0.0086,0.0023,0.005,0.0016,0.0014
GA,0.0046,0.0661,0.0508,,,,0.0232,0.088,0.03,0.0285
GD,0.0001,0.2437,0.0205,,,,0.0012,0.023,0.2218,0.1733
RMU,0.045,0.0022,0.0086,,,,0.015,0.1801,0.0241,0.0332
AWQ,0.0144,0.0006,0.0023,0.0232,0.0012,0.015,,,,
GPTQ,0.009,0.0097,0.005,0.088,0.023,0.1801,,,,
SparseGPT,0.0026,0.0019,0.0016,0.03,0.2218,0.0241,,,,
Wanda,0.0039,0.0008,0.0014,0.0285,0.1733,0.0332,,,,


WMDP MCE Values


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.4664,0.2937,0.2751,0.5341,0.5423,0.5525,0.5586
LoRA,,,,0.2762,0.2993,0.2947,0.5472,0.5341,0.5434,0.5469
MEMIT,,,,0.3966,0.2557,0.2928,0.5516,0.5477,0.557,0.5573
GA,0.4664,0.2762,0.3966,,,,0.4257,0.3443,0.4318,0.4586
GD,0.2937,0.2993,0.2557,,,,0.2434,0.2438,0.275,0.3475
RMU,0.2751,0.2947,0.2928,,,,0.2733,0.2667,0.285,0.2854
AWQ,0.5341,0.5472,0.5516,0.4257,0.2434,0.2733,,,,
GPTQ,0.5423,0.5341,0.5477,0.3443,0.2438,0.2667,,,,
SparseGPT,0.5525,0.5434,0.557,0.4318,0.275,0.285,,,,
Wanda,0.5586,0.5469,0.5573,0.4586,0.3475,0.2854,,,,


Rewrite OI


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.0686,0.6705,0.009,0.0376,0.357,0.0345,0.0259
LoRA,,,,0.996,0.5587,0.0004,0.0835,0.3706,0.0148,0.0798
MEMIT,,,,0.4821,0.3992,0.0125,0.0099,0.0323,0.0706,0.0711
GA,0.0686,0.996,0.4821,,,,0.0,0.0,0.0,0.0
GD,0.6705,0.5587,0.3992,,,,0.0117,0.0067,0.0044,0.0
RMU,0.009,0.0004,0.0125,,,,0.0059,0.0162,0.0141,0.0005
AWQ,0.0376,0.0835,0.0099,0.0,0.0117,0.0059,,,,
GPTQ,0.357,0.3706,0.0323,0.0,0.0067,0.0162,,,,
SparseGPT,0.0345,0.0148,0.0706,0.0,0.0044,0.0141,,,,
Wanda,0.0259,0.0798,0.0711,0.0,0.0,0.0005,,,,


Rewrite MCE Values


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.9314,0.005,0.0017,0.0053,0.2061,0.0092,0.013
LoRA,,,,0.004,0.004,0.004,0.004,0.4199,0.1374,0.0519
MEMIT,,,,0.5179,0.0696,0.0313,0.071,0.1554,0.059,0.0541
GA,0.9314,0.004,0.5179,,,,1.0,1.0,1.0,1.0
GD,0.005,0.004,0.0696,,,,0.9883,0.9933,0.9933,1.0
RMU,0.0017,0.004,0.0313,,,,0.9759,0.9794,0.9784,0.9804
AWQ,0.0053,0.004,0.071,1.0,0.9883,0.9759,,,,
GPTQ,0.2061,0.4199,0.1554,1.0,0.9933,0.9794,,,,
SparseGPT,0.0092,0.1374,0.059,1.0,0.9933,0.9784,,,,
Wanda,0.013,0.0519,0.0541,1.0,1.0,0.9804,,,,


Generalization OI


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.0401,0.5607,0.0285,0.0801,0.3564,0.0231,0.03
LoRA,,,,0.7835,0.4808,0.0393,0.14,0.2672,0.0713,0.0059
MEMIT,,,,0.4096,0.4077,0.0359,0.0033,0.1089,0.0357,0.0549
GA,0.0401,0.7835,0.4096,,,,0.0,0.0,0.0,0.0
GD,0.5607,0.4808,0.4077,,,,0.0,0.0067,0.0067,0.0067
RMU,0.0285,0.0393,0.0359,,,,0.0027,0.0023,0.0032,0.0057
AWQ,0.0801,0.14,0.0033,0.0,0.0,0.0027,,,,
GPTQ,0.3564,0.2672,0.1089,0.0,0.0067,0.0023,,,,
SparseGPT,0.0231,0.0713,0.0357,0.0,0.0067,0.0032,,,,
Wanda,0.03,0.0059,0.0549,0.0,0.0067,0.0057,,,,


Generalization MCE Values


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.9599,0.1892,0.1848,0.1507,0.401,0.221,0.2066
LoRA,,,,0.2165,0.2915,0.2852,0.2592,0.5944,0.4767,0.4623
MEMIT,,,,0.5904,0.1067,0.0703,0.1095,0.1977,0.1349,0.1039
GA,0.9599,0.2165,0.5904,,,,1.0,1.0,1.0,1.0
GD,0.1892,0.2915,0.1067,,,,1.0,0.9933,0.9933,0.9933
RMU,0.1848,0.2852,0.0703,,,,0.9772,0.9801,0.977,0.9765
AWQ,0.1507,0.2592,0.1095,1.0,1.0,0.9772,,,,
GPTQ,0.401,0.5944,0.1977,1.0,0.9933,0.9801,,,,
SparseGPT,0.221,0.4767,0.1349,1.0,0.9933,0.977,,,,
Wanda,0.2066,0.4623,0.1039,1.0,0.9933,0.9765,,,,


Locality OI


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,0.0,0.0413,0.0337,0.0369,0.0687,0.0299,0.0325
LoRA,,,,0.0255,0.0344,0.0075,0.0271,0.0066,0.001,0.0081
MEMIT,,,,0.0005,0.0119,0.0008,0.002,0.0008,0.0029,0.0014
GA,0.0,0.0255,0.0005,,,,0.0,0.0,0.0,0.0
GD,0.0413,0.0344,0.0119,,,,0.0384,0.0072,0.0059,0.0279
RMU,0.0337,0.0075,0.0008,,,,0.0032,0.0063,0.001,0.0005
AWQ,0.0369,0.0271,0.002,0.0,0.0384,0.0032,,,,
GPTQ,0.0687,0.0066,0.0008,0.0,0.0072,0.0063,,,,
SparseGPT,0.0299,0.001,0.0029,0.0,0.0059,0.001,,,,
Wanda,0.0325,0.0081,0.0014,0.0,0.0279,0.0005,,,,


Locality MCE Values


Unnamed: 0,Fine-tune,LoRA,MEMIT,GA,GD,RMU,AWQ,GPTQ,SparseGPT,Wanda
Fine-tune,,,,1.0,0.8858,0.8666,0.8545,0.8931,0.8941,0.8952
LoRA,,,,0.9745,0.9172,0.9475,0.93,0.9476,0.9583,0.9606
MEMIT,,,,0.9995,0.9537,0.9652,0.9639,0.9648,0.9649,0.9643
GA,1.0,0.9745,0.9995,,,,1.0,1.0,1.0,1.0
GD,0.8858,0.9172,0.9537,,,,0.9616,0.9897,0.9789,0.9698
RMU,0.8666,0.9475,0.9652,,,,0.9658,0.9719,0.9644,0.9623
AWQ,0.8545,0.93,0.9639,1.0,0.9616,0.9658,,,,
GPTQ,0.8931,0.9476,0.9648,1.0,0.9897,0.9719,,,,
SparseGPT,0.8941,0.9583,0.9649,1.0,0.9789,0.9644,,,,
Wanda,0.8952,0.9606,0.9643,1.0,0.9698,0.9623,,,,


# Tables

In [19]:
compression_order = ["Wanda", "SparseGPT", "AWQ", "GPTQ"]
editor_order = ["Fine-tune", "MEMIT", "LoRA"]
unlearn_order = ["GA", "GD", "RMU"]

In [20]:
def format_value(value):
    if pd.isnull(value):
        return ''
    elif value > .995:
        return '1'
    else:
        return f'{value:.2f}'[1:] if value < 1 else f'{value:.2f}'

def latex_bold_if_min(value: str, max_value: float):
    return f'\\textbf{{{value}}}' if value == format_value(min_value) else value

## KE ←→ Compression

### Single Row Table

In [21]:
def generate_latex_table_ke_mc(edit_mce_df, edit_oi_df, gen_mce_df, gen_oi_df, mmlu_mce_df, mmlu_oi_df, edit_interventions, mmlu_interventions):
    latex_code = r"""
    \begin{tabular}{lcccccccccccccccccc}
        \toprule
        & \multicolumn{6}{c}{\textbf{Edit Success}} & \multicolumn{6}{c}{\textbf{Generalization}} & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13} \cmidrule(lr){14-19}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13} \cmidrule(lr){14-16} \cmidrule(lr){17-19}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
"""
    table_values = []

    for compressor in compression_order:
        if compressor == "AWQ":
            latex_code += r"        \cdashlinelr{1-19}" + "\n"

        row_values = []
        table_row = f"        {compressor}"
        for metrics_category in [(edit_mce_df, edit_oi_df), (gen_mce_df, gen_oi_df), (mmlu_mce_df, mmlu_oi_df)]:
            for sub_metric in metrics_category:
                for editor in editor_order:
                    table_row += f" & {format_value(sub_metric[editor][compressor])}"
                    row_values.append(sub_metric[editor][compressor])
        
        table_row += r" \\"
        latex_code += table_row + "\n"
        table_values.append(row_values)
    
    latex_code += r"        \midrule" + "\n"
    avg_row = r"        \textit{Average}"
    for col_avg in np.array(table_values).mean(0).tolist():
        avg_row += f" & {format_value(col_avg)}"
    
    latex_code += avg_row + r" \\" + "\n"
    latex_code += r'''        \bottomrule \\
    \end{tabular}
'''

    print(latex_code)


generate_latex_table_ke_mc(
    edit_mce_main_results,
    edit_oi_main_results,
    generalization_mce_main_results,
    generalization_oi_main_results,
    mmlu_mce_main_results,
    mmlu_oi_main_results,
    editor_order,
    compression_order,
)


    \begin{tabular}{lcccccccccccccccccc}
        \toprule
        & \multicolumn{6}{c}{\textbf{Edit Success}} & \multicolumn{6}{c}{\textbf{Generalization}} & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13} \cmidrule(lr){14-19}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13} \cmidrule(lr){14-16} \cmidrule(lr){17-19}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf

### Multi Row Table

### First Row: KE Metrics

In [22]:
def generate_latex_table_ke_mc_edit_only(edit_mce_df, edit_oi_df, gen_mce_df, gen_oi_df, edit_interventions, compression_order):
    latex_code = r"""
    \begin{tabular}{lcccccccccccc}
        \toprule
        & \multicolumn{6}{c}{\textbf{Edit Success}} & \multicolumn{6}{c}{\textbf{Generalization}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
"""
    table_values = []

    for compressor in compression_order:
        if compressor == "AWQ":
            latex_code += r"        \cdashlinelr{1-13}" + "\n"

        row_values = []
        table_row = r"        \textbf{" + compressor + "}"
        for metrics_category in [(edit_mce_df, edit_oi_df), (gen_mce_df, gen_oi_df)]:
            for sub_metric in metrics_category:
                for editor in editor_order:
                    table_row += f" & {format_value(sub_metric[editor][compressor])}"
                    row_values.append(sub_metric[editor][compressor])
        
        table_row += r" \\"
        latex_code += table_row + "\n"
        table_values.append(row_values)
    
    latex_code += r"        \midrule" + "\n"
    avg_row = r"        \textit{Average}"
    for col_avg in np.array(table_values).mean(0).tolist():
        avg_row += f" & {format_value(col_avg)}"
    
    latex_code += avg_row + r" \\" + "\n"
    latex_code += r'''        \bottomrule \\
    \end{tabular}
'''

    print(latex_code)


generate_latex_table_ke_mc_edit_only(
    edit_mce_main_results,
    edit_oi_main_results,
    generalization_mce_main_results,
    generalization_oi_main_results,
    editor_order,
    compression_order,
)



    \begin{tabular}{lcccccccccccc}
        \toprule
        & \multicolumn{6}{c}{\textbf{Edit Success}} & \multicolumn{6}{c}{\textbf{Generalization}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
        \textbf{Wanda} & .01 & .05 & .05 & .03 & .07 & .08 & .21 & .10 & .46 & .03 & .05 & .01 \\
        \textbf{SparseGPT} & .01 & .06 & .14 & .03 & .07 & .01 & .22 & .13 & .48 & .02 & .04 & .07 \\
        \cdashlinelr{1-13}
        \textbf{AWQ} & .01 & .07 & .00 & .04

#### Second Row: Locality & MMLU

In [23]:
def generate_latex_table_ke_locality_and_mmlu(locality_mce_df, locality_oi_df, mmlu_mce_df, mmlu_oi_df, edit_interventions, mmlu_interventions):
    latex_code = r"""
    \begin{tabular}{lcccccccccccc}
        & \multicolumn{6}{c}{\textbf{Locality}} & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
"""

    table_values = []

    for compressor in compression_order:
        if compressor == "AWQ":
            latex_code += r"        \cdashlinelr{1-13}" + "\n"

        row_values = []
        table_row = r"        \textbf{" + compressor + "}"
        for metrics_category in [(locality_mce_df, locality_oi_df), (mmlu_mce_df, mmlu_oi_df)]:
            for sub_metric in metrics_category:
                for editor in editor_order:
                    table_row += f" & {format_value(sub_metric[editor][compressor])}"
                    row_values.append(sub_metric[editor][compressor])
        
        table_row += r" \\"
        latex_code += table_row + "\n"
        table_values.append(row_values)

    latex_code += r"        \midrule" + "\n"
    avg_row = r"        \textit{Average}"
    for col_avg in np.array(table_values).mean(0).tolist():
        avg_row += f" & {format_value(col_avg)}"

    latex_code += avg_row + r" \\" + "\n"
    latex_code += r'''        \bottomrule \\
    \end{tabular}
'''
    
    print(latex_code)

generate_latex_table_ke_locality_and_mmlu(
    locality_mce_main_results,
    locality_oi_main_results,
    mmlu_mce_main_results,
    mmlu_oi_main_results,
    editor_order,
    compression_order,
)


    \begin{tabular}{lcccccccccccc}
        & \multicolumn{6}{c}{\textbf{Locality}} & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
        \textbf{Wanda} & .90 & .96 & .96 & .03 & .00 & .01 & .40 & .40 & .40 & .01 & .00 & .00 \\
        \textbf{SparseGPT} & .89 & .96 & .96 & .03 & .00 & .00 & .40 & .40 & .40 & .00 & .00 & .00 \\
        \cdashlinelr{1-13}
        \textbf{AWQ} & .85 & .96 & .93 & .04 & .00 & .03 & .41 & .41 & .41 

### Single Row: MMLU

In [24]:
def generate_latex_table_ke_mc_mmlu_only(mmlu_mce_df, mmlu_oi_df, editor_order, compression_order):
    latex_code = r"""
    \begin{tabular}{lcccccc}
        & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
"""
    table_values = []

    for compressor in compression_order:
        if compressor == "AWQ":
            latex_code += r"        \cdashlinelr{1-7}" + "\n"

        row_values = []
        table_row = r"        \textbf{" + compressor + "}"
        for metrics_category in [(mmlu_mce_df, mmlu_oi_df)]:
            for sub_metric in metrics_category:
                for editor in editor_order:
                    table_row += f" & {format_value(sub_metric[editor][compressor])}"
                    row_values.append(sub_metric[editor][compressor])
        
        table_row += r" \\"
        latex_code += table_row + "\n"
        table_values.append(row_values)
    
    latex_code += r"        \midrule" + "\n"
    avg_row = r"        \textit{Average}"
    for col_avg in np.array(table_values).mean(0).tolist():
        avg_row += f" & {format_value(col_avg)}"
    
    latex_code += avg_row + r" \\" + "\n"
    latex_code += r'''        \bottomrule \\
    \end{tabular}
'''

    print(latex_code)


generate_latex_table_ke_mc_mmlu_only(
    mmlu_mce_main_results,
    mmlu_oi_main_results,
    editor_order,
    compression_order,
)



    \begin{tabular}{lcccccc}
        & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
        \textbf{Wanda} & .40 & .40 & .40 & .01 & .00 & .00 \\
        \textbf{SparseGPT} & .40 & .40 & .40 & .00 & .00 & .00 \\
        \cdashlinelr{1-7}
        \textbf{AWQ} & .41 & .41 & .41 & .02 & .01 & .00 \\
        \textbf{GPTQ} & .41 & .41 & .42 & .00 & .01 & .01 \\
        \midrule
        \textit{Average} & .40 & .41 & .41 & .01 & .01 & .00 \\
        \bottomrule \\
    \end{tabular}



## MU ←→ MC

In [25]:
# Have WMDP and MMLU in the same table
def generate_latex_table_mu_mc():
    latex_code = r"""
    \begin{tabular}{lcccccccccccc}
        \toprule
        & \multicolumn{6}{c}{\textbf{WMDP}} & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{GA} & \textbf{GD} & \textbf{RMU} & \textbf{GA} & \textbf{GD} & \textbf{RMU} & \textbf{GA} & \textbf{GD} & \textbf{RMU} & \textbf{GA} & \textbf{GD} & \textbf{RMU} \\
        \midrule
"""
    table_values = []

    for compressor in compression_order:
        if compressor == "AWQ":
            latex_code += r"        \cdashlinelr{1-13}" + "\n"

        row_values = []
        table_row = r"        \textbf{" + compressor + "}"
        for metrics_category in [(wmdp_mce_main_results, wmdp_oi_main_results), (mmlu_mce_main_results, mmlu_oi_main_results)]:
            for sub_metric in metrics_category:
                for unlearner in unlearn_order:
                    table_row += f" & {format_value(sub_metric[unlearner][compressor])}"
                    row_values.append(sub_metric[unlearner][compressor])
        
        table_row += r" \\"
        latex_code += table_row + "\n"
        table_values.append(row_values)
    
    latex_code += r"        \midrule" + "\n"
    avg_row = r"        \textit{Average}"
    for col_avg in np.array(table_values).mean(0).tolist():
        avg_row += f" & {format_value(col_avg)}"

    latex_code += avg_row + r" \\" + "\n"
    latex_code += r'''        \bottomrule \\
    \end{tabular}
'''

    print(latex_code)


generate_latex_table_mu_mc()


    \begin{tabular}{lcccccccccccc}
        \toprule
        & \multicolumn{6}{c}{\textbf{WMDP}} & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{GA} & \textbf{GD} & \textbf{RMU} & \textbf{GA} & \textbf{GD} & \textbf{RMU} & \textbf{GA} & \textbf{GD} & \textbf{RMU} & \textbf{GA} & \textbf{GD} & \textbf{RMU} \\
        \midrule
        \textbf{Wanda} & .46 & .35 & .29 & .03 & .17 & .03 & .46 & .42 & .43 & .02 & .01 & .01 \\
        \textbf{SparseGPT} & .43 & .28 & .28 & .03 & .22 & .02 & .49 & .42 & .43 & .06 & .09 & .01 \\
        \cdashlinelr{1-13}
        \textbf{AWQ} & .43 & .24 & .27 & .02 & .00 & .01 & .53 & .58 & .44 & .

## KE ←→ MU

### First Row: KE Metrics

In [26]:
def generate_latex_table_ke_mu_ke_metrics():
    latex_code = r"""
    \begin{tabular}{lcccccccccccc}
        \toprule
        & \multicolumn{6}{c}{\textbf{Edit Success}} & \multicolumn{6}{c}{\textbf{Generalization}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
"""
    table_values = []

    for unlearner in unlearn_order:
        row_values = []
        table_row = r"        \textbf{" + unlearner + "}"
        for metrics_category in [(edit_mce_main_results, edit_oi_main_results), (generalization_mce_main_results, generalization_oi_main_results)]:
            for sub_metric in metrics_category:
                for editor in editor_order:
                    table_row += f" & {format_value(sub_metric[editor][unlearner])}"
                    row_values.append(sub_metric[editor][unlearner])
        
        table_row += r" \\"
        latex_code += table_row + "\n"
        table_values.append(row_values)
    
    latex_code += r"        \midrule" + "\n"
    avg_row = r"        \textit{Average}"
    for col_avg in np.array(table_values).mean(0).tolist():
        avg_row += f" & {format_value(col_avg)}"
    
    latex_code += avg_row + r" \\" + "\n"
    latex_code += r'''        \bottomrule \\
    \end{tabular}
'''

    print(latex_code)


generate_latex_table_ke_mu_ke_metrics()



    \begin{tabular}{lcccccccccccc}
        \toprule
        & \multicolumn{6}{c}{\textbf{Edit Success}} & \multicolumn{6}{c}{\textbf{Generalization}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
        \textbf{GA} & .93 & .52 & .00 & .07 & .48 & 1 & .96 & .59 & .22 & .04 & .41 & .78 \\
        \textbf{GD} & .01 & .07 & .00 & .67 & .40 & .56 & .19 & .11 & .29 & .56 & .41 & .48 \\
        \textbf{RMU} & .00 & .03 & .00 & .01 & .01 & .00 & .18 & .07 & .29 & .03 & 

### Second Row: MU Metrics

In [27]:
# Have WMDP and MMLU in the same table
def generate_latex_table_ke_mu_mu_metrics():
    latex_code = r"""
    \begin{tabular}{lcccccccccccc}
        & \multicolumn{6}{c}{\textbf{WMDP}} & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
"""
    table_values = []

    for unlearner in unlearn_order:
        row_values = []
        table_row = r"        \textbf{" + unlearner + "}"
        for metrics_category in [(wmdp_mce_main_results, wmdp_oi_main_results), (mmlu_mce_main_results, mmlu_oi_main_results)]:
            for sub_metric in metrics_category:
                for editor in editor_order:
                    table_row += f" & {format_value(sub_metric[editor][unlearner])}"
                    row_values.append(sub_metric[editor][unlearner])
        
        table_row += r" \\"
        latex_code += table_row + "\n"
        table_values.append(row_values)
    
    latex_code += r"        \midrule" + "\n"
    avg_row = r"        \textit{Average}"
    for col_avg in np.array(table_values).mean(0).tolist():
        avg_row += f" & {format_value(col_avg)}"

    latex_code += avg_row + r" \\" + "\n"
    latex_code += r'''        \bottomrule \\
    \end{tabular}
'''

    print(latex_code)


generate_latex_table_ke_mu_mu_metrics()


    \begin{tabular}{lcccccccccccc}
        & \multicolumn{6}{c}{\textbf{WMDP}} & \multicolumn{6}{c}{\textbf{MMLU}} \\
        \cmidrule(lr){2-7} \cmidrule(lr){8-13}
        & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{MCE ($\downarrow$)}} & \multicolumn{3}{c}{\textbf{OI ($\downarrow$)}} \\
        \cmidrule(lr){2-4} \cmidrule(lr){5-7} \cmidrule(lr){8-10} \cmidrule(lr){11-13}
        \textbf{Method} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} & \textbf{FT} & \textbf{MEMIT} & \textbf{LoRA} \\
        \midrule
        \textbf{GA} & .47 & .40 & .28 & .00 & .05 & .07 & .47 & .51 & .64 & .01 & .04 & .07 \\
        \textbf{GD} & .29 & .26 & .30 & .00 & .02 & .24 & .41 & .42 & .41 & .18 & .22 & .14 \\
        \textbf{RMU} & .28 & .29 & .29 & .04 & .01 & .00 & .43 & .44 & .44 & .01 & .00 & .04 \\
        \midrule
   

## Appendix: Detailed Results

- None    : Done
- KE ←→ MC: Done
- MC ←→ KE: Done
- KE ←→ MU: Todo
- MU ←→ KE: Todo
- MU ←→ MC: Todo
- MC ←→ MU: Todo


In [28]:
technique_formatting_map = {
    "awq": "AWQ",
    "gptq": "GPTQ",
    "sparsegpt": "SparseGPT",
    "wanda": "Wanda",
    "ft": "FT",
    "memit": "MEMIT",
    "lora": "LoRA",
    "ga": "GA",
    "gd": "GD",
    "rmu": "RMU",
}
appendix_compositions_order = [
    [],
    ["compress"],
    ["edit"],
    ["edit", "compress"],
    ["compress", "edit"],
    ["unlearn"],
    ["unlearn", "compress"],
    ["compress", "unlearn"],
    ["edit", "unlearn"],
    ["unlearn", "edit"],
]
appendix_table_columns_map = {
    "tag": "Composition",
    "Rewrite accuracy": "Edit Success",
    "Generalization": "Generalization",
    "Locality": "Locality",
    "Average bits": "Avg. Bits",
    "Avg WMDP": "Avg. WMDP",
    "mmlu accuracy": "MMLU",
    "PPL": "WikiText PPL",
}
appendix_technique_ordering = {
    "edit": ["Fine-tune", "MEMIT", "LoRA"],
    "compress": ["SparseGPT", "Wanda", "GPTQ", "AWQ"],
    "unlearn": ["GA", "GD", "RMU"],
}


def get_composition_label(row):
    composition = row["interventions"]
    if composition == []:
        return "None"
    
    first_intervention_type = composition[0] if composition[0] != "compress" else "compression"
    first_intervention = technique_formatting_map.get(row[first_intervention_type], row[first_intervention_type])
    if first_intervention in ["SparseGPT", "Wanda"]:
        first_intervention += " " + str(row["sparsity_ratio"])
    elif first_intervention in ["GPTQ", "AWQ"]:
        first_intervention += " (" + str(int(row["wbits"])) + "bit) "
    
    if len(composition) == 1:
        return first_intervention
    
    second_intervention_type = composition[1] if composition[1] != "compress" else "compression"
    second_intervention = technique_formatting_map.get(row[second_intervention_type], row[second_intervention_type])
    if second_intervention in ["SparseGPT", "Wanda"]:
        second_intervention += " " + str(row["sparsity_ratio"])
    elif second_intervention in ["GPTQ", "AWQ"]:
        second_intervention += " (" + str(int(row["wbits"])) + "bit) "
    
    return first_intervention + r"$\rightarrow$" + second_intervention


appendix_results = main_results.copy()
appendix_results["interventions"] = appendix_results["interventions"].apply(lambda x : ast.literal_eval(x) if isinstance(x, str) else x)
appendix_results["Label"] = appendix_results.apply(get_composition_label, axis=1)
appendix_results


Unnamed: 0,model_name,tag,_timestamp,edit_set,number_of_edits,rmu_layer_id,wbits,sparsity_ratio,mmlu accuracy,wmdp_bio accuracy,...,edit_dataset,compression_dataset,qa_question_count_limit,PPL,PPL edits,PPl QA,FLOPs,PPl edits unmasked,date,Label
37,Llama-3 (8b),AWQ2bit_FT,1.719118e+09,9.545455,50.0,-1.000000,2.0,0.00,0.258743,0.243591,...,zsre,c4,,33638.4375,338052.34375,102475.617188,-1,78554.226562,2024-05-20 17:38:54.680141449,AWQ (2bit) $\rightarrow$Fine-tune
38,Llama-3 (8b),AWQ2bit_LORA,1.718724e+09,5.500000,50.0,-1.000000,2.0,0.00,0.262028,0.242419,...,zsre,c4,,127311.320312,52947.800781,352363.28125,-1,96682.617188,2024-06-14 09:58:55.207093239,AWQ (2bit) $\rightarrow$LoRA
39,Llama-3 (8b),AWQ2bit_MEMIT,1.721601e+09,10.666667,50.0,-1.000000,2.0,0.00,0.264049,0.241424,...,zsre,c4,,1735678.75,996271.5625,1198751.125,-1,1074956.375,2024-05-20 17:01:28.464071751,AWQ (2bit) $\rightarrow$MEMIT
40,Llama-3 (8b),AWQ3bit_FT,1.720819e+09,5.500000,50.0,-1.000000,3.0,0.00,0.508567,0.596465,...,zsre,c4,,7.562892,16993.777344,648.620483,-1,822.444336,2024-06-17 08:37:54.834238052,AWQ (3bit) $\rightarrow$Fine-tune
41,Llama-3 (8b),AWQ3bit_LORA,1.719290e+09,5.500000,50.0,-1.000000,3.0,0.00,0.510034,0.604242,...,zsre,c4,,55.692047,751300.5625,7932.938477,-1,16216.796875,2024-06-17 07:55:40.395966053,AWQ (3bit) $\rightarrow$LoRA
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
343,Llama-3 (8b),WANDA65%_GD,1.718204e+09,1.000000,50.0,-1.000000,4.0,0.65,0.229098,0.249018,...,zsre,c4,,16268.005859,5423145.5,81961160,447.25 GFLOPS,89329664,2024-06-12 14:53:38.419730186,Wanda 0.65$\rightarrow$GD
344,Llama-3 (8b),WANDA65%_RMU,1.716764e+09,18.666667,50.0,3.666667,4.0,0.65,0.229360,0.249542,...,zsre,c4,,45.534767,83373.453125,1477.523804,760 GFLOPS,1549.510742,2024-05-20 19:27:06.522696495,Wanda 0.65$\rightarrow$RMU
345,Llama-3 (8b),WANDA75%_GA,1.718194e+09,1.000000,50.0,-1.000000,4.0,0.75,0.229811,0.249018,...,zsre,c4,,8125971838094910145582358934348365824.0,176520989218511062933962752.0,31936578699555629957120.0,357.84 GFLOPS,31529243875858664390656.0,2024-06-12 12:09:17.554701090,Wanda 0.75$\rightarrow$GA
346,Llama-3 (8b),WANDA75%_GD,1.718204e+09,1.000000,50.0,-1.000000,4.0,0.75,0.246902,0.247447,...,zsre,c4,,10226426,743558.8125,145133.09375,357.84 GFLOPS,128967.648438,2024-06-12 14:55:21.376342773,Wanda 0.75$\rightarrow$GD


In [29]:
appendix_results.columns

Index(['model_name', 'tag', '_timestamp', 'edit_set', 'number_of_edits',
       'rmu_layer_id', 'wbits', 'sparsity_ratio', 'mmlu accuracy',
       'wmdp_bio accuracy', 'wmdp_cyber accuracy', 'Generalization',
       'Success recall', 'Generalization recall', 'Locality', 'Average bits',
       'Rewrite accuracy', 'Local recall', 'Latency', 'Avg WMDP',
       '_timestamp_se', 'edit_set_se', 'number_of_edits_se', 'rmu_layer_id_se',
       'wbits_se', 'sparsity_ratio_se', 'mmlu accuracy_se',
       'wmdp_bio accuracy_se', 'wmdp_cyber accuracy_se', 'Generalization_se',
       'Success recall_se', 'Generalization recall_se', 'Locality_se',
       'Average bits_se', 'Rewrite accuracy_se', 'Local recall_se',
       'Latency_se', 'Avg WMDP_se', 'interventions', 'edit', 'unlearn',
       'compression', 'edit_dataset', 'compression_dataset',
       'qa_question_count_limit', 'PPL', 'PPL edits', 'PPl QA', 'FLOPs',
       'PPl edits unmasked', 'date', 'Label'],
      dtype='object')

### Appendix Table: Single Intervention

In [30]:
list(appendix_table_columns_map.keys())[1:]

['Rewrite accuracy',
 'Generalization',
 'Locality',
 'Average bits',
 'Avg WMDP',
 'mmlu accuracy',
 'PPL']

In [31]:
latex_code = r"""
\begin{tabular}{lcccccccc}
    \toprule
    & \multicolumn{3}{c}{Editing} & \multicolumn{1}{c}{Compression} & \multicolumn{1}{c}{Unlearning} & \multicolumn{2}{c}{Utility} \\
    \cmidrule(lr){2-4} \cmidrule(lr){5-5} \cmidrule(lr){6-6} \cmidrule(lr){7-8}
    & Edit Success & Generalization & Locality & Avg. Bits & Avg. WMDP & MMLU & WikiText PPL \\
    \midrule
"""

# None
appendix_no_compositons = appendix_results[appendix_results["interventions"].apply(lambda x: len(x) == 0)]
latex_code += r"   {None}"
for col in list(appendix_table_columns_map.keys())[1:]:
    latex_code += f" & {appendix_no_compositons[col].mean():.2f}"

latex_code += r" \\" + "\n"

# Edit Only
latex_code += r"    \cdashlinelr{1-9}" + "\n"

appendix_edit_only = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["edit"])]
display(appendix_edit_only.value_counts("edit"))
for edit_technique in appendix_technique_ordering["edit"]:
    appendix_edit_only_technique = appendix_edit_only[appendix_edit_only["edit"] == edit_technique]
    assert len(appendix_edit_only_technique) > 0, f"No data found for {edit_technique}"
    technique_row_label = "FT" if edit_technique == "Fine-tune" else edit_technique
    latex_code += f"    {{{technique_row_label}}}$\\rightarrow$None"
    for col in list(appendix_table_columns_map.keys())[1:]:
        latex_code += f"& {round(appendix_edit_only_technique[col].mean(), 2)}"
    
    latex_code += r" \\" + "\n"

# Compress Only
latex_code += r"    \cdashlinelr{1-9}" + "\n"

appendix_compress_only = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["compress"])]
for compress_technique in appendix_technique_ordering["compress"]:
    appendix_compress_only_technique = appendix_compress_only[appendix_compress_only["compression"] == compress_technique]
    assert len(appendix_compress_only_technique) > 0, f"No data found for {compress_technique}"

    compression_strength_column = "sparsity_ratio" if compress_technique in ["SparseGPT", "Wanda"] else "wbits"
    compression_strength_ordering = sorted(appendix_compress_only_technique[compression_strength_column].unique())
    for compression_strength in compression_strength_ordering:
        technique_row_label = compress_technique
        current_compression = appendix_compress_only_technique[appendix_compress_only_technique[compression_strength_column] == compression_strength]
        if compress_technique in ["SparseGPT", "Wanda"]:
            technique_row_label += " (" + str(current_compression["sparsity_ratio"].iloc[0]) + ") "
        elif compress_technique in ["GPTQ", "AWQ"]:
            technique_row_label += " (" + str(int(current_compression["wbits"].iloc[0])) + "-Bit) "
        
        latex_code += f"    {{{technique_row_label}}}$\\rightarrow$None"
        for col in list(appendix_table_columns_map.keys())[1:]:
            latex_code += f" & {round(current_compression[col].mean(), 2)}"
        
        latex_code += r" \\" + "\n"

# Unlearn Only
latex_code += r"    \cdashlinelr{1-9}" + "\n"

appendix_unlearn_only = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["unlearn"])]
for unlearn_technique in appendix_technique_ordering["unlearn"]:
    appendix_unlearn_only_technique = appendix_unlearn_only[appendix_unlearn_only["unlearn"] == unlearn_technique]
    assert len(appendix_unlearn_only_technique) > 0, f"No data found for {unlearn_technique}"
    technique_row_label = unlearn_technique
    latex_code += f"    {{{technique_row_label}}}$\\rightarrow$None"
    for col in list(appendix_table_columns_map.keys())[1:]:
        latex_code += f" & {round(appendix_unlearn_only_technique[col].mean(), 2)}"
    
    latex_code += r" \\" + "\n"

# end of table
latex_code += r"    \bottomrule \\" + "\n"
latex_code += r"\end{tabular}"

# Pring the table
print(latex_code)

edit
Fine-tune    1
LoRA         1
MEMIT        1
Name: count, dtype: int64

TypeError: Could not convert string 'Infinity' to numeric

### Appendix Table: KE ←→ MC

#### Edit First

In [None]:
latex_code = r"""
\begin{tabular}{lcccccccc}
    \toprule
    & \multicolumn{3}{c}{Editing} & \multicolumn{1}{c}{Compression} & \multicolumn{1}{c}{Unlearning} & \multicolumn{2}{c}{Utility} \\
    \cmidrule(lr){2-4} \cmidrule(lr){5-5} \cmidrule(lr){6-6} \cmidrule(lr){7-8}
    & Edit Success & Generalization & Locality & Avg. Bits & Avg. WMDP & MMLU & WikiText PPL \\
    \midrule
"""

# Editing -> Compression
appendix_edit_compress = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["edit", "compress"])]
for edit_technique in appendix_technique_ordering["edit"]:
    appendix_edit_compress_edit_technique = appendix_edit_compress[appendix_edit_compress["edit"] == edit_technique]
    assert len(appendix_edit_compress_edit_technique) > 0, f"No data found for {edit_technique}"
    for compress_technique in appendix_technique_ordering["compress"]:
        appendix_edit_compress_technique_frame = appendix_edit_compress_edit_technique[appendix_edit_compress_edit_technique["compression"] == compress_technique]
        assert len(appendix_edit_compress_technique_frame) > 0, f"No data found for {compress_technique}"
        compression_strength_column = "sparsity_ratio" if compress_technique in ["SparseGPT", "Wanda"] else "wbits"
        compression_strength_ordering = sorted(set([round(strength, 2) for strength in appendix_edit_compress_technique_frame[compression_strength_column] if strength not in [0, 16]]))
        for compression_strength in compression_strength_ordering:
            technique_row_label = compress_technique
            current_compression = appendix_edit_compress_technique_frame[appendix_edit_compress_technique_frame[compression_strength_column] == compression_strength]
            if compress_technique in ["SparseGPT", "Wanda"]:
                technique_row_label += " (" + str(current_compression["sparsity_ratio"].iloc[0]) + ") "
            elif compress_technique in ["GPTQ", "AWQ"]:
                technique_row_label += " (" + str(int(current_compression["wbits"].iloc[0])) + "-Bit) "
            
            latex_code += f"    {{{formatted_edit_technique}}}$\\rightarrow${{{technique_row_label}}}"
            for col in list(appendix_table_columns_map.keys())[1:]:
                latex_code += f" & {round(current_compression[col].mean(), 2)}"
            
            latex_code += r" \\" + "\n"
    
    if edit_technique != appendix_technique_ordering["edit"][-1]:
        latex_code += r"    \cdashlinelr{1-8}" + "\n"

# end of table
latex_code += r"    \bottomrule \\" + "\n"
latex_code += r"\end{tabular}"

# Pring the table
print(latex_code)


### Compression First

In [None]:
latex_code = r"""
\begin{tabular}{lcccccccc}
    \toprule
    & \multicolumn{3}{c}{Editing} & \multicolumn{1}{c}{Compression} & \multicolumn{1}{c}{Unlearning} & \multicolumn{2}{c}{Utility} \\
    \cmidrule(lr){2-4} \cmidrule(lr){5-5} \cmidrule(lr){6-6} \cmidrule(lr){7-8}
    & Edit Success & Generalization & Locality & Avg. Bits & Avg. WMDP & MMLU & WikiText PPL \\
    \midrule
"""

# Editing -> Compression
appendix_compress_edit = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["compress", "edit"])]
for compress_technique in appendix_technique_ordering["compress"]:
    appendix_compress_edit_technique_frame = appendix_compress_edit[appendix_compress_edit["compression"] == compress_technique]
    assert len(appendix_compress_edit_technique_frame) > 0, f"No data found for {compress_technique}"

    compression_strength_column = "sparsity_ratio" if compress_technique in ["SparseGPT", "Wanda"] else "wbits"
    compression_strength_ordering = sorted(set([round(strength, 2) for strength in appendix_compress_edit_technique_frame[compression_strength_column] if strength not in [0, 16]]))
    print(f"Technique: {compress_technique}, Strengths: {compression_strength_ordering}")

    for edit_technique in appendix_technique_ordering["edit"]:
        appendix_compress_edit_edit_technique = appendix_compress_edit_technique_frame[appendix_compress_edit_technique_frame["edit"] == edit_technique]
        assert len(appendix_compress_edit_edit_technique) > 0, f"No data found for {edit_technique}"
        
        for compression_strength in compression_strength_ordering:
            technique_row_label = compress_technique
            current_compression = appendix_compress_edit_edit_technique[round(appendix_compress_edit_edit_technique[compression_strength_column], 2) == compression_strength]
            if compress_technique in ["SparseGPT", "Wanda"]:
                technique_row_label += " (" + str(compression_strength) + ") "
            elif compress_technique in ["GPTQ", "AWQ"]:
                technique_row_label += " (" + str(int(compression_strength)) + "-Bit) "
            
            latex_code += f"    {{{technique_row_label}}}$\\rightarrow${{{formatted_edit_technique}}}"
            for col in list(appendix_table_columns_map.keys())[1:]:
                latex_code += f" & {round(current_compression[col].mean(), 2)}"
            
            latex_code += r" \\" + "\n"
    
    if compress_technique != appendix_technique_ordering["compress"][-1]:
        latex_code += r"    \cdashlinelr{1-8}" + "\n"

# end of table
latex_code += r"    \bottomrule \\" + "\n"
latex_code += r"\end{tabular}"

# Pring the table
print(latex_code)

### Appendix Table: MU ←→ MC

#### Unlearn First

In [None]:
latex_code = r"""
\begin{tabular}{lcccccccc}
    \toprule
    & \multicolumn{3}{c}{Editing} & \multicolumn{1}{c}{Compression} & \multicolumn{1}{c}{Unlearning} & \multicolumn{2}{c}{Utility} \\
    \cmidrule(lr){2-4} \cmidrule(lr){5-5} \cmidrule(lr){6-6} \cmidrule(lr){7-8}
    & Edit Success & Generalization & Locality & Avg. Bits & Avg. WMDP & MMLU & WikiText PPL \\
    \midrule
"""

# Unlearn -> Compression
appendix_unlearn_compress = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["unlearn", "compress"])]
for unlearn_technique in appendix_technique_ordering["unlearn"]:
    appendix_unlearn_compress_unlearn_technique = appendix_unlearn_compress[appendix_unlearn_compress["unlearn"] == unlearn_technique]
    assert len(appendix_unlearn_compress_unlearn_technique) > 0, f"No data found for {unlearn_technique}"
    for compress_technique in appendix_technique_ordering["compress"]:
        appendix_unlearn_compress_technique_frame = appendix_unlearn_compress_unlearn_technique[appendix_unlearn_compress_unlearn_technique["compression"] == compress_technique]
        assert len(appendix_unlearn_compress_technique_frame) > 0, f"No data found for {compress_technique}"
        compression_strength_column = "sparsity_ratio" if compress_technique in ["SparseGPT", "Wanda"] else "wbits"
        compression_strength_ordering = sorted(set([round(strength, 2) for strength in appendix_unlearn_compress_technique_frame[compression_strength_column] if strength not in [0, 16]]))
        for compression_strength in compression_strength_ordering:
            technique_row_label = compress_technique
            current_compression = appendix_unlearn_compress_technique_frame[appendix_unlearn_compress_technique_frame[compression_strength_column] == compression_strength]
            if compress_technique in ["SparseGPT", "Wanda"]:
                technique_row_label += " (" + str(current_compression["sparsity_ratio"].iloc[0]) + ") "
            elif compress_technique in ["GPTQ", "AWQ"]:
                technique_row_label += " (" + str(int(current_compression["wbits"].iloc[0])) + "-Bit) "
            
            latex_code += f"    {{{formatted_unlearn_technique}}}$\\rightarrow${{{technique_row_label}}}"
            for col in list(appendix_table_columns_map.keys())[1:]:
                latex_code += f" & {round(current_compression[col].mean(), 2)}"
            
            latex_code += r" \\" + "\n"
    
    if unlearn_technique != appendix_technique_ordering["unlearn"][-1]:
        latex_code += r"    \cdashlinelr{1-8}" + "\n"
    
# end of table
latex_code += r"    \bottomrule \\" + "\n"
latex_code += r"\end{tabular}"

# Pring the table
print(latex_code)

#### Compression First

In [None]:
latex_code = r"""
\begin{tabular}{lcccccccc}
    \toprule
    & \multicolumn{3}{c}{Editing} & \multicolumn{1}{c}{Compression} & \multicolumn{1}{c}{Unlearning} & \multicolumn{2}{c}{Utility} \\
    \cmidrule(lr){2-4} \cmidrule(lr){5-5} \cmidrule(lr){6-6} \cmidrule(lr){7-8}
    & Edit Success & Generalization & Locality & Avg. Bits & Avg. WMDP & MMLU & WikiText PPL \\
    \midrule
"""

# Compression -> Unlearn
appendix_compress_unlearn = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["compress", "unlearn"])]
for compress_technique in appendix_technique_ordering["compress"]:
    appendix_compress_unlearn_technique_frame = appendix_compress_unlearn[appendix_compress_unlearn["compression"] == compress_technique.lower()]
    assert len(appendix_compress_unlearn_technique_frame) > 0, f"No data found for {compress_technique}"

    compression_strength_column = "sparsity_ratio" if compress_technique in ["SparseGPT", "Wanda"] else "wbits"
    compression_strength_ordering = sorted(set([round(strength, 2) for strength in appendix_compress_unlearn_technique_frame[compression_strength_column] if strength not in [0, 16]]))

    for unlearn_technique in appendix_technique_ordering["unlearn"]:
        formatted_unlearn_technique = unlearn_technique
        appendix_compress_unlearn_unlearn_technique = appendix_compress_unlearn_technique_frame[appendix_compress_unlearn_technique_frame["unlearn"] == formatted_unlearn_technique.lower()]
        assert len(appendix_compress_unlearn_unlearn_technique) > 0, f"No data found for {unlearn_technique}"
        
        for compression_strength in compression_strength_ordering:
            technique_row_label = compress_technique
            current_compression = appendix_compress_unlearn_unlearn_technique[round(appendix_compress_unlearn_unlearn_technique[compression_strength_column], 2) == compression_strength]
            if compress_technique in ["SparseGPT", "Wanda"]:
                technique_row_label += " (" + str(compression_strength) + ") "
            elif compress_technique in ["GPTQ", "AWQ"]:
                technique_row_label += " (" + str(int(compression_strength)) + "-Bit) "
            
            latex_code += f"    {{{technique_row_label}}}$\\rightarrow${{{formatted_unlearn_technique}}}"
            for col in list(appendix_table_columns_map.keys())[1:]:
                assert len(current_compression) == 1, f"Multiple rows found for {compress_technique} -> {unlearn_technique} -> {compression_strength}"
                latex_code += f" & {round(current_compression[col].mean(), 2)}"
            
            latex_code += r" \\" + "\n"
    
    if compress_technique != appendix_technique_ordering["compress"][-1]:
        latex_code += r"    \cdashlinelr{1-8}" + "\n"
    
# end of table
latex_code += r"    \bottomrule \\" + "\n"
latex_code += r"\end{tabular}"

# Pring the table
print(latex_code)

### Appendix Table: KE ←→ MU

#### Edit First

In [None]:
latex_code = r"""
\begin{tabular}{lcccccccc}
    \toprule
    & \multicolumn{3}{c}{Editing} & \multicolumn{1}{c}{Compression} & \multicolumn{1}{c}{Unlearning} & \multicolumn{2}{c}{Utility} \\
    \cmidrule(lr){2-4} \cmidrule(lr){5-5} \cmidrule(lr){6-6} \cmidrule(lr){7-8}
    & Edit Success & Generalization & Locality & Avg. Bits & Avg. WMDP & MMLU & WikiText PPL \\
    \midrule
"""

# Edit -> Unlearn
appendix_compress_edit = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["edit", "unlearn"])]
for edit_technique in appendix_technique_ordering["edit"]:
    formatted_edit_technique = "FT" if edit_technique == "Fine-tune" else edit_technique
    appendix_compress_edit_edit_technique = appendix_compress_edit[appendix_compress_edit["edit"] == formatted_edit_technique.lower()]
    assert len(appendix_compress_edit_edit_technique) > 0, f"No data found for {edit_technique}"
    for unlearn_technique in appendix_technique_ordering["unlearn"]:
        formatted_unlearn_technique = unlearn_technique
        appendix_compress_edit_technique_frame = appendix_compress_edit_edit_technique[appendix_compress_edit_edit_technique["unlearn"] == formatted_unlearn_technique.lower()]
        assert len(appendix_compress_edit_technique_frame) > 0, f"No data found for {unlearn_technique}"

        # No compression strength for this composition
        technique_row_label = edit_technique
        latex_code += f"    {{{formatted_edit_technique}}}$\\rightarrow${{{formatted_unlearn_technique}}}"
        for col in list(appendix_table_columns_map.keys())[1:]:
            latex_code += f" & {round(appendix_compress_edit_technique_frame[col].mean(), 2)}"

        latex_code += r" \\" + "\n"
        
    if edit_technique != appendix_technique_ordering["edit"][-1]:
        latex_code += r"    \cdashlinelr{1-8}" + "\n"

# end of table
latex_code += r"    \bottomrule \\" + "\n"
latex_code += r"\end{tabular}"

# Pring the table
print(latex_code)

#### Unlearn First

In [None]:
latex_code = r"""
\begin{tabular}{lcccccccc}
    \toprule
    & \multicolumn{3}{c}{Editing} & \multicolumn{1}{c}{Compression} & \multicolumn{1}{c}{Unlearning} & \multicolumn{2}{c}{Utility} \\
    \cmidrule(lr){2-4} \cmidrule(lr){5-5} \cmidrule(lr){6-6} \cmidrule(lr){7-8}
    & Edit Success & Generalization & Locality & Avg. Bits & Avg. WMDP & MMLU & WikiText PPL \\
    \midrule
"""

# Unlearn -> Edit
appendix_unlearn_edit = appendix_results[appendix_results["interventions"].apply(lambda x: x == ["unlearn", "edit"])]
for unlearn_technique in appendix_technique_ordering["unlearn"]:
    formatted_unlearn_technique = unlearn_technique
    appendix_unlearn_edit_unlearn_technique = appendix_unlearn_edit[appendix_unlearn_edit["unlearn"] == formatted_unlearn_technique.lower()]
    assert len(appendix_unlearn_edit_unlearn_technique) > 0, f"No data found for {unlearn_technique}"
    for edit_technique in appendix_technique_ordering["edit"]:
        formatted_edit_technique = "FT" if edit_technique == "Fine-tune" else edit_technique
        appendix_unlearn_edit_technique_frame = appendix_unlearn_edit_unlearn_technique[appendix_unlearn_edit_unlearn_technique["edit"] == formatted_edit_technique.lower()]
        assert len(appendix_unlearn_edit_technique_frame) > 0, f"No data found for {edit_technique}"
        
        # No compression strength for this composition
        technique_row_label = unlearn_technique
        latex_code += f"    {{{formatted_unlearn_technique}}}$\\rightarrow${{{formatted_edit_technique}}}"
        for col in list(appendix_table_columns_map.keys())[1:]:
            latex_code += f" & {round(appendix_unlearn_edit_technique_frame[col].mean(), 2)}"

        latex_code += r" \\" + "\n"
        
    if unlearn_technique != appendix_technique_ordering["unlearn"][-1]:
        latex_code += r"    \cdashlinelr{1-8}" + "\n"

# end of table
latex_code += r"    \bottomrule \\" + "\n"
latex_code += r"\end{tabular}"

# Pring the table
print(latex_code)

# Plots

In [50]:
def get_order_label(row):
    interventions = row["interventions"]
    first_method = ""
    second_method = ""
    if interventions[0] == "edit":
        first_method = row["edit"]
    elif interventions[0] == "compress":
        first_method = row["compression"]
    elif interventions[0] == "unlearn":
        first_method = row["unlearn"]
    
    if interventions[1] == "edit":
        second_method = row["edit"]
    elif interventions[1] == "compress":
        second_method = row["compression"]
    elif interventions[1] == "unlearn":
        second_method = row["unlearn"]
    
    return f"{first_method}→{second_method}"

def wrap_label(interventions):
    first_intervention, second_intervention = interventions[0], interventions[1]
    first_letter_upper = first_intervention[0].upper()
    second_letter_upper = second_intervention[0].upper()
    
    # EX: E $\rightarrow$ C
    return f"{first_letter_upper}$\\rightarrow${second_letter_upper}"


### Create mock records for baselines

In [None]:
# I want instances where editing has been applied but there is no unlearning or compression. In these cases, set wbits=16 and sparsity=0 
baseline_editors = main_results[(main_results["edit"].notnull()) & (main_results["unlearn"].isnull()) & (main_results["compression"].isnull()) & (main_results["interventions"].apply(lambda x: x == ["edit"]))].copy()
baseline_editors["wbits"] = 16
baseline_editors["sparsity_ratio"] = 0
news_records = []

# Edit and Compress
for editing_method in ["LoRA", "MEMIT", "Fine-tune"]:
    baseline_record = baseline_editors[baseline_editors["edit"] == editing_method]
    for compression_method in ["SparseGPT", "Wanda", "GPTQ", "AWQ"]:
        edit_first_record = baseline_record.copy()
        edit_first_record["compression"] = compression_method
        edit_first_record["interventions"] = [["edit", "compress"]]
        edit_first_record["sparsity_ratio"] = 0
        edit_first_record["wbits"] = 16
        news_records.append(edit_first_record)

        compress_first_record = baseline_record.copy()
        compress_first_record["compression"] = compression_method
        compress_first_record["interventions"] = [["compress", "edit"]]
        compress_first_record["sparsity_ratio"] = 0
        compress_first_record["wbits"] = 16
        news_records.append(compress_first_record)

baseline_unlearners = main_results[(main_results["edit"].isnull()) & (main_results["unlearn"].notnull()) & (main_results["compression"].isnull()) & (main_results["interventions"].apply(lambda x: x == ["unlearn"]))].copy()

# Compress and Unlearn
for unlearn_method in ["RMU", "GA", "GD"]:
    baseline_record = baseline_unlearners[baseline_unlearners["unlearn"] == unlearn_method]

    for compression_method in ["SparseGPT", "Wanda", "GPTQ", "AWQ"]:
        compress_first_record = baseline_record.copy()
        compress_first_record["unlearn"] = unlearn_method
        compress_first_record["compression"] = compression_method
        compress_first_record["interventions"] = [["compress", "unlearn"]]
        news_records.append(compress_first_record)

        unlearn_first_record = baseline_record.copy()
        unlearn_first_record["unlearn"] = unlearn_method
        unlearn_first_record["compression"] = compression_method
        unlearn_first_record["interventions"] = [["unlearn", "compress"]]
        news_records.append(unlearn_first_record)

baseline_records = pd.concat(news_records)
data = pd.concat([main_results, baseline_records])
baseline_records

### Editing and Compresion Single Row

## Plot: KE ←→ Compression

In [54]:
compositions_by_col = {
    # MEMIT and WANDA + SparseGPT
    0: [("MEMIT→SparseGPT", "SparseGPT→MEMIT"), ("MEMIT→Wanda", "Wanda→MEMIT")],
    # LoRA and WANDA + SparseGPT
    1: [("LoRA→SparseGPT", "SparseGPT→LoRA"), ("LoRA→Wanda", "Wanda→LoRA")],
    # FT and WANDA + SparseGPT
    2: [("Fine-tune→SparseGPT", "SparseGPT→Fine-tune"), ("Fine-tune→Wanda", "Wanda→Fine-tune")],
    # MEMIT and GPTQ + AWQ
    3: [("MEMIT→GPTQ", "GPTQ→MEMIT"), ("MEMIT→AWQ", "AWQ→MEMIT")],
    # LoRA and GPTQ + AWQ
    4: [("LoRA→GPTQ", "GPTQ→LoRA"), ("LoRA→AWQ", "AWQ→LoRA")],
    # FT and GPTQ + AWQ
    5: [("Fine-tune→GPTQ", "GPTQ→Fine-tune"), ("Fine-tune→AWQ", "AWQ→Fine-tune")],
}

In [None]:
pruning_frame = data[((data["compression"] == "SparseGPT") | (data["compression"] == "Wanda")) & (data["edit"] != None) & (data["interventions"].apply(lambda x: len(x) > 1))]
pruning_frame["order"] = pruning_frame.apply(get_order_label, axis=1)
pruning_frame = pruning_frame.sort_values(by="order")

quantization_frame = data[((data["compression"] == "GPTQ") | (data["compression"] == "AWQ")) & (data["edit"] != None) & (data["interventions"].apply(lambda x: len(x) > 1))]
quantization_frame["order"] = quantization_frame.apply(get_order_label, axis=1)
quantization_frame = quantization_frame.sort_values(by="order")

row_metrics = {
    0: "Rewrite accuracy",
    1: "Generalization",
    2: "Locality",
    3: "mmlu accuracy",
}
row_labels = {
    0: r"Edit Success$ \uparrow$",
    1: r"Generalization$ \uparrow$",
    2: r"Locality$ \uparrow$",
    3: r"MMLU$ \uparrow$"
}
column_edit_methods = {
    0: "MEMIT",
    1: "LoRA",
    2: "Fine-tune",
    3: "MEMIT",
    4: "LoRA",
    5: "Fine-tune"
}
compositions_by_col = {
    # MEMIT and WANDA + SparseGPT
    0: [("MEMIT→SparseGPT", "SparseGPT→MEMIT"), ("MEMIT→Wanda", "Wanda→MEMIT")],
    # LoRA and WANDA + SparseGPT
    1: [("LoRA→SparseGPT", "SparseGPT→LoRA"), ("LoRA→Wanda", "Wanda→LoRA")],
    # FT and WANDA + SparseGPT
    2: [("Fine-tune→SparseGPT", "SparseGPT→Fine-tune"), ("Fine-tune→Wanda", "Wanda→Fine-tune")],
    # MEMIT and GPTQ + AWQ
    3: [("MEMIT→GPTQ", "GPTQ→MEMIT"), ("MEMIT→AWQ", "AWQ→MEMIT")],
    # LoRA and GPTQ + AWQ
    4: [("LoRA→GPTQ", "GPTQ→LoRA"), ("LoRA→AWQ", "AWQ→LoRA")],
    # FT and GPTQ + AWQ
    5: [("Fine-tune→GPTQ", "GPTQ→Fine-tune"), ("Fine-tune→AWQ", "AWQ→Fine-tune")],
}
final_row_index = len(row_labels) - 1

fig, axes = plt.subplots(len(row_labels), 6, figsize=(6 * FIG_SIZE, len(row_labels) * FIG_SIZE))
for row_index, y_metric in row_metrics.items():
    for col_index, plotting_frame in enumerate([pruning_frame, pruning_frame, pruning_frame, quantization_frame, quantization_frame, quantization_frame]):
        ax = axes[row_index][col_index]
        x_metric = "sparsity_ratio" if col_index < 3 else "wbits"
        plotting_frame = plotting_frame[plotting_frame["edit"] == column_edit_methods[col_index]]

        for composition in compositions_by_col[col_index]:
            compression_method = [method for method in composition[0].split("→") if method not in ["MEMIT", "LoRA", "Fine-tune"]][0]
            first_line = plotting_frame[plotting_frame["order"] == composition[0]]
            first_line["label"] = first_line["order"].apply(wrap_label)
            second_line = plotting_frame[plotting_frame["order"] == composition[1]].sort_values(x_metric)
            second_line["label"] = second_line["order"].apply(wrap_label)
            if compression_method in ["AWQ", "GPTQ"]:
                first_line = first_line.sort_values(x_metric, ascending=False)
                second_line = second_line.sort_values(x_metric, ascending=False)
            else:
                first_line = first_line.sort_values(x_metric)
                second_line = second_line.sort_values(x_metric)

            ax.plot(first_line[x_metric], first_line[y_metric], marker="o", markersize=MARKER_SIZE, color=colors[compression_method], label=f"{composition[0]}")
            ax.plot(second_line[x_metric], second_line[y_metric], markerfacecolor='none', marker="o", ls="--", markersize=MARKER_SIZE, color=colors[compression_method], label=f"{composition[1]}")
            ax.fill_between(
                x=first_line[x_metric], y1=first_line[y_metric], y2=second_line[y_metric],
                alpha=0.3,
                color=colors[compression_method]
            )

        if x_metric == "wbits":
            ax.set_xscale("log", base=2)
            ax.set_xticks([2, 4, 8, 16], ["2", "4", "8", "16"])

        if row_index != final_row_index:
            ax.set_ylim(0, 1.05)
        else:
            ax.set_ylim(0.2, 0.65)
            ax.axhline(y=0.25, color="gray", linestyle="--")

        if y_metric == "Locality":
            ax.set_ylim(0, .2)

        if row_index == 0:
            title = column_edit_methods[col_index] if column_edit_methods[col_index] != "Fine-tune" else "FT"
            ax.set_title(title, fontsize=TITLE_FONT_SIZE)
        else:
            ax.set_title("")

        if col_index == 0:
            ax.set_ylabel(row_labels[row_index], fontsize=TITLE_FONT_SIZE)
        else:
            ax.set_ylabel("")

        if row_index == len(row_labels) - 1:
            ax.set_xlabel("Sparsity" if col_index < 3 else "Bits", fontsize=TITLE_FONT_SIZE)
        else:
            ax.set_xlabel("")

        if row_index == final_row_index:
            ax.legend(fontsize=LEGEND_FONT_SIZE, frameon=False, loc="upper center", bbox_to_anchor=(0.5, -0.3), ncol=1)

fig.subplots_adjust(wspace=WSPACE, hspace=WSPACE)
plt.savefig("figures/main_results_editors_compression.pdf", bbox_inches="tight")

## Plot: Unlearning ←→ Compression

In [None]:
pruning_frame = data[((data["compression"] == "SparseGPT") | (data["compression"] == "Wanda")) & (data["unlearn"] != None) & (data["interventions"].apply(lambda x: len(x) > 1))]
pruning_frame["order"] = pruning_frame.apply(get_order_label, axis=1)
pruning_frame = pruning_frame.sort_values(by="order")
pruning_frame["unlearn"] = pruning_frame["unlearn"].apply(lambda x: x.upper() if x is not None else None)

quantization_frame = data[((data["compression"] == "GPTQ") | (data["compression"] == "AWQ")) & (data["unlearn"] != None) & (data["interventions"].apply(lambda x: len(x) > 1))]
quantization_frame["order"] = quantization_frame.apply(get_order_label, axis=1)
quantization_frame = quantization_frame.sort_values(by="order")
quantization_frame["unlearn"] = quantization_frame["unlearn"].apply(lambda x: x.upper() if x is not None else None)

# 4 columns and 3 rows
fig, axes = plt.subplots(2, 6, figsize=(6 * FIG_SIZE, 2 * FIG_SIZE))
row_metrics = {
    0: "Avg WMDP",
    1: "mmlu accuracy",
}
row_labels = {
    "Avg WMDP": r"WMDP $\downarrow$",
    "mmlu accuracy": r"MMLU $\uparrow$"
}
row_label_map = {
    0: "Avg WMDP",
    1: "mmlu accuracy"
}
column_unlearn_methods = {
    0: "GA",
    1: "GD",
    2: "RMU",
    3: "GA",
    4: "GD",
    5: "RMU",
}

compositions_by_col = {
    
    # GA and WANDA + SparseGPT
    0: [("GA→SparseGPT", "SparseGPT→GA"), ("GA→Wanda", "Wanda→GA")],
    # FT and WANDA + SparseGPT
    1: [("GD→SparseGPT", "SparseGPT→GD"), ("GD→Wanda", "Wanda→GD")],
    # RMU and WANDA + SparseGPT
    2: [("RMU→SparseGPT", "SparseGPT→RMU"), ("RMU→Wanda", "Wanda→RMU")],
    # GA and GPTQ + AWQ
    3: [("GA→GPTQ", "GPTQ→GA"), ("GA→AWQ", "AWQ→GA")],
    # FT and GPTQ + AWQ
    4: [("GD→GPTQ", "GPTQ→GD"), ("GD→AWQ", "AWQ→GD")],
    # RMU and GPTQ + AWQ
    5: [("RMU→GPTQ", "GPTQ→RMU"), ("RMU→AWQ", "AWQ→RMU")],
}
for row_index, y_metric in row_metrics.items():
    for col_index, plotting_frame in enumerate([pruning_frame, pruning_frame, pruning_frame, quantization_frame, quantization_frame, quantization_frame]):
        ax = axes[row_index][col_index]
        x_metric = "sparsity_ratio" if col_index < 3 else "wbits"
        plotting_frame = plotting_frame[plotting_frame["unlearn"] == column_unlearn_methods[col_index]]

        for composition in compositions_by_col[col_index]:
            compression_method = [method for method in composition[0].split("→") if method not in ["RMU", "GA", "GD"]][0]
            first_line = plotting_frame[plotting_frame["order"] == composition[0]]
            first_line["label"] = first_line["order"].apply(wrap_label)
            second_line = plotting_frame[plotting_frame["order"] == composition[1]].sort_values(x_metric)
            second_line["label"] = second_line["order"].apply(wrap_label)
            if compression_method in ["AWQ", "GPTQ"]:
                first_line = first_line.sort_values(x_metric, ascending=False)
                second_line = second_line.sort_values(x_metric, ascending=False)
            else:
                first_line = first_line.sort_values(x_metric)
                second_line = second_line.sort_values(x_metric)

            ax.plot(first_line[x_metric], first_line[y_metric], marker="o", markersize=MARKER_SIZE, color=colors[compression_method], label=f"{composition[0]}")
            ax.plot(second_line[x_metric], second_line[y_metric], markerfacecolor='none', marker="o", ls="--", markersize=MARKER_SIZE, color=colors[compression_method], label=f"{composition[1]}")
            ax.fill_between(
                x=first_line[x_metric], y1=first_line[y_metric], y2=second_line[y_metric],
                alpha=0.3,
                color=colors[compression_method]
            )

        ax.axhline(y=0.25, color="gray", linestyle="--")
        ax.set_ylim(0.20, 0.65)

        if x_metric == "wbits":
            ax.set_xscale("log", base=2)
            ax.set_xticks([2, 4, 8, 16], ["2", "4", "8", "16"])

        if row_index == 0:
            title = column_unlearn_methods[col_index]
            ax.set_title(title, fontsize=TITLE_FONT_SIZE)
        else:
            ax.set_title("")

        if col_index == 0:
            ax.set_ylabel(row_labels[list(row_labels.keys())[row_index]], fontsize=TITLE_FONT_SIZE)
        else:
            ax.set_ylabel("")

        if row_index == 1:
            ax.set_xlabel("Sparsity" if col_index < 3 else "Bits", fontsize=TITLE_FONT_SIZE)
        else:
            ax.set_xlabel("")

        if row_index == 1:
            ax.legend(fontsize=LEGEND_FONT_SIZE, frameon=False, loc="upper center", bbox_to_anchor=(0.5, -0.3), ncol=1)

fig.subplots_adjust(wspace=WSPACE, hspace=WSPACE)
plt.savefig("figures/main_results_unlearn_compression.pdf", bbox_inches="tight")