In [1]:
import sys
from pathlib import Path

# Add the src directory to Python path
src_path = str(Path("./").resolve().parent.parent)
if src_path not in sys.path:
    sys.path.append(src_path)

sys.path

['/Library/Frameworks/Python.framework/Versions/3.11/lib/python311.zip',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload',
 '',
 '/Users/anton/dev/MARS/correctness-model-internals/venv/lib/python3.11/site-packages',
 '/Users/anton/dev/MARS/correctness-model-internals']

In [2]:
from pathlib import Path
from collections import defaultdict

import torch as pt
import pandas as pd
# import seaborn as sns
import plotly.graph_objects as go

from src.classifying import (
    ActivationsHandler,
    combine_activations_handlers,
    get_correctness_direction_classifier,
    get_logistic_regression_classifier,
)

# sns.set_theme(style="whitegrid")


In [3]:
def load_activations(
    model_id,
    dataset_id,
    prompt_id,
    subset_id,
    input_type,
    layer,
    batch_ids=None,
):
    if batch_ids:
        batch_ids = [int(batch_id) for batch_id in batch_ids]

    paths = sorted(
        list(
            Path(
                f"../../activations/{model_id}/{dataset_id}/{prompt_id}/{subset_id}/{input_type}/layer_{layer}"
            ).iterdir()
        ),
        key=lambda p: int(p.stem.split("_")[-1]),
    )

    activations_list, indices = [], []
    batch_size = None
    for batch_file in paths:
        batch_id = int(batch_file.stem.split("_")[-1])
        if batch_ids and batch_id not in batch_ids:
            continue

        activations = pt.load(batch_file, map_location=pt.device("cpu"))
        activations_list.append(activations)

        batch_size = activations.shape[0]

        if batch_size is None:
            batch_size = activations.shape[0]
        else:
            assert batch_size == activations.shape[0]

        indices.append(
            pd.Series(range(batch_size), name="index") + batch_id
        )


    return (
        pt.cat(activations_list, dim=0),
        pd.concat(indices, axis=0).reset_index(drop=True),
    )


def load_labels(model_id, dataset_id, prompt_id, subset_id, indices=None):
    paths = list(
        Path(f"../../evaluations/{model_id}/{dataset_id}/{prompt_id}/").iterdir()
    )
    for path in paths:
        filename = path.stem
        if subset_id != filename.split("_generations_evaluated")[0]:
            continue
        df = pd.read_csv(path)
        if indices is not None:
            df = df.iloc[indices].reset_index(drop=True)
        return df
    raise ValueError(
        f"No labels found for {model_id} {dataset_id} {prompt_id} {subset_id}"
    )


In [4]:
all_activation_exp_configs = defaultdict(list)
for model_path in Path("../../activations").iterdir():
    for dataset_path in model_path.iterdir():
        for prompt_path in dataset_path.iterdir():
            for subset_path in prompt_path.iterdir():
                for input_type_path in subset_path.iterdir():
                    for layer_path in input_type_path.iterdir():
                        all_activation_exp_configs["model_id"].append(model_path.name)
                        all_activation_exp_configs["dataset_id"].append(dataset_path.name)
                        all_activation_exp_configs["prompt_id"].append(prompt_path.name)
                        all_activation_exp_configs["subset_id"].append(subset_path.name)
                        all_activation_exp_configs["input_type"].append(input_type_path.name)
                        all_activation_exp_configs["layer"].append(int(layer_path.name.split("_")[-1]))
                        all_activation_exp_configs["path"].append(layer_path)
all_activation_exp_configs_df = pd.DataFrame(all_activation_exp_configs)
all_activation_exp_configs_df


Unnamed: 0,model_id,dataset_id,prompt_id,subset_id,input_type,layer,path
0,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,0,../../activations/llama3.1_8b_chat/football_le...
1,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,7,../../activations/llama3.1_8b_chat/football_le...
2,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,9,../../activations/llama3.1_8b_chat/football_le...
3,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,17,../../activations/llama3.1_8b_chat/football_le...
4,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,28,../../activations/llama3.1_8b_chat/football_le...
...,...,...,...,...,...,...,...
251,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,23,../../activations/llama3.1_8b_chat/birth_years...
252,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,15,../../activations/llama3.1_8b_chat/birth_years...
253,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,12,../../activations/llama3.1_8b_chat/birth_years...
254,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,2,../../activations/llama3.1_8b_chat/birth_years...


