In [1]:
from tqdm import tqdm
import pandas as pd
import json
import os
import ast
import seaborn as sns
import matplotlib.pyplot as plt
from datetime import datetime

import wandb
api = wandb.Api()

In [2]:
# Count Interventions
# 31 RMU Compositions
# Each editor in

In [10]:
setting_columns = [
    # Overall
    "tag",
    # "seed",
    "_timestamp",

    # Interventions
    "interventions", "edit", "unlearn", "compression", "model_name",

    # Editing
    # "edit_set", 
    "edit_dataset", "number_of_edits",

    # Compression
    "wbits", "compression_dataset", "sparsity_ratio",
]
evaluation_columns = [
    "qa_question_count_limit",
    "mmlu accuracy",
    "wmdp_bio accuracy",
    "wmdp_cyber accuracy",
    "PPL",
    "PPL edits",
    "PPl QA",
    "Generalization",
    "FLOPs",
    "Success recall",
    "Generalization recall",
    "Locality",
    "Average bits",
    "Rewrite accuracy",
    "PPl edits unmasked",
    "Local recall",
    "Latency",
]
relevant_columns = setting_columns + evaluation_columns

In [11]:
# Replace 'username/project_name' with your specific project path
# Composable_Interventions
project_paths = [
    'dri-ice/Composable_Interventions',
    'dri-ice/AK_Tests'
]

filter_dict = { 
    "state": "finished",
    # "created_at": {"$gte": "2024-05-20"}
}
data_frames = []
for project_path in project_paths:
    runs = api.runs(project_path, filters=filter_dict)
    
    # Iterate over eachrun and capture the config and summary metrics
    for run in tqdm(runs, desc=project_path):
        try:
            run_start_datetime = datetime.fromtimestamp(run.summary_metrics["_timestamp"])
            start_cutoff = datetime.strptime("2024-05-18 00:00:00", "%Y-%m-%d %H:%M:%S")
            end_cutoff = datetime.strptime("2024-05-21 00:00:00", "%Y-%m-%d %H:%M:%S")
            if run_start_datetime < start_cutoff or run_start_datetime > end_cutoff:
                continue

            skip_tags = ["test", "hparam_search", "none"]
            should_skip = False
            for tag in skip_tags:
                if tag in run.config["tag"].lower():
                    should_skip = True
            
            if should_skip:
                continue

            config_frame = pd.DataFrame([run.config])
            summary_frame = pd.DataFrame([run.summary_metrics])
            combined_frame = pd.concat([config_frame, summary_frame], axis=1)
            data_frames.append(combined_frame)
        except Exception as e:
            print(f"Error processing run {run.id}: {e}")

# Sort by 'tag' and '_timestamp' in descending order to have the most recent run first
all_runs_df = pd.concat(data_frames, ignore_index=True)[relevant_columns]
all_runs_df["interventions"] = all_runs_df["interventions"].astype(str)

# WARNING: WHAT DOES EDIT SET 50 MEAN COMPARED TO EDIT SET 1?
# all_runs_df = all_runs_df[all_runs_df["edit_set"] == 50]
# all_runs_df_sorted = all_runs_df.sort_values(by=['tag', '_timestamp'], ascending=[True, False])
all_runs_df["date"] = pd.to_datetime(all_runs_df["_timestamp"], unit='s')
all_runs_df_sorted = all_runs_df.sort_values(by=['_timestamp'], ascending=[False])

dri-ice/Composable_Interventions:   0%|          | 0/150 [00:00<?, ?it/s]

dri-ice/Composable_Interventions: 100%|██████████| 150/150 [00:00<00:00, 672.82it/s] 


Error processing run n0iel6ok: '_timestamp'
Error processing run xr5mede5: '_timestamp'
Error processing run 27f8pxs0: '_timestamp'


dri-ice/AK_Tests: 100%|██████████| 1624/1624 [00:00<00:00, 1767.07it/s]


In [12]:
all_runs_df_sorted

