In [1]:
import sys
from pathlib import Path

# Add the src directory to Python path
src_path = str(Path("./").resolve().parent)
if src_path not in sys.path:
    sys.path.append(src_path)

sys.path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import re
import torch
from src.utils.data import load_activations, load_labels, get_experiment_activations_configs_df_subset

In [2]:
def extract_number(filename):
    match = re.search(r'\d+', filename)
    return int(match.group()) if match else -1

def intra_class_variance_over_inter_class_variance(vectors, labels):
    vectors = vectors.numpy()
    class_0_vectors = vectors[labels == 0]
    class_1_vectors = vectors[labels == 1]
    class_0_mean = np.mean(class_0_vectors, axis=0)
    class_1_mean = np.mean(class_1_vectors, axis=0)
    class_0_variance = np.mean(np.linalg.norm(class_0_vectors - class_0_mean, axis=1))
    class_1_variance = np.mean(np.linalg.norm(class_1_vectors - class_1_mean, axis=1))
    total_mean = np.mean(vectors, axis=0)
    # return between-class variance over within-class variance
    return (np.linalg.norm(class_0_mean - total_mean) + np.linalg.norm(class_1_mean - total_mean)) / (class_0_variance + class_1_variance)

In [3]:
models = [
    "mistral_7b_instruct",
    "ministral_8b_instruct",
    "qwen_2.5_7b_instruct",
    "llama3.1_8b_chat",
    "llama3.3_70b",
    "deepseek_qwen_32b",
]

datasets = [
    "gsm8k",
    "trivia_qa_2_60k",
    "birth_years_4k",
    "cities_10k",
    "medals_9k",
    "math_operations_6k"
]

BASE_PATH = {}
BASE_PATH["llama3.1_8b_chat"] = "/runpod-volume/anton/correctness-model-internals/data_for_classification"
BASE_PATH["llama3.3_70b"] = "/runpod-volume/arnau/correctness-model-internals/data_for_classification"
BASE_PATH["deepseek_qwen_32b"] = "/runpod-volume/arnau/correctness-model-internals/data_for_classification"
BASE_PATH["mistral_7b_instruct"] = "/runpod-volume/anton/correctness-model-internals/data_for_classification"
BASE_PATH["ministral_8b_instruct"] = "/runpod-volume/arnau/correctness-model-internals/data_for_classification"
BASE_PATH["qwen_2.5_7b_instruct"] = "/runpod-volume/arnau/correctness-model-internals/data_for_classification"

layers = {}
layers["llama3.1_8b_chat"] = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
layers["llama3.3_70b"] = [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76]
layers["deepseek_qwen_32b"] = [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]
layers["mistral_7b_instruct"] = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
layers["ministral_8b_instruct"] = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34]
layers["qwen_2.5_7b_instruct"] = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]

In [None]:
results = []

for model_id in models:
    for dataset_id in datasets:
        prompt_id = "base"
        if dataset_id == "gsm8k":
            prompt_id = "base_3_shot"
        subset_id = "main"
        input_type = "prompt_only"
        layers_model = layers[model_id]
        BASE_PATH_MODEL = BASE_PATH[model_id]
        vars = []
        for layer in layers_model:
            activations, indices = load_activations(
                base_path=BASE_PATH_MODEL,
                model_id=model_id,
                dataset_id=dataset_id,
                prompt_id=prompt_id,
                subset_id=subset_id,
                input_type=input_type,
                layer=layer,
            )
            labels_df = load_labels(
                base_path=BASE_PATH_MODEL,
                model_id=model_id,
                dataset_id=dataset_id,
                prompt_id=prompt_id,
                subset_id=subset_id,
            )
            labels = labels_df['correct']
            activations, labels
            var = intra_class_variance_over_inter_class_variance(activations, labels)
            vars.append(var)

        print(model_id, dataset_id, vars)
        results.append({
            "model": model_id,
            "dataset": dataset_id,
            "aucs": vars
        })

df = pd.DataFrame(results)
df.to_csv("intra_class_variance_over_inter_class_variance_data.csv", index=False)