In [1]:
import sys
from pathlib import Path

# Add the src directory to Python path
src_path = str(Path("./").resolve().parent.parent)
if src_path not in sys.path:
    sys.path.append(src_path)

sys.path

['/Library/Frameworks/Python.framework/Versions/3.11/lib/python311.zip',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload',
 '',
 '/Users/anton/dev/MARS/correctness-model-internals/venv/lib/python3.11/site-packages',
 '/Users/anton/dev/MARS/correctness-model-internals']

In [2]:
from collections import defaultdict

import pandas as pd

from src.classifying import (
    ActivationsHandler,
    combine_activations_handlers,
    get_correctness_direction_classifier,
    get_logistic_regression_classifier,
    get_between_class_variance_and_within_class_variance,
)
from src.visualisations.utils import plot_interactive_lineplot
from src.utils.data import load_activations, load_labels, get_experiment_activations_configs_df_subset


In [3]:
BASE_PATH = "../../data_for_classification"
PCA_COMPONENTS = 10

PCA_COMONENTS_LABEL = f"_pca_{PCA_COMPONENTS}" if PCA_COMPONENTS else ""

In [4]:
# Set None to get all
MODEL_ID = "llama3.1_8b_chat"
# DATASET_ID = "gsm8k"
# PROMPT_ID = "cot_3_shot"
# SUBSET_ID = "main"
# INPUT_TYPE = "prompt_only"
DATASET_ID = None
PROMPT_ID = None
SUBSET_ID = None
INPUT_TYPE = None

activation_exp_configs_df = get_experiment_activations_configs_df_subset(
    base_path=BASE_PATH,
    model_id=MODEL_ID,
    dataset_id=DATASET_ID,
    prompt_id=PROMPT_ID,
    subset_id=SUBSET_ID,
    input_type=INPUT_TYPE,
)
activation_exp_configs_df

Unnamed: 0,model_id,dataset_id,prompt_id,subset_id,input_type,layer,path
0,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,0,../../data_for_classification/activations/llam...
1,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,7,../../data_for_classification/activations/llam...
2,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,9,../../data_for_classification/activations/llam...
3,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,17,../../data_for_classification/activations/llam...
4,llama3.1_8b_chat,football_leagues_1k,base,main,prompt_only,28,../../data_for_classification/activations/llam...
...,...,...,...,...,...,...,...
379,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,23,../../data_for_classification/activations/llam...
380,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,15,../../data_for_classification/activations/llam...
381,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,12,../../data_for_classification/activations/llam...
382,llama3.1_8b_chat,birth_years_4k,base,main,prompt_answer,2,../../data_for_classification/activations/llam...


In [5]:
between_over_within_class_variance_dict = defaultdict(list)

for (model_id, dataset_id, prompt_id, subset_id, input_type), config_df in activation_exp_configs_df.groupby(["model_id", "dataset_id", "prompt_id", "subset_id", "input_type"]):
    print(f"\n{model_id=}, {dataset_id=}, {prompt_id=}, {subset_id=}, {input_type=}")
    labels_df = load_labels(
        base_path=BASE_PATH,
        model_id=model_id,
        dataset_id=dataset_id,
        prompt_id=prompt_id,
        subset_id=subset_id,
    )

    check_indices = None
    for layer in config_df["layer"].astype(int).sort_values():
        print(f"{layer=}", end=", ")
        activations, indices = load_activations(
            base_path=BASE_PATH,
            model_id=model_id,
            dataset_id=dataset_id,
            prompt_id=prompt_id,
            subset_id=subset_id,
            input_type=input_type,
            layer=layer,
        )
        
        if check_indices is None:
            check_indices = indices.sample(frac=1, replace=False)
        
        if set(indices) != set(check_indices):
            raise RuntimeError(f"indices across layers are not the same")

        labels_df_subset = labels_df.iloc[check_indices]
        activations = activations[check_indices]

        activations_handler = ActivationsHandler(
            activations=activations,
            labels=labels_df_subset["correct"].astype(bool),
        )

        activations_handler, _ = activations_handler.reduce_dims(pca_components=PCA_COMPONENTS)

        between_class_variance, within_class_variance = get_between_class_variance_and_within_class_variance(
            activations_handler.sample_equally_across_groups(group_labels=(False, True)),
            group_labels=(False, True),
        )
        between_over_within_class_variance_dict["model_id"].append(model_id)
        between_over_within_class_variance_dict["dataset_id"].append(dataset_id)
        between_over_within_class_variance_dict["prompt_id"].append(prompt_id)
        between_over_within_class_variance_dict["subset_id"].append(subset_id)
        between_over_within_class_variance_dict["input_type"].append(input_type)
        between_over_within_class_variance_dict["layer"].append(layer)
        between_over_within_class_variance_dict["between_class_variance"].append(between_class_variance)
        between_over_within_class_variance_dict["within_class_variance"].append(within_class_variance)
        between_over_within_class_variance_dict["between_over_within_class_variance"].append(between_class_variance/within_class_variance)