Unnamed: 0,tag,_timestamp,interventions,edit,unlearn,compression,model_name,edit_dataset,number_of_edits,wbits,...,FLOPs,Success recall,Generalization recall,Locality,Average bits,Rewrite accuracy,PPl edits unmasked,Local recall,Latency,date
1,gptq4bit-rmu,1.716243e+09,"['compress', 'unlearn']",none,rmu,gptq,meta-llama/Meta-Llama-3-8B,zsre,50,4.0,...,-1,0.000000,0.000000,0.031642,4.25,0.000000,507.698822,0.031800,86.307608,2024-05-20 22:05:08.484089613
5,gptq2bit-rmu,1.716243e+09,"['compress', 'unlearn']",none,rmu,gptq,meta-llama/Meta-Llama-3-8B,zsre,50,2.0,...,-1,0.002857,0.006667,0.004933,2.25,0.002857,150067.625000,0.004262,86.962416,2024-05-20 22:03:10.347089291
3,gptq8bit-rmu,1.716243e+09,"['compress', 'unlearn']",none,rmu,gptq,meta-llama/Meta-Llama-3-8B,zsre,50,8.0,...,-1,0.012381,0.010000,0.032513,8.25,0.012821,445.737610,0.032110,86.466912,2024-05-20 22:02:09.332513094
0,rmu-gptq2bit,1.716240e+09,"['unlearn', 'compress']",none,rmu,gptq,meta-llama/Meta-Llama-3-8B,zsre,50,2.0,...,-1,0.000000,0.000000,0.024140,2.25,0.000000,23385.773438,0.023029,88.352600,2024-05-20 21:26:17.526590109
2,rmu-gptq4bit,1.716240e+09,"['unlearn', 'compress']",none,rmu,gptq,meta-llama/Meta-Llama-3-8B,zsre,50,4.0,...,-1,0.014524,0.019048,0.025801,4.25,0.014744,525.068787,0.025197,86.940684,2024-05-20 21:23:45.248690844
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
339,memit_Edit,1.716005e+09,['edit'],memit,none,none,meta-llama/Meta-Llama-3-8B,zsre,50,,...,1.92 TFLOPS,0.990000,0.963333,0.037923,16.00,0.990000,641.421265,0.035993,124.798484,2024-05-18 04:06:08.118285894
341,memit_Edit,1.716005e+09,['edit'],memit,none,none,meta-llama/Meta-Llama-3-8B,zsre,50,,...,1.92 TFLOPS,0.845000,0.739667,0.021008,16.00,0.845000,688.575256,0.020586,132.843399,2024-05-18 04:06:07.325703621
340,memit_Edit,1.716005e+09,['edit'],memit,none,none,meta-llama/Meta-Llama-3-8B,zsre,50,,...,1.92 TFLOPS,0.986667,0.976000,0.063423,16.00,0.989333,621.960205,0.061930,125.165364,2024-05-18 04:04:32.198839188
342,memit_Edit,1.716005e+09,['edit'],memit,none,none,meta-llama/Meta-Llama-3-8B,zsre,50,,...,1.92 TFLOPS,0.975000,0.950000,0.043431,16.00,0.975000,723.989441,0.042454,125.125785,2024-05-18 04:00:56.205363750


In [13]:
all_runs_df_sorted[(all_runs_df_sorted["wbits"] == 16) & (all_runs_df_sorted["compression"] == "none")]

Unnamed: 0,tag,_timestamp,interventions,edit,unlearn,compression,model_name,edit_dataset,number_of_edits,wbits,...,FLOPs,Success recall,Generalization recall,Locality,Average bits,Rewrite accuracy,PPl edits unmasked,Local recall,Latency,date
8,lora-rmu,1716237000.0,"['edit', 'unlearn']",lora,rmu,none,meta-llama/Meta-Llama-3-8B,zsre,50,16.0,...,1.79 TFLOPS,1.0,0.618571,0.057981,16.0,1.0,25660.986328,0.058719,512.197093,2024-05-20 20:31:09.500632286
11,rmu-lora,1716237000.0,"['unlearn', 'edit']",lora,rmu,none,meta-llama/Meta-Llama-3-8B,zsre,50,16.0,...,1.79 TFLOPS,1.0,0.615,0.05421,16.0,1.0,10560.212891,0.054559,515.848708,2024-05-20 20:28:42.994992018
6,memit-rmu,1716234000.0,"['edit', 'unlearn']",memit,rmu,none,meta-llama/Meta-Llama-3-8B,zsre,50,16.0,...,1.92 TFLOPS,0.952,0.932,0.018525,16.0,0.952,435.951874,0.018666,95.736313,2024-05-20 19:38:41.118072510
7,ft-rmu,1716234000.0,"['edit', 'unlearn']",ft,rmu,none,meta-llama/Meta-Llama-3-8B,zsre,50,16.0,...,1.92 TFLOPS,1.0,0.799857,0.144349,15.999969,1.0,542.750366,0.142683,96.74437,2024-05-20 19:38:26.891535759
9,rmu-memit,1716234000.0,"['unlearn', 'edit']",memit,rmu,none,meta-llama/Meta-Llama-3-8B,zsre,50,16.0,...,1.92 TFLOPS,0.985,0.95,0.034166,16.0,0.985,433.229065,0.033793,95.927453,2024-05-20 19:37:23.770762682
10,rmu-ft,1716234000.0,"['unlearn', 'edit']",ft,rmu,none,meta-llama/Meta-Llama-3-8B,zsre,50,16.0,...,1.92 TFLOPS,1.0,0.78,0.131125,15.999967,1.0,579.778748,0.129552,96.83709,2024-05-20 19:33:35.887049198


