In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from glob import glob
import json
import re
from sklearn.metrics import auc, roc_auc_score, roc_curve, accuracy_score, balanced_accuracy_score, f1_score
from statistics import mean, stdev, variance
from scipy import stats
import numpy as np
from scipy import stats
from math import sqrt
import plotly
from tqdm import tqdm

# Calculating the FMS score

In [None]:
import pandas as pd
import re
from glob import glob
import pandas as pd
from glob import glob
from typing import List, Dict, Optional

def load_local_tree_stats_data(
    file_info_list: List[Dict],
) -> pd.DataFrame:
    """
    Loads and processes a local tree stats CSV using provided metadata.

    Parameters:
    - file_info: Dict with:
        {
            "file": str,
            "model_name": str,
            "concept": str,
            "model_type": str,
        }
    - model_type_fn: Optional function to determine model type per row

    Returns:
    - pd.DataFrame with local model shift (MS_local)
    """
    all_data = pd.DataFrame()
    for file_info in file_info_list:
        df = pd.read_csv(file_info["file"])

        # Inject metadata
        df["model_name"] = file_info["model_name"]
        df["concept"] = file_info["concept"]
        df["model type"] = file_info["model_type"]

        # Compute MS_local
        df["MS_local"] = df.apply(
            lambda x: 2
            * (
                df[
                    (df["Nodes"] == 1)
                    & (df["num_cuts"] == 0)
                ]["Accuracy"].mean()
                - x["Accuracy"]
            )
            if x["num_cuts"] != 0
            else None,
            axis=1,
        )
        all_data = pd.concat([all_data, df], ignore_index=True)

    return all_data[all_data["Nodes"] == 1][
        [
            "num_cuts",
            "concept",
            "model type",
            "MS_local",
        ]
    ]


def load_tree_stats_data(
    file_info_list: List[Dict],
) -> pd.DataFrame:
    """
    Loads and combines tree stats CSVs using provided metadata.

    Parameters:
    - file_info_list: List of dicts with:
        {
            "file": str,                  # Path to CSV file
            "model_name": str,
            "concept": str,              # e.g., "sp", "rtp", or "pii"
            "model_type": str,           # e.g., "G-SAE", "Baseline", etc.
        }

    Returns:
    - pd.DataFrame: Combined and annotated dataset.
    """
    all_depths = pd.DataFrame()

    for info in file_info_list:
        file = info["file"]

        if "cut" in file:
            continue

        df = pd.read_csv(file)

        df["model_name"] = info["model_name"]
        df["concept"] = info["concept"]
        df["model type"] = info["model_type"]

        all_depths = pd.concat([all_depths, df], ignore_index=True)
        all_depths["MS_global"] = all_depths.apply(
            lambda x: 1
            - (
                sum(
                    all_depths[
                        (all_depths["Nodes"] != 1)
                        & (all_depths["model type"] == x["model type"])
                        & (all_depths["concept"] == x["concept"])
                    ]["Accuracy"]
                    - all_depths[
                        (all_depths["Nodes"] == 1)
                        & (all_depths["model type"] == x["model type"])
                        & (all_depths["concept"] == x["concept"])
                    ]["Accuracy"].item()
                )
                / len(
                    all_depths[
                        (all_depths["Nodes"] != 1)
                        & (all_depths["model type"] == x["model type"])
                        & (all_depths["concept"] == x["concept"])
                    ]["Accuracy"]
                )
            )
            if x["Nodes"] == 1
            else None,
            axis=1,
        )

    return all_depths[all_depths["Nodes"] == 1][
        [
            "Accuracy",
            "concept",
            "model type",
            "MS_global",
        ]
    ]


In [None]:
file_info_list = [
    {
        "file": "./llama3_SAE/SAE_eval/SP-Block_v2/sp_tree_valid_llama3-l24576-b03-k2048_s1_statsV2.csv",
        "model_name": "LLaMA3",
        "concept": "Shakespeare",
        "model_type": "G-SAE",
    },
]


In [None]:
df_global = load_tree_stats_data(file_info_list)

In [None]:
file_info_list = [
    {
        "file": "./llama3_SAE/SAE_eval/SP-Block_v2/sp_tree_valid_llama3-l24576-b03-k2048_s1_cut.csv",
        "model_name": "LLaMA3",
        "concept": "Shakespeare",
        "model_type": "G-SAE",
    },
]


In [None]:
df_local = load_local_tree_stats_data(file_info_list)


In [None]:
df = pd.merge(df_local, df_global)

In [None]:
df["FMS"] = df.apply(lambda x: x["Accuracy"] * ((x["MS_local"] + x["MS_global"]) / 2), axis=1)

In [None]:
df[(df["num_cuts"] == 1) | (df["num_cuts"] == 5)].round(2)