In [5]:
# Set None to get all
MODEL_ID = "llama3.1_8b_chat"
# DATASET_ID = "gsm8k"
# PROMPT_ID = "cot_3_shot"
# SUBSET_ID = "main"
# INPUT_TYPE = "prompt_only"
DATASET_ID = None
PROMPT_ID = None
SUBSET_ID = None
INPUT_TYPE = None


activation_exp_configs_df = all_activation_exp_configs_df.copy()

if MODEL_ID != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["model_id"] == MODEL_ID]
if DATASET_ID != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["dataset_id"] == DATASET_ID]
if PROMPT_ID != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["prompt_id"] == PROMPT_ID]
if SUBSET_ID != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["subset_id"] == SUBSET_ID]
if INPUT_TYPE != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["input_type"] == INPUT_TYPE]


activation_exp_configs_df

Unnamed: 0,model_id,dataset_id,prompt_id,subset_id,input_type,layer,path
0,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,0,../../activations/llama3.1_8b_chat/football_le...
1,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,7,../../activations/llama3.1_8b_chat/football_le...
2,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,9,../../activations/llama3.1_8b_chat/football_le...
3,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,17,../../activations/llama3.1_8b_chat/football_le...
4,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,28,../../activations/llama3.1_8b_chat/football_le...
...,...,...,...,...,...,...,...
251,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,23,../../activations/llama3.1_8b_chat/birth_years...
252,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,15,../../activations/llama3.1_8b_chat/birth_years...
253,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,12,../../activations/llama3.1_8b_chat/birth_years...
254,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,2,../../activations/llama3.1_8b_chat/birth_years...


In [6]:
res_dict = defaultdict(list)