In [14]:
# Drop duplicates, keeping only the first occurrence (which is the most recent due to sorting)
all_runs_df_deduplicated = all_runs_df_sorted.drop_duplicates(subset=[col for col in setting_columns if col not in ["_timestamp", "tag", "date"]], keep="first")
# all_runs_df_deduplicated = all_runs_df_sorted
all_runs_df_deduplicated["interventions"] = all_runs_df_deduplicated["interventions"].apply(lambda x : ast.literal_eval(x))

rename_dict = {
    "meta-llama/Meta-Llama-3-8B" : "Llama-3 (8b)",
    "ft" : "Fine-tune",
    "memit" : "MEMIT",
    "lora" : "Lora",
    "wanda" : "Wanda",
    "sparsegpt" : "SparseGPT",
    "gptq" : "GPTQ",
    "awq" : "AWQ",
    "rmu" : "RMU",
}
metrics = all_runs_df_deduplicated
metrics["model_name"] = metrics["model_name"].apply(lambda x : rename_dict.get(x, None))
metrics["edit"] = metrics["edit"].apply(lambda x : rename_dict.get(x, None))
metrics["compression"] = metrics["compression"].apply(lambda x : rename_dict.get(x, None))
metrics["unlearn"] = metrics["unlearn"].apply(lambda x : rename_dict.get(x, None))
all_runs_df_deduplicated = metrics
all_runs_df_deduplicated.value_counts("tag")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  all_runs_df_deduplicated["interventions"] = all_runs_df_deduplicated["interventions"].apply(lambda x : ast.literal_eval(x))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  metrics["model_name"] = metrics["model_name"].apply(lambda x : rename_dict.get(x, None))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a

tag
AWQ2bit-to-ft              1
memit-to-SparseGPT0.45%    1
memit-to-GPTQ8bit          1
memit-to-GPTQ4bit          1
memit-to-GPTQ2bit          1
                          ..
SparseGPT0.65%-to-memit    1
SparseGPT0.65%-to-lora     1
SparseGPT0.65%-to-ft       1
SparseGPT0.45%-to-memit    1
wanda0.65\%-rmu            1
Name: count, Length: 107, dtype: int64

In [15]:
all_runs_df_deduplicated[all_runs_df_deduplicated["compression"] == "SparseGPT"]

