In [3]:
import sys
from pathlib import Path

# Add the src directory to Python path
src_path = str(Path("./").resolve().parent.parent)
if src_path not in sys.path:
    sys.path.append(src_path)

sys.path

from collections import defaultdict

import pandas as pd

from src.classifying import (
    ActivationsHandler,
    combine_activations_handlers,
    get_correctness_direction_classifier,
    get_logistic_regression_classifier,
    get_between_class_variance_and_within_class_variance,
)
# from src.visualisations.utils import plot_interactive_lineplot
from src.utils.data import load_activations, load_labels, get_experiment_activations_configs_df_subset

In [4]:
models = [
    "ministral_8b_instruct",
    "qwen_2.5_7b_instruct",
    "llama3.1_8b_chat",
    "llama3.3_70b",
    "deepseek_qwen_32b",
    "mistral_7b_instruct",
]

BASE_PATH = {}
BASE_PATH["llama3.1_8b_chat"] = "/runpod-volume/anton/correctness-model-internals/data_for_classification"
BASE_PATH["llama3.3_70b"] = "/runpod-volume/arnau/correctness-model-internals/data_for_classification"
BASE_PATH["deepseek_qwen_32b"] = "/runpod-volume/arnau/correctness-model-internals/data_for_classification"
BASE_PATH["mistral_7b_instruct"] = "/runpod-volume/anton/correctness-model-internals/data_for_classification"
BASE_PATH["ministral_8b_instruct"] = "/runpod-volume/arnau/correctness-model-internals/data_for_classification"
BASE_PATH["qwen_2.5_7b_instruct"] = "/runpod-volume/arnau/correctness-model-internals/data_for_classification"

layers = {}
layers["llama3.1_8b_chat"] = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
layers["llama3.3_70b"] = [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76]
layers["deepseek_qwen_32b"] = [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]
layers["mistral_7b_instruct"] = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
layers["ministral_8b_instruct"] = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34]
layers["qwen_2.5_7b_instruct"] = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]

aurocs = {}
aurocs["llama3.1_8b_chat"] = {}
aurocs["llama3.3_70b"] = {}
aurocs["deepseek_qwen_32b"] = {}
aurocs["mistral_7b_instruct"] = {}
aurocs["ministral_8b_instruct"] = {}
aurocs["qwen_2.5_7b_instruct"] = {}

In [None]:
for model_id in models:
    print(model_id)
    BASE_PATH_MODEL = BASE_PATH[model_id]
    dataset_id = "trivia_qa_2_60k"
    prompt_id = "base"
    subset_id = "main"
    input_type = "prompt_only"
    n_folds = 3
    layers_model = layers[model_id]
    best_auc = -1
    best_layer = None
    for layer in layers_model:
        activations, indices = load_activations(
            base_path=BASE_PATH_MODEL,
            model_id=model_id,
            dataset_id=dataset_id,
            prompt_id=prompt_id,
            subset_id=subset_id,
            input_type=input_type,
            layer=layer,
        )
        labels_df = load_labels(
            base_path=BASE_PATH_MODEL,
            model_id=model_id,
            dataset_id=dataset_id,
            prompt_id=prompt_id,
            subset_id=subset_id,
        )
        activations, indices, labels_df = activations[:10000], indices[:10000], labels_df.iloc[:10000]
        activations_handler = ActivationsHandler(
            activations=activations,
            labels=labels_df["correct"].astype(bool),
        )
        activations_handler_folds = list(
            activations_handler.split_dataset(split_sizes=[1/n_folds] * n_folds)
        )
        test_aurocs = []
        for i in range(n_folds):
            activations_handler_test = activations_handler_folds[i]
            activations_handler_train = combine_activations_handlers(
                [ah for j, ah in enumerate(activations_handler_folds) if j != i]
            )
            activations_handler_train = activations_handler_train.sample_equally_across_groups(
                group_labels=[False, True]
            )
            activations_handler_test = activations_handler_test.sample_equally_across_groups(
                group_labels=[False, True]
            )
            direction_classifier, direction_calculator = get_correctness_direction_classifier(
                activations_handler_train=activations_handler_train,
                activations_handler_test=activations_handler_test,
            )
            test_aurocs.append(direction_classifier.classification_metrics['test_roc_auc'])
        print(layer, "AUROC", sum(test_aurocs) / len(test_aurocs))
        aurocs[model_id][layer] = sum(test_aurocs) / len(test_aurocs)
        if sum(test_aurocs) / len(test_aurocs) > best_auc:
            best_auc = sum(test_aurocs) / len(test_aurocs)
            best_layer = layer
    print("Best layer", best_layer, "Best AUROC", best_auc)
    print()

ministral_8b_instruct