for (model_id, prompt_id, subset_id, input_type), config_df in activation_exp_configs_df.groupby(["model_id", "prompt_id", "subset_id", "input_type"]):
    print(f"\n{model_id=}, {prompt_id=}, {subset_id=}, {input_type=}")
    
    config_df["layer"] = config_df["layer"].astype(int)
    config_df = config_df.sort_values(by="layer")

    check_indices = {}
    for layer, layer_config_df in config_df.groupby("layer"):
        print(f"{layer=}", end=", ")

        dataset_activation_handlers = {}
        for _, row in layer_config_df.iterrows():
            dataset_id = row["dataset_id"]

            labels_df = load_labels(
                model_id=model_id,
                dataset_id=dataset_id,
                prompt_id=prompt_id,
                subset_id=subset_id,
            )
            activations, indices = load_activations(
                model_id=model_id,
                dataset_id=dataset_id,
                prompt_id=prompt_id,
                subset_id=subset_id,
                input_type=input_type,
                layer=layer,
            )

            if dataset_id not in check_indices:
                check_indices[dataset_id] = indices.sample(frac=1, replace=False)
            
            if set(indices) != set(check_indices[dataset_id]):
                raise RuntimeError(f"indices across layers are not the same")

            labels_df_subset = labels_df.iloc[check_indices[dataset_id]].reset_index(drop=True)
            activations_subset = activations[check_indices[dataset_id]]

            # print(f"{dataset_id=}, {labels_df_subset['correct'].astype(bool).value_counts()=}")

            activations_handler = ActivationsHandler(
                activations=activations_subset,
                labels=labels_df_subset["correct"].astype(bool),
            ).sample_equally_across_groups(
                group_labels=[False, True], interleave=True
            )

            # print("activations_handler", activations_handler.labels.value_counts(), end="\n\n")

            activations_handler_folds = list(
                activations_handler.split_dataset(split_sizes=[1/5] * 5)
            )

            # print("activations_handler_folds", [ah.labels.value_counts() for ah in activations_handler_folds], end="\n\n\n\n\n")


            dataset_activation_handlers[dataset_id] = activations_handler_folds
            
        # if check_indices is None:
        #     check_indices = indices.sample(frac=1, replace=False)
        
        # if set(indices) != set(check_indices):
        #     raise RuntimeError(f"indices across layers are not the same")

        # labels_df_subset = labels_df.iloc[check_indices]
        # activations = activations[check_indices]

        # activations_handler = ActivationsHandler(
        #     activations=activations,
        #     labels=labels_df_subset["correct"].astype(bool),
        # )

        # activations_handler_folds = list(
        #     activations_handler.split_dataset(split_sizes=[1/5] * 5)
        # )

        fold_stats = {}
        for fold_i in range(len(list(dataset_activation_handlers.values())[0])):
            train_activations_handlers, test_activations_handlers = {}, {}
            for dataset_id, ah_folds in dataset_activation_handlers.items():

                test_activations_handlers[dataset_id] = ah_folds[fold_i].sample_equally_across_groups(
                    group_labels=[False, True], interleave=True
                )

                if dataset_id == "football_leagues_1k":
                    continue

                train_activations_handlers[dataset_id] = [ah for j, ah in enumerate(ah_folds) if j != fold_i]
            
            activations_handler_train = combine_activations_handlers(
                sum(list(train_activations_handlers.values()), []), # combine the lists
                equal_counts=True
            ).sample_equally_across_groups(
                group_labels=[False, True], interleave=True
            )

            scaler_model_tuple = None
            for test_dataset_id, activations_handler_test in test_activations_handlers.items():
                res_dict["model_id"].append(model_id)
                res_dict["train_dataset_ids"].append(list(train_activations_handlers.keys()))
                res_dict["test_dataset_id"].append(test_dataset_id)
                res_dict["prompt_id"].append(prompt_id)
                res_dict["subset_id"].append(subset_id)
                res_dict["input_type"].append(input_type)
                res_dict["layer"].append(layer)
                res_dict["fold"].append(fold_i)

                direction_classifier, direction_calculator = get_correctness_direction_classifier(
                    activations_handler_train=activations_handler_train,
                    activations_handler_test=activations_handler_test,
                )
                # res_dict["classifying_direction"].append(direction_calculator.classifying_direction.tolist())
                for key, value in direction_classifier.classification_metrics.items():
                    res_dict[f"direction_{key}"].append(value)
                
                logistic_regression_classifier, scaler_model_tuple = get_logistic_regression_classifier(
                    activations_handler_train=activations_handler_train,
                    activations_handler_test=activations_handler_test,
                    scaler_model_tuple=scaler_model_tuple,
                )
                for key, value in logistic_regression_classifier.classification_metrics.items():
                    res_dict[f"logistic_regression_{key}"].append(value)



res_df = pd.DataFrame(res_dict)
res_df


model_id='llama3.1_8b_chat', prompt_id='base', subset_id='main', input_type='prompt_answer'
layer=0, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=1, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=2, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=3, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 
model_id='llama3.1_8b_chat', prompt_id='base', subset_id='main', input_type='prompt_only'
layer=0, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=1, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=2, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=3, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

layer=4, 

STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=5, layer=6, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=7, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=8, layer=9, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=10, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=11, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=12, layer=13, layer=14, layer=15, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,model_id,train_dataset_ids,test_dataset_id,prompt_id,subset_id,input_type,layer,fold,direction_optimal_cut,direction_optimal_train_set_cut,...,direction_f1_score,direction_precision_score,direction_recall_score,logistic_regression_optimal_cut,logistic_regression_optimal_train_set_cut,logistic_regression_test_roc_auc,logistic_regression_accuracy_score,logistic_regression_f1_score,logistic_regression_precision_score,logistic_regression_recall_score
0,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",football_leagues_1k,base,main,prompt_answer,0,0,-0.104880,-0.104880,...,0.637363,0.500000,0.878788,0.5,0.512922,0.523416,0.515152,0.448276,0.520000,0.393939
1,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",medals_9k,base,main,prompt_answer,0,0,-0.104880,-0.104880,...,0.636730,0.543756,0.768056,0.5,0.512922,0.761444,0.710417,0.712215,0.707819,0.716667
2,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",birth_years_4k,base,main,prompt_answer,0,0,-0.104880,-0.104880,...,0.545932,0.553191,0.538860,0.5,0.512922,0.583371,0.588083,0.605459,0.580952,0.632124
3,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",cities_10k,base,main,prompt_answer,0,0,-0.104880,-0.104880,...,0.680504,0.557039,0.874286,0.5,0.512922,0.653194,0.604762,0.628469,0.592905,0.668571
4,llama3.1_8b_chat,"[medals_9k, birth_years_4k, cities_10k]",football_leagues_1k,base,main,prompt_answer,0,1,-0.085766,-0.085766,...,0.635294,0.519231,0.818182,0.5,0.524956,0.539945,0.515152,0.500000,0.516129,0.484848
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1275,llama3.1_8b_chat,"[birth_years_4k, cities_10k, medals_9k]",medals_9k,base,main,prompt_only,31,3,7.130029,7.130029,...,0.537715,0.581583,0.500000,0.5,0.492002,0.726635,0.685417,0.690785,0.679195,0.702778
1276,llama3.1_8b_chat,"[birth_years_4k, cities_10k, medals_9k]",birth_years_4k,base,main,prompt_only,31,4,4.809654,4.809654,...,0.428135,0.522388,0.362694,0.5,0.462916,0.781900,0.727979,0.720000,0.741758,0.699482
1277,llama3.1_8b_chat,"[birth_years_4k, cities_10k, medals_9k]",cities_10k,base,main,prompt_only,31,4,4.809654,4.809654,...,0.482850,0.785408,0.348571,0.5,0.462916,0.774984,0.728571,0.721408,0.740964,0.702857
1278,llama3.1_8b_chat,"[birth_years_4k, cities_10k, medals_9k]",football_leagues_1k,base,main,prompt_only,31,4,4.809654,4.809654,...,0.363636,0.727273,0.242424,0.5,0.462916,0.489440,0.469697,0.461538,0.468750,0.454545


In [13]:
res_file = "./classification_data/res_df_llama31_8B_4_memory_datasets_combined_train_sets.csv"


In [14]:
res_df.to_csv(res_file, index=False)

In [15]:
res_df = pd.read_csv(res_file)

In [16]:
def plot_interactive_lineplot(df_dict, x_label, y_label, title=None, save_path=None):
    """
    df_dict: Dictionary mapping labels to dataframes, where each dataframe contains multiple columns
            for the same measurement (e.g., {'Metric 1': df1, 'Metric 2': df2})
    """
    fig = go.Figure()
    
    # Define a color palette for different metrics
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    
    for metric_idx, (metric_name, df) in enumerate(df_dict.items()):
        color = colors[metric_idx % len(colors)]
        
        # Calculate statistics for each column
        means = df.mean(axis=1)

        # Add mean line
        fig.add_trace(go.Scatter(
            x=df.index,
            y=means,
            mode='lines+markers',
            line=dict(color=color, width=2),
            marker=dict(size=8),
            name=f'{metric_name}',
            legendgroup=metric_name,
        ))

        if df.shape[1] > 1:
            stds = df.std(axis=1)
            mins = df.min(axis=1)
            maxs = df.max(axis=1)

            # Add min/max range (very faint)
            fig.add_trace(go.Scatter(
                x=df.index,
                y=maxs,
                mode='lines',
                line=dict(width=0),
                showlegend=False,
                name=f'{metric_name} Max',
                legendgroup=metric_name,
            ))
            fig.add_trace(go.Scatter(
                x=df.index,
                y=mins,
                mode='lines',
                line=dict(width=0),
                fillcolor=f'rgba{tuple(list(int(color.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) + [0.1])}',
                fill='tonexty',
                showlegend=False,
                name=f'{metric_name} Min',
                legendgroup=metric_name,
            ))

            # Add ±1 std range (moderately faint)
            fig.add_trace(go.Scatter(
                x=df.index,
                y=means + stds,
                mode='lines',
                line=dict(width=0),
                showlegend=False,
                name=f'{metric_name} +1 STD',
                legendgroup=metric_name,
            ))
            fig.add_trace(go.Scatter(
                x=df.index,
                y=means - stds,
                mode='lines',
                line=dict(width=0),
                fillcolor=f'rgba{tuple(list(int(color.lstrip("#")[i:i+2], 16) for i in (0, 2, 4)) + [0.3])}',
                fill='tonexty',
                showlegend=False,
                name=f'{metric_name} -1 STD',
                legendgroup=metric_name,
            ))

            # Add individual points for each fold
            for col in df.columns:
                fig.add_trace(go.Scatter(
                    x=df.index,
                    y=df[col],
                    mode='markers',
                    marker=dict(
                        color=color,
                        size=6,
                        opacity=0.5
                    ),
                    showlegend=False,
                    name=f'{metric_name} Fold {col}',
                    legendgroup=metric_name,
                ))

    fig.update_layout(
        title=title,
        yaxis_title=y_label,
        xaxis_title=x_label,
        template='plotly_dark',
        plot_bgcolor='rgba(32, 32, 32, 1)',
        paper_bgcolor='rgba(32, 32, 32, 1)',
        font=dict(color='white'),
        margin=dict(t=100, l=50, r=30, b=50),  # Increased top margin to accommodate legend
        showlegend=True,
        legend=dict(
            orientation="h",    # Horizontal legend
            yanchor="bottom",  
            y=1.02,           # Position above the plot
            xanchor="center",  # Center horizontally
            x=0.5,
            bgcolor='rgba(32, 32, 32, 0.8)'  # Semi-transparent background
        ),
        width=1000,   # Set explicit width in pixels
        height=600    # Set explicit height in pixels
    )
    
    # Update axes for consistency with dark theme
    fig.update_xaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    fig.update_yaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)

    # Save the plot if path is provided
    if save_path:
        if save_path.endswith('.html'):
            fig.write_html(save_path)
        elif save_path.endswith('.png'):
            fig.write_image(save_path)
            
    
    return fig