Unnamed: 0,tag,_timestamp,interventions,edit,unlearn,compression,model_name,edit_dataset,number_of_edits,wbits,...,FLOPs,Success recall,Generalization recall,Locality,Average bits,Rewrite accuracy,PPl edits unmasked,Local recall,Latency,date
12,sparsegpt0.65\%-rmu,1716237000.0,"[compress, unlearn]",,RMU,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,759.86 GFLOPS,0.0,0.006667,0.022232,6.249982,0.0,1688.532227,0.020173,-1.0,2024-05-20 20:26:57.591763496
13,sparsegpt0.45\%-rmu,1716237000.0,"[compress, unlearn]",,RMU,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,1.12 TFLOPS,0.0,0.005,0.023685,9.249988,0.0,637.708252,0.023586,-1.0,2024-05-20 20:26:14.456046343
14,sparsegpt0.25\%-rmu,1716237000.0,"[compress, unlearn]",,RMU,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,1.47 TFLOPS,0.005714,0.025238,0.041017,12.249977,0.006154,481.484344,0.040736,-1.0,2024-05-20 20:24:09.416608095
21,rmu-sparsegpt0.45\%,1716234000.0,"[unlearn, compress]",,RMU,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,1.12 TFLOPS,0.020714,0.021667,0.036096,9.249988,0.021154,637.052124,0.036055,-1.0,2024-05-20 19:38:30.983223438
22,rmu-sparsegpt0.65\%,1716234000.0,"[unlearn, compress]",,RMU,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,759.86 GFLOPS,0.0,0.006667,0.024438,6.249982,0.0,1573.433228,0.022101,-1.0,2024-05-20 19:36:36.984816551
23,rmu-sparsegpt0.25\%,1716234000.0,"[unlearn, compress]",,RMU,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,1.47 TFLOPS,0.017714,0.012857,0.04096,12.249977,0.018154,480.60788,0.039888,-1.0,2024-05-20 19:35:36.819632053
87,SparseGPT0.65%-to-lora,1716164000.0,"[compress, edit]",Lora,,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,625.35 GFLOPS,0.316333,0.195222,0.013377,6.249982,0.316333,1176.053467,0.012868,-1.0,2024-05-20 00:21:08.571938038
88,SparseGPT0.45%-to-lora,1716164000.0,"[compress, edit]",Lora,,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,982.69 GFLOPS,0.712,0.329556,0.033835,9.249988,0.711231,1712.821167,0.034028,-1.0,2024-05-20 00:20:46.149566889
89,SparseGPT0.25%-to-lora,1716164000.0,"[compress, edit]",Lora,,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,1.34 TFLOPS,0.878667,0.528333,0.036167,12.249977,0.878667,2798.60498,0.036143,-1.0,2024-05-20 00:06:42.644053936
92,lora-to-SparseGPT0.45%,1716161000.0,"[edit, compress]",Lora,,SparseGPT,Llama-3 (8b),zsre,50,4.0,...,982.69 GFLOPS,0.224667,0.207333,0.03089,9.249987,0.224667,2524.844482,0.030844,-1.0,2024-05-19 23:29:09.522534609


In [23]:
data = all_runs_df_deduplicated
# print(data["interventions"])
# Select rows where the "interventions" column is exactly ["edit"]
temp = data[data["interventions"].apply(lambda x: x == ["edit", "compress"])]

categories = {
    "No Intervention": data[data["interventions"].apply(lambda x: x == [])].copy(),
    "Editing": data[data["interventions"].apply(lambda x: x == ["edit"])].copy(),
    "Compression": data[data["interventions"].apply(lambda x: x == ["compress"])].copy(),
    "Edit to Compression": data[data["interventions"].apply(lambda x: x == ["edit", "compress"])].copy(),
    "Compression to Edit": data[data["interventions"].apply(lambda x: x == ["compress", "edit"])].copy(),
    "Unlearn": data[data["interventions"].apply(lambda x: x == ["unlearn"])].copy(),
    "Edit to Unlearn": data[data["interventions"].apply(lambda x: x == ["edit", "unlearn"])].copy(),
    "Unlearn to Edit": data[data["interventions"].apply(lambda x: x == ["unlearn", "edit"])].copy(),
    "Compress to Unlearn": data[data["interventions"].apply(lambda x: x == ["compress", "unlearn"])].copy(),
    "Unlearn to Compress": data[data["interventions"].apply(lambda x: x == ["unlearn", "compress"])].copy()
}

assert len(categories["No Intervention"]) == 0 # Should be 1
assert len(categories["Editing"]) == 3 
assert len(categories["Compression"]) == 12
assert len(categories["Edit to Compression"]) == 32 # Should be 36 Missing LoRA Quant
display(categories["Compression to Edit"])
assert len(categories["Compression to Edit"]) == 0