between_over_within_class_variance_df = pd.DataFrame(between_over_within_class_variance_dict)



model_id='llama3.1_8b_chat', dataset_id='birth_years_4k', prompt_id='base', subset_id='main', input_type='prompt_answer'
layer=0, layer=1, layer=2, layer=3, layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 
model_id='llama3.1_8b_chat', dataset_id='birth_years_4k', prompt_id='base', subset_id='main', input_type='prompt_only'
layer=0, layer=1, layer=2, layer=3, layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 
model_id='llama3.1_8b_chat', dataset_id='cities_10k', prompt_id='base', subset_id='main', input_type='prompt_answer'
layer=0, layer=1, la

In [6]:
plot_interactive_lineplot(
    df_dict={str(k): v.set_index("layer")[["between_over_within_class_variance"]] for k, v in between_over_within_class_variance_df.groupby(["model_id", "dataset_id", "prompt_id", "subset_id", "input_type"])},
    x_label="Layer",
    y_label="Between Class Variance / Within Class Variance",
    save_path=f"./classification_data/figures/res_df_llama31_8B_4_memory_datasets_between_over_within_class_variance{PCA_COMONENTS_LABEL}.html"
)

In [None]:

res_dict = defaultdict(list)

for (model_id, dataset_id, prompt_id, subset_id, input_type), config_df in activation_exp_configs_df.groupby(["model_id", "dataset_id", "prompt_id", "subset_id", "input_type"]):
    print(f"\n{model_id=}, {dataset_id=}, {prompt_id=}, {subset_id=}, {input_type=}")
    labels_df = load_labels(
        base_path=BASE_PATH,
        model_id=model_id,
        dataset_id=dataset_id,
        prompt_id=prompt_id,
        subset_id=subset_id,
    )

    check_indices = None
    for layer in config_df["layer"].astype(int).sort_values():
        print(f"{layer=}", end=", ")
        activations, indices = load_activations(
            base_path=BASE_PATH,
            model_id=model_id,
            dataset_id=dataset_id,
            prompt_id=prompt_id,
            subset_id=subset_id,
            input_type=input_type,
            layer=layer,
        )
        
        if check_indices is None:
            check_indices = indices.sample(frac=1, replace=False)
        
        if set(indices) != set(check_indices):
            raise RuntimeError(f"indices across layers are not the same")

        labels_df_subset = labels_df.iloc[check_indices]
        activations = activations[check_indices]

        activations_handler = ActivationsHandler(
            activations=activations,
            labels=labels_df_subset["correct"].astype(bool),
        )

        activations_handler_folds = list(
            activations_handler.split_dataset(split_sizes=[1/5] * 5)
        )

        fold_stats = {}
        for fold_i, activations_handler_test in enumerate(activations_handler_folds):
            activations_handler_train = combine_activations_handlers(
                [ah for j, ah in enumerate(activations_handler_folds) if j != fold_i]
            )
            
            activations_handler_train, pca_info = activations_handler_train.reduce_dims(pca_components=PCA_COMPONENTS)
            
            activations_handler_train = activations_handler_train.sample_equally_across_groups(
                group_labels=[False, True]
            )
            activations_handler_test, _ = activations_handler_test.reduce_dims(pca_components=PCA_COMPONENTS, pca_info=pca_info)
            activations_handler_test = activations_handler_test.sample_equally_across_groups(
                group_labels=[False, True]
            )

            res_dict["model_id"].append(model_id)
            res_dict["dataset_id"].append(dataset_id)
            res_dict["prompt_id"].append(prompt_id)
            res_dict["subset_id"].append(subset_id)
            res_dict["input_type"].append(input_type)
            res_dict["layer"].append(layer)
            res_dict["fold"].append(fold_i)

            direction_classifier, direction_calculator = get_correctness_direction_classifier(
                activations_handler_train=activations_handler_train,
                activations_handler_test=activations_handler_test,
            )
            # res_dict["classifying_direction"].append(direction_calculator.classifying_direction.tolist())
            for key, value in direction_classifier.classification_metrics.items():
                res_dict[f"direction_{key}"].append(value)
            
            for key, value in get_logistic_regression_classifier(
                    activations_handler_train=activations_handler_train,
                    activations_handler_test=activations_handler_test,
                )[0].classification_metrics.items():
                res_dict[f"logistic_regression_{key}"].append(value)