In [17]:
for classifier in ["direction", "logistic_regression"]:
    # for metric in ["f1_score", "accuracy_score", "precision_score", "recall_score"]:
    for metric in ["f1_score"]:
        plot_dict = {}
        for conf, res_df_ in res_df.groupby(["model_id", "test_dataset_id", "prompt_id", "subset_id", "input_type"]):
            # if conf[4] != "prompt_only":
            #     continue
            
            print(f"{conf=}")
            res_df_pivot = pd.pivot(
                res_df_.drop(columns=["model_id", "test_dataset_id", "prompt_id", "subset_id", "input_type"]),
                index='layer',
                columns='fold',
                # values=['direction_f1_score', 'logistic_regression_f1_score']  # add all metrics you want to keep
            )
            # for classifier in ["direction", "logistic_regression"]:
            #     for metric in ["f1_score", "accuracy_score", "precision_score", "recall_score"]:
            plot_dict[str(conf)] = res_df_pivot[[f"{classifier}_{metric}"]]

        plot_interactive_lineplot(
            plot_dict,
            x_label="Layer",
            y_label=f"{classifier}_{metric}".replace("_", " ").title(),
            save_path=f"./classification_data/figures/res_df_llama31_8B_4_memory_datasets_combined_train_sets_{classifier}_{metric}.html"
        ).show()



conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_only')


conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_only')


In [12]:
# def plot_interactive_box(df, x_label, y_label, title=None):
#     fig = go.Figure()
    
#     # Create box plot
#     fig.add_trace(go.Box(
#         x=[idx for idx in df.index for _ in range(len(df.columns))],  # Repeat each index for each fold
#         y=df.values.flatten(),  # Flatten all values
#         boxpoints='all',  # Show all points
#         jitter=0,        # No jitter for points
#         pointpos=0,      # Position points at center
#         marker=dict(
#             color='#1f77b4',
#             size=8,
#             opacity=0.5
#         ),
#         line=dict(
#             color='#1f77b4',
#             width=2
#         ),
#         fillcolor='rgba(68, 138, 255, 0.5)',
#         opacity=0.6,
#         showlegend=False,
#         boxmean=True,    # Show mean as a dashed line
#         width=0.5        # Width of boxes
#     ))

#     fig.update_layout(
#         title=title,
#         yaxis_title=y_label,
#         xaxis_title=x_label,
#         template='plotly_dark',
#         plot_bgcolor='rgba(32, 32, 32, 1)',
#         paper_bgcolor='rgba(32, 32, 32, 1)',
#         font=dict(color='white'),
#         margin=dict(t=50, l=50, r=30),
#         showlegend=False
#     )
    
#     # Update axes for consistency with dark theme
#     fig.update_xaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
#     fig.update_yaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    
#     return fig

# plot_interactive_box(res_df, x_label="Layer", y_label="F1 Score").show()