Unnamed: 0,tag,_timestamp,interventions,edit,unlearn,compression,model_name,edit_dataset,number_of_edits,wbits,...,FLOPs,Success recall,Generalization recall,Locality,Average bits,Rewrite accuracy,PPl edits unmasked,Local recall,Latency,date
51,GPTQ2bit-to-memit,1716229000.0,"[compress, edit]",MEMIT,,GPTQ,Llama-3 (8b),zsre,50,2.0,...,-1,0.002857,0.005,0.010902,2.25,0.003333,165747.8,0.009673,88.586112,2024-05-20 18:13:41.270603180
62,GPTQ4bit-to-memit,1716229000.0,"[compress, edit]",MEMIT,,GPTQ,Llama-3 (8b),zsre,50,4.0,...,-1,0.780238,0.717095,0.033183,4.25,0.780238,484.5504,0.034467,87.214703,2024-05-20 18:09:37.249646425
61,GPTQ8bit-to-memit,1716229000.0,"[compress, edit]",MEMIT,,GPTQ,Llama-3 (8b),zsre,50,8.0,...,-1,0.976429,0.885095,0.027154,8.25,0.976429,481.2392,0.026999,87.147286,2024-05-20 18:09:00.776746988
30,AWQ8bit-to-ft,1716227000.0,"[compress, edit]",Fine-tune,,AWQ,Llama-3 (8b),zsre,50,8.0,...,-1,1.0,0.845714,0.160518,8.25,1.0,777.6783,0.15967,91.074401,2024-05-20 17:43:14.585030079
32,AWQ4bit-to-ft,1716227000.0,"[compress, edit]",Fine-tune,,AWQ,Llama-3 (8b),zsre,50,4.0,...,-1,1.0,0.845714,0.169455,4.25,1.0,540.7035,0.169852,90.909752,2024-05-20 17:40:15.012472868
33,AWQ2bit-to-ft,1716227000.0,"[compress, edit]",Fine-tune,,AWQ,Llama-3 (8b),zsre,50,2.0,...,-1,0.0,0.0,0.0,2.25,0.0,78554.23,0.0,90.894511,2024-05-20 17:38:54.680141449
52,AWQ2bit-to-memit,1716224000.0,"[compress, edit]",MEMIT,,AWQ,Llama-3 (8b),zsre,50,2.0,...,-1,0.0,0.0,0.0,2.25,0.0,1074956.0,0.0,90.283724,2024-05-20 17:01:28.464071751
57,AWQ4bit-to-memit,1716224000.0,"[compress, edit]",MEMIT,,AWQ,Llama-3 (8b),zsre,50,4.0,...,-1,0.941762,0.877476,0.035476,4.25,0.942641,476.1157,0.033445,91.889338,2024-05-20 17:01:01.431791544
58,AWQ8bit-to-memit,1716224000.0,"[compress, edit]",MEMIT,,AWQ,Llama-3 (8b),zsre,50,8.0,...,-1,0.965,0.955,0.029943,8.25,0.965,514.4988,0.030221,91.65965,2024-05-20 17:00:40.140754700
69,GPTQ8bit-to-lora,1716168000.0,"[compress, edit]",Lora,,GPTQ,Llama-3 (8b),zsre,50,4.0,...,-1,0.140889,0.045889,0.023484,4.25,0.140889,4293.976,0.023364,88.206251,2024-05-20 01:21:30.996167183


AssertionError: 

In [None]:
def format_flops(value):
    """ Format FLOPs with three significant figures and appropriate suffix. """
    try:
        if isinstance(value, str):
            value = clean_numeric_value(value)
        if abs(value) < 1e6:  # Less than 1 million (below Mega)
            return "{:.3g}k".format(value / 1e3)
        elif abs(value) < 1e9:  # Mega to Giga range
            return "{:.3g}M".format(value / 1e6)
        elif abs(value) < 1e12:  # Giga to Tera range
            return "{:.3g}G".format(value / 1e9)
        else:  # Tera and above
            return "{:.3g}T".format(value / 1e12)
    except Exception as e:
        print(f"Error formatting FLOPs value {value}: {e}")
        return "---"

def escape_latex_special_chars(s):
    """ Escape special characters in LaTeX strings. """
    return str(s).replace('%', '\\%').replace('_', '\\_').replace('&', '\\&').replace('#', '\\#').replace('$', '\\$')