res_df = pd.DataFrame(res_dict)
res_df


model_id='llama3.1_8b_chat', dataset_id='birth_years_4k', prompt_id='base', subset_id='main', input_type='prompt_answer'
layer=0, layer=1, layer=2, layer=3, layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 
model_id='llama3.1_8b_chat', dataset_id='birth_years_4k', prompt_id='base', subset_id='main', input_type='prompt_only'
layer=0, layer=1, layer=2, layer=3, layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, layer=28, layer=29, layer=30, layer=31, 
model_id='llama3.1_8b_chat', dataset_id='cities_10k', prompt_id='base', subset_id='main', input_type='prompt_answer'
layer=0, layer=1, la

In [8]:
res_file = "./classification_data/res_df_llama31_8B_4_memory_datasets.csv"

In [9]:
res_df.to_csv(res_file, index=False)

In [10]:
res_df = pd.read_csv(res_file)

In [11]:
for classifier in ["direction", "logistic_regression"]:
    # for metric in ["f1_score", "accuracy_score", "precision_score", "recall_score"]:
    for metric in ["f1_score"]:
        plot_dict = {}
        for conf, res_df_ in res_df.groupby(["model_id", "dataset_id", "prompt_id", "subset_id", "input_type"]):
            # if conf[4] != "prompt_only":
            #     continue
            
            print(f"{conf=}")
            res_df_pivot = pd.pivot(
                res_df_.drop(columns=["model_id", "dataset_id", "prompt_id", "subset_id", "input_type"]),
                index='layer',
                columns='fold',
                # values=['direction_f1_score', 'logistic_regression_f1_score']  # add all metrics you want to keep
            )
            # for classifier in ["direction", "logistic_regression"]:
            #     for metric in ["f1_score", "accuracy_score", "precision_score", "recall_score"]:
            plot_dict[str(conf)] = res_df_pivot[[f"{classifier}_{metric}"]]

        plot_interactive_lineplot(
            plot_dict,
            x_label="Layer",
            y_label=f"{classifier}_{metric}".replace("_", " ").title(),
            save_path=f"./classification_data/figures/res_df_llama31_8B_4_memory_datasets_{classifier}_{metric}{PCA_COMONENTS_LABEL}.html"
        ).show()



conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'gsm8k', 'base_3_shot', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'gsm8k', 'base_3_shot', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'gsm8k', 'cot_3_shot', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'gsm8k', 'cot_3_shot', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_only')


conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'birth_years_4k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'cities_10k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'football_leagues_1k', 'base', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'gsm8k', 'base_3_shot', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'gsm8k', 'base_3_shot', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'gsm8k', 'cot_3_shot', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'gsm8k', 'cot_3_shot', 'main', 'prompt_only')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_answer')
conf=('llama3.1_8b_chat', 'medals_9k', 'base', 'main', 'prompt_only')