def clean_numeric_value(value):
    """ Convert a string with units to a numeric value. """
    try:
        value = str(value)
        if ' TFLOPS' in value:
            return float(value.replace(' TFLOPS', '')) * 1e12
        if ' GFLOPS' in value:
            return float(value.replace(' GFLOPS', '')) * 1e9
        if ' MFLOPS' in value:
            return float(value.replace(' MFLOPS', '')) * 1e6
        if ' kFLOPS' in value:
            return float(value.replace(' kFLOPS', '')) * 1e3
        return pd.to_numeric(value, errors='coerce')
    except Exception as e:
        print(f"Error cleaning value {value}: {e}")
        return pd.NA

def categorize_and_generate_latex(data):
    # Define categories based on the provided criteria
    categories = {
    "No Intervention": data[data['interventions'].apply(lambda x: x == [])].copy(),
    "Editing": data[data['interventions'].apply(lambda x: x == ['edit'])].copy(),
    "Compression": data[data['interventions'].apply(lambda x: x == ['compress'])].copy(),
    "Edit to Compression": data[data['interventions'].apply(lambda x: x == ['edit', 'compress'])].copy(),
    "Compression to Edit": data[data['interventions'].apply(lambda x: x == ['compress', 'edit'])].copy(),
    "Unlearn": data[data['interventions'].apply(lambda x: x == ['unlearn'])].copy(),
    "Edit to Unlearn": data[data['interventions'].apply(lambda x: x == ['edit', 'unlearn'])].copy(),
    "Unlearn to Edit": data[data['interventions'].apply(lambda x: x == ['unlearn', 'edit'])].copy(),
    "Compress to Unlearn": data[data['interventions'].apply(lambda x: x == ['compress', 'unlearn'])].copy(),
    "Unlearn to Compress": data[data['interventions'].apply(lambda x: x == ['unlearn', 'compress'])].copy()
}
    # Clean numeric columns
    for col in ["FLOPs", "Latency"]:
        if col in data.columns:
            data.loc[:, col] = data[col].apply(clean_numeric_value)
            data.loc[:, col] = pd.to_numeric(data[col], errors='coerce')  # Ensure all values are numeric

    # Column mappings
    column_mappings = {
        "Success": "Rewrite accuracy",
        "Generalization": "Generalization",
        "Locality": "Locality",
        "Avg. Bits": "Average bits",
        "FLOPs": "FLOPs",
        "PPL": "PPL",
        "MMLU": "mmlu accuracy",
        "WMDP Bio": "wmdp_bio accuracy",
        "WMDP Cyber": "wmdp_cyber accuracy"
    }
    latex_columns = ["Success", "Generalization", "Locality", "Avg. Bits", "FLOPs", "PPL", "MMLU", "WMDP Bio", "WMDP Cyber"]

    # Initialize output string
    output_str = ""

    for category, group in categories.items():
        if group.empty:
            continue
        # output_str += f"\\textbf{{{category}}} \\\\ \\midline\n"
        for _, row in group.iterrows():
            # Calculate mean and std for each relevant column within the group
            stats = {}
            for latex_col, csv_col in column_mappings.items():
                if csv_col in row.index:
                    value = row[csv_col]
                    if pd.isna(value):
                        stats[latex_col] = "---"
                    else:
                        # Custom formatting for FLOPs and Latency
                        if latex_col == "FLOPs":
                            mean_str = format_flops(value)
                            stats[latex_col] = escape_latex_special_chars(mean_str)
                        elif latex_col == "Latency":
                            mean_str = f"{value:.3f}s"
                            stats[latex_col] = escape_latex_special_chars(mean_str)
                        else:
                            mean_str = f"{value:.3f}"
                            stats[latex_col] = escape_latex_special_chars(mean_str)
                else:
                    stats[latex_col] = "---"

            # Prepare the LaTeX row for the current group
            latex_row = escape_latex_special_chars(row['tag'])  # Use the tag name directly without escaping
            for column in latex_columns:
                latex_row += " & " + stats.get(column, "---")
            latex_row += " \\\\"

            # Append to output string
            output_str += latex_row + "\n"
        
        output_str += "\\midrule\n"

    return output_str

latex_rows_with_categories = categorize_and_generate_latex(all_runs_df_deduplicated)
print(latex_rows_with_categories)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Set the font family to serif
plt.rcParams['font.family'] = 'serif'

# Assuming 'metrics' DataFrame is already loaded from the CSV

# Define the criteria for filtering
selected_method = 'AWQ'
edit_method = ["FT"]

# Filter based on the criteria
edit_then_compress = categories['Edit to Compression']
compress_then_edit = categories['Compression to Edit']

# Filter based on selected method
edit_then_compress = edit_then_compress[edit_then_compress['compression']==selected_method]
compress_then_edit = compress_then_edit[compress_then_edit['compression']==selected_method]

# Add baselines to dfs
baseline = categories['Editing']
baseline['wbits'] = 16
edit_then_compress = pd.concat([edit_then_compress, baseline], axis=0)
compress_then_edit = pd.concat([compress_then_edit, baseline], axis=0)

# Sort by 'wbits' in ascending order
edit_then_compress = edit_then_compress.sort_values(by='wbits')
compress_then_edit = compress_then_edit.sort_values(by='wbits')

# Define the metrics to plot
metrics_to_plot = ['Rewrite accuracy', 'Generalization', 'mmlu']
x_axis_metric = 'wbits'

# Compute baselines
# edit_then_compress_baselines = {model: edit_then_compress[(edit_then_compress['model_name'] == model) & (edit_then_compress['sparsity_ratio'] == 0)][metrics_to_plot].mean() for model in included_models}
# compress_then_edit_baselines = {model: compress_then_edit[(compress_then_edit['model_name'] == model) & (compress_then_edit['sparsity_ratio'] == 0)][metrics_to_plot].mean() for model in included_models}
# edit_then_compress_baselines = categories['No Intervention']
# compress_then_edit_baselines = categories['No Intervention']

# Define plot parameters
title_fontsize = 20
label_fontsize = 20
legend_fontsize = 18
tick_fontsize = 18
line_width = 3
marker_size = 8

# Create subplots
fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=(15, 5))

# Iterate over each metric and plot
for i, metric in enumerate(metrics_to_plot):
    ax = axes[i]
    
    # Plot the data with scatter and lines
    ax.plot(edit_then_compress['wbits'], edit_then_compress[metric], linestyle='--', marker='o', markerfacecolor='purple', color='purple', label='Edit then compress',
            linewidth=line_width, markersize=marker_size, markeredgewidth=line_width)
    ax.plot(compress_then_edit['wbits'], compress_then_edit[metric], linestyle='-', marker='o', markerfacecolor='none', color='purple', label='Compress then edit',
            linewidth=line_width, markersize=marker_size, markeredgewidth=line_width)
    
    # Fill the area between the lines
    ax.fill_between(edit_then_compress['wbits'], edit_then_compress[metric], compress_then_edit[metric], color='purple', alpha=0.2)
    
    # Integrate baselines into the scatter plots
    for model in included_models:
        baseline_edit = edit_then_compress_baselines[metric]
        baseline_compress = compress_then_edit_baselines[metric]
        
        if x_axis_metric == 'Average bits':
            baseline_x = 16
            ax.set_xlim(2, 16)
        elif x_axis_metric == 'sparsity_ratio':
            baseline_x = 0.0
            ax.set_xlim(0, 1)
        else:
            baseline_x = 0  # Adjust based on your default x-axis range

        # Add baselines to the scatter plots
        # ax.scatter([baseline_x], [baseline_edit], color='purple', marker='o', s=marker_size**2, edgecolor='purple', linewidth=line_width)
        # ax.scatter([baseline_x], [baseline_compress], color='purple', marker='o', s=marker_size**2, edgecolor='purple', linewidth=line_width)
    if x_axis_metric == 'wbits':
        ax.set_xlabel('Bits', fontsize=label_fontsize)
    else:
        ax.set_xlabel(x_axis_metric, fontsize=label_fontsize)
    if metric == 'Rewrite accuracy':
        ax.set_ylabel('Edit success', fontsize=label_fontsize)
    else:
        ax.set_ylabel(metric, fontsize=label_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)

# Move the legend to the bottom of the figure
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', fontsize=legend_fontsize, ncol=2)

plt.tight_layout(rect=[0, 0.1, 1, 1])  # Adjust the bottom margin to make space for the legend
plt.show()
plt.savefig('figures/memit-gptq.pdf', format='pdf')
