In [1]:
import sys
from pathlib import Path

# Add the src directory to Python path
src_path = str(Path("./").resolve().parent)
if src_path not in sys.path:
    sys.path.append(src_path)

sys.path

['/Library/Frameworks/Python.framework/Versions/3.11/lib/python311.zip',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload',
 '',
 '/Users/anton/dev/MARS/correctness-model-internals/venv/lib/python3.11/site-packages',
 '/Users/anton/dev/MARS/correctness-model-internals']

In [2]:
from pathlib import Path
from collections import defaultdict

import torch as pt
import pandas as pd
# import seaborn as sns

from src.classifying import (
    ActivationsHandler,
    combine_activations_handlers,
    get_correctness_direction_classifier,
    get_logistic_regression_classifier,
)

# sns.set_theme(style="whitegrid")


In [3]:
def load_activations(
    model_id,
    dataset_id,
    prompt_id,
    subset_id,
    input_type,
    layer,
    batch_ids=None,
):
    if batch_ids:
        batch_ids = [int(batch_id) for batch_id in batch_ids]

    paths = sorted(
        list(
            Path(
                f"../activations/{model_id}/{dataset_id}/{prompt_id}/{subset_id}/{input_type}/layer_{layer}"
            ).iterdir()
        ),
        key=lambda p: int(p.stem.split("_")[-1]),
    )

    activations_list, indices = [], []
    batch_size = None
    for batch_file in paths:
        batch_id = int(batch_file.stem.split("_")[-1])
        if batch_ids and batch_id not in batch_ids:
            continue

        activations = pt.load(batch_file, map_location=pt.device("cpu"))
        activations_list.append(activations)

        batch_size = activations.shape[0]

        if batch_size is None:
            batch_size = activations.shape[0]
        else:
            assert batch_size == activations.shape[0]

        indices.append(
            pd.Series(range(batch_size), name="index") + batch_id * batch_size
        )
    return (
        pt.cat(activations_list, dim=0),
        pd.concat(indices).reset_index(drop=True),
    )


def load_labels(model_id, dataset_id, prompt_id, subset_id, indices=None):
    paths = list(
        Path(f"../evaluations/{model_id}/{dataset_id}/{prompt_id}/").iterdir()
    )
    for path in paths:
        filename = path.stem
        if subset_id != filename.split("_generations_evaluated")[0]:
            continue
        df = pd.read_csv(path)
        if indices is not None:
            df = df.iloc[indices].reset_index(drop=True)
        return df
    raise ValueError(
        f"No labels found for {model_id} {dataset_id} {prompt_id} {subset_id}"
    )


In [41]:
all_activation_exp_configs = defaultdict(list)
for model_path in Path("../activations").iterdir():
    for dataset_path in model_path.iterdir():
        for prompt_path in dataset_path.iterdir():
            for subset_path in prompt_path.iterdir():
                for input_type_path in subset_path.iterdir():
                    for layer_path in input_type_path.iterdir():
                        all_activation_exp_configs["model_id"].append(model_path.stem)
                        all_activation_exp_configs["dataset_id"].append(dataset_path.stem)
                        all_activation_exp_configs["prompt_id"].append(prompt_path.stem)
                        all_activation_exp_configs["subset_id"].append(subset_path.stem)
                        all_activation_exp_configs["input_type"].append(input_type_path.stem)
                        all_activation_exp_configs["layer"].append(int(layer_path.stem.split("_")[-1]))
                        all_activation_exp_configs["path"].append(layer_path)
all_activation_exp_configs_df = pd.DataFrame(all_activation_exp_configs)
all_activation_exp_configs_df


Unnamed: 0,model_id,dataset_id,prompt_id,subset_id,input_type,layer,path
0,llama3_3b_chat,football_leagues_1k,base,main,prompt_only,0,../activations/llama3_3b_chat/football_leagues...
1,llama3_3b_chat,football_leagues_1k,base,main,prompt_only,7,../activations/llama3_3b_chat/football_leagues...
2,llama3_3b_chat,football_leagues_1k,base,main,prompt_only,9,../activations/llama3_3b_chat/football_leagues...
3,llama3_3b_chat,football_leagues_1k,base,main,prompt_only,17,../activations/llama3_3b_chat/football_leagues...
4,llama3_3b_chat,football_leagues_1k,base,main,prompt_only,10,../activations/llama3_3b_chat/football_leagues...
...,...,...,...,...,...,...,...
499,llama3_3b_chat,birth_years_4k,base,main,prompt_answer,23,../activations/llama3_3b_chat/birth_years_4k/b...
500,llama3_3b_chat,birth_years_4k,base,main,prompt_answer,15,../activations/llama3_3b_chat/birth_years_4k/b...
501,llama3_3b_chat,birth_years_4k,base,main,prompt_answer,12,../activations/llama3_3b_chat/birth_years_4k/b...
502,llama3_3b_chat,birth_years_4k,base,main,prompt_answer,2,../activations/llama3_3b_chat/birth_years_4k/b...


In [42]:
# Set None to get all
MODEL_ID = "llama3_3b_chat"
DATASET_ID = "gsm8k"
PROMPT_ID = "cot_3_shot"
SUBSET_ID = "main"
INPUT_TYPE = "prompt_only"
# DATASET_ID = None
# PROMPT_ID = None
# SUBSET_ID = None
# INPUT_TYPE = None


activation_exp_configs_df = all_activation_exp_configs_df.copy()

if MODEL_ID != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["model_id"] == MODEL_ID]
if DATASET_ID != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["dataset_id"] == DATASET_ID]
if PROMPT_ID != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["prompt_id"] == PROMPT_ID]
if SUBSET_ID != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["subset_id"] == SUBSET_ID]
if INPUT_TYPE != None:
    activation_exp_configs_df = activation_exp_configs_df[activation_exp_configs_df["input_type"] == INPUT_TYPE]


activation_exp_configs_df

Unnamed: 0,model_id,dataset_id,prompt_id,subset_id,input_type,layer,path
168,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,0,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
169,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,7,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
170,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,9,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
171,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,17,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
172,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,10,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
173,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,26,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
174,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,19,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
175,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,21,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
176,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,8,../activations/llama3_3b_chat/gsm8k/cot_3_shot...
177,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,6,../activations/llama3_3b_chat/gsm8k/cot_3_shot...


In [43]:
for x in activation_exp_configs_df.groupby(["model_id", "dataset_id", "prompt_id", "subset_id", "input_type"]):
    print(x[1])

           model_id dataset_id   prompt_id subset_id   input_type  layer  \
168  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only      0   
169  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only      7   
170  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only      9   
171  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only     17   
172  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only     10   
173  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only     26   
174  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only     19   
175  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only     21   
176  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only      8   
177  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only      6   
178  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only      1   
179  llama3_3b_chat      gsm8k  cot_3_shot      main  prompt_only     20   
180  llama3_

In [48]:

res_dict = defaultdict(list)

for (model_id, dataset_id, prompt_id, subset_id, input_type), config_df in activation_exp_configs_df.groupby(["model_id", "dataset_id", "prompt_id", "subset_id", "input_type"]):
    print(f"\n{model_id=}, {dataset_id=}, {prompt_id=}, {subset_id=}, {input_type=}")
    labels_df = load_labels(
        model_id=model_id,
        dataset_id=dataset_id,
        prompt_id=prompt_id,
        subset_id=subset_id,
    )

    check_indices = None
    for layer in config_df["layer"].astype(int).sort_values():
        print(f"{layer=}", end=", ")
        activations, indices = load_activations(
            model_id=model_id,
            dataset_id=dataset_id,
            prompt_id=prompt_id,
            subset_id=subset_id,
            input_type=input_type,
            layer=layer,
        )
        
        if check_indices is None:
            check_indices = indices.sample(frac=1, replace=False)
        
        if set(indices) != set(check_indices):
            raise RuntimeError(f"indices across layers are not the same")

        labels_df_subset = labels_df.iloc[check_indices]
        activations = activations[check_indices]
        
        labels_df_subset = pd.concat([labels_df_subset]*10, axis=0)
        activations = pt.cat([activations]*10, dim=0)

        activations_handler = ActivationsHandler(
            activations=activations,
            labels=labels_df_subset["correct"].astype(bool),
        )

        activations_handler_folds = list(
            activations_handler.split_dataset(split_sizes=[1/5] * 5)
        )

        fold_stats = {}
        for fold_i, activations_handler_test in enumerate(activations_handler_folds):
            activations_handler_train = combine_activations_handlers(
                [ah for j, ah in enumerate(activations_handler_folds) if j != fold_i]
            ).sample_equally_across_groups(
                group_labels=[False, True]
            )
            activations_handler_test = activations_handler_test.sample_equally_across_groups(
                group_labels=[False, True]
            )

            res_dict["model_id"].append(model_id)
            res_dict["dataset_id"].append(dataset_id)
            res_dict["prompt_id"].append(prompt_id)
            res_dict["subset_id"].append(subset_id)
            res_dict["input_type"].append(input_type)
            res_dict["layer"].append(layer)
            res_dict[f"fold"].append(fold_i)

            direction_classifier, direction_calculator = get_correctness_direction_classifier(
                activations_handler_train=activations_handler_train,
                activations_handler_test=activations_handler_test,
            )
            # res_dict["classifying_direction"].append(direction_calculator.classifying_direction.tolist())
            for key, value in direction_classifier.classification_metrics.items():
                res_dict[f"direction_{key}"].append(value)
            
            for key, value in get_logistic_regression_classifier(
                    activations_handler_train=activations_handler_train,
                    activations_handler_test=activations_handler_test,
                )[0].classification_metrics.items():
                res_dict[f"logistic_regression_{key}"].append(value)



res_df = pd.DataFrame(res_dict)
res_df


model_id='llama3_3b_chat', dataset_id='gsm8k', prompt_id='cot_3_shot', subset_id='main', input_type='prompt_only'
layer=0, layer=1, layer=2, layer=3, layer=4, layer=5, layer=6, layer=7, layer=8, layer=9, layer=10, layer=11, layer=12, layer=13, layer=14, layer=15, layer=16, layer=17, layer=18, layer=19, layer=20, layer=21, layer=22, layer=23, layer=24, layer=25, layer=26, layer=27, 

Unnamed: 0,model_id,dataset_id,prompt_id,subset_id,input_type,layer,fold,direction_optimal_cut,direction_optimal_train_set_cut,direction_test_roc_auc,...,direction_f1_score,direction_precision_score,direction_recall_score,logistic_regression_optimal_cut,logistic_regression_optimal_train_set_cut,logistic_regression_test_roc_auc,logistic_regression_accuracy_score,logistic_regression_f1_score,logistic_regression_precision_score,logistic_regression_recall_score
0,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,0,0,0.017149,0.017149,1.0,...,0.8,1.0,0.666667,0.5,0.999421,1.0,1.0,1.0,1.0,1.0
1,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,0,1,0.017149,0.017149,1.0,...,0.8,1.0,0.666667,0.5,0.999421,1.0,1.0,1.0,1.0,1.0
2,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,0,2,0.017149,0.017149,1.0,...,0.8,1.0,0.666667,0.5,0.999421,1.0,1.0,1.0,1.0,1.0
3,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,0,3,0.017149,0.017149,1.0,...,0.8,1.0,0.666667,0.5,0.999421,1.0,1.0,1.0,1.0,1.0
4,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,0,4,0.017149,0.017149,1.0,...,0.8,1.0,0.666667,0.5,0.999421,1.0,1.0,1.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
135,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,27,0,9.218114,9.218114,1.0,...,1.0,1.0,1.000000,0.5,0.999593,1.0,1.0,1.0,1.0,1.0
136,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,27,1,9.218114,9.218114,1.0,...,1.0,1.0,1.000000,0.5,0.999593,1.0,1.0,1.0,1.0,1.0
137,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,27,2,9.218114,9.218114,1.0,...,1.0,1.0,1.000000,0.5,0.999593,1.0,1.0,1.0,1.0,1.0
138,llama3_3b_chat,gsm8k,cot_3_shot,main,prompt_only,27,3,9.218114,9.218114,1.0,...,1.0,1.0,1.000000,0.5,0.999593,1.0,1.0,1.0,1.0,1.0


In [45]:

import plotly.graph_objects as go
def plot_interactive_lineplot(df, x_label, y_label, title=None):
    fig = go.Figure()
    
    # Calculate statistics for each column
    means = df.mean(axis=1)

    # Add mean line
    fig.add_trace(go.Scatter(
        x=df.index,
        y=means,
        mode='lines+markers',
        line=dict(color='#1f77b4', width=2),
        marker=dict(size=8),
        name='Mean'
    ))

    if df.shape[1] > 1:
        multiple_vals = True
        stds = df.std(axis=1)
        mins = df.min(axis=1)
        maxs = df.max(axis=1)
    else:
        multiple_vals = False
    
    if multiple_vals:
    # Add min/max range (very faint)
        fig.add_trace(go.Scatter(
            x=df.index,
            y=maxs,
            mode='lines',
            line=dict(width=0),
            showlegend=False,
            name='Max'
        ))
        fig.add_trace(go.Scatter(
            x=df.index,
            y=mins,
            mode='lines',
            line=dict(width=0),
            fillcolor='rgba(68, 138, 255, 0.1)',  # Very faint blue
            fill='tonexty',
            showlegend=False,
            name='Min'
        ))
    
        # Add ±1 std range (moderately faint)
        fig.add_trace(go.Scatter(
            x=df.index,
            y=means + stds,
            mode='lines',
            line=dict(width=0),
            showlegend=False,
            name='+1 STD'
        ))
        fig.add_trace(go.Scatter(
            x=df.index,
            y=means - stds,
            mode='lines',
            line=dict(width=0),
            fillcolor='rgba(68, 138, 255, 0.3)',  # Slightly more visible blue
            fill='tonexty',
            showlegend=False,
            name='-1 STD'
        ))
        

    
        # Add individual points for each fold
        for col in df.columns:
            fig.add_trace(go.Scatter(
                x=df.index,
                y=df[col],
                mode='markers',
                marker=dict(
                    color='#1f77b4',
                    size=6,
                    opacity=0.5
                ),
                showlegend=False,
                name=f'Fold {col}'
            ))
    
    fig.update_layout(
        title=title,
        yaxis_title=y_label,
        xaxis_title=x_label,
        template='plotly_dark',
        plot_bgcolor='rgba(32, 32, 32, 1)',
        paper_bgcolor='rgba(32, 32, 32, 1)',
        font=dict(color='white'),
        margin=dict(t=50, l=50, r=30),
        showlegend=False
    )
    
    # Update axes for consistency with dark theme
    fig.update_xaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    fig.update_yaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    
    return fig



In [46]:

res_df_pivot =pd.pivot(
    res_df,
    index='layer',
    columns='fold',
    # values=['direction_f1_score', 'logistic_regression_f1_score']  # add all metrics you want to keep
)
res_df_pivot

Unnamed: 0_level_0,direction_optimal_cut,direction_optimal_cut,direction_optimal_cut,direction_optimal_cut,direction_optimal_cut,direction_optimal_train_set_cut,direction_optimal_train_set_cut,direction_optimal_train_set_cut,direction_optimal_train_set_cut,direction_optimal_train_set_cut,...,logistic_regression_precision_score,logistic_regression_precision_score,logistic_regression_precision_score,logistic_regression_precision_score,logistic_regression_precision_score,logistic_regression_recall_score,logistic_regression_recall_score,logistic_regression_recall_score,logistic_regression_recall_score,logistic_regression_recall_score
fold,0,1,2,3,4,0,1,2,3,4,...,0,1,2,3,4,0,1,2,3,4
layer,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
0,0.018484,0.018484,0.018484,0.018484,0.018484,0.018484,0.018484,0.018484,0.018484,0.018484,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
1,0.018109,0.018109,0.018109,0.018109,0.018109,0.018109,0.018109,0.018109,0.018109,0.018109,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
2,0.026479,0.026479,0.026479,0.026479,0.026479,0.026479,0.026479,0.026479,0.026479,0.026479,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
3,0.067752,0.067752,0.067752,0.067752,0.067752,0.067752,0.067752,0.067752,0.067752,0.067752,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
4,0.117733,0.117733,0.117733,0.117733,0.117733,0.117733,0.117733,0.117733,0.117733,0.117733,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
5,0.220754,0.220754,0.220754,0.220754,0.220754,0.220754,0.220754,0.220754,0.220754,0.220754,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
6,0.258142,0.258142,0.258142,0.258142,0.258142,0.258142,0.258142,0.258142,0.258142,0.258142,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
7,0.238972,0.238972,0.238972,0.238972,0.238972,0.238972,0.238972,0.238972,0.238972,0.238972,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
8,0.282022,0.282022,0.282022,0.282022,0.282022,0.282022,0.282022,0.282022,0.282022,0.282022,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
9,0.37992,0.37992,0.37992,0.37992,0.37992,0.37992,0.37992,0.37992,0.37992,0.37992,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [47]:
for classifier in ["direction", "logistic_regression"]:
    for metric in ["f1_score", "accuracy_score", "precision_score", "recall_score"]:
        plot_interactive_lineplot(
            res_df_pivot[[f"{classifier}_{metric}"]],
            x_label="Layer",
            y_label=f"{classifier}_{metric}".replace("_", " ").title()
        ).show()

In [22]:
# def plot_interactive_box(df, x_label, y_label, title=None):
#     fig = go.Figure()
    
#     # Create box plot
#     fig.add_trace(go.Box(
#         x=[idx for idx in df.index for _ in range(len(df.columns))],  # Repeat each index for each fold
#         y=df.values.flatten(),  # Flatten all values
#         boxpoints='all',  # Show all points
#         jitter=0,        # No jitter for points
#         pointpos=0,      # Position points at center
#         marker=dict(
#             color='#1f77b4',
#             size=8,
#             opacity=0.5
#         ),
#         line=dict(
#             color='#1f77b4',
#             width=2
#         ),
#         fillcolor='rgba(68, 138, 255, 0.5)',
#         opacity=0.6,
#         showlegend=False,
#         boxmean=True,    # Show mean as a dashed line
#         width=0.5        # Width of boxes
#     ))

#     fig.update_layout(
#         title=title,
#         yaxis_title=y_label,
#         xaxis_title=x_label,
#         template='plotly_dark',
#         plot_bgcolor='rgba(32, 32, 32, 1)',
#         paper_bgcolor='rgba(32, 32, 32, 1)',
#         font=dict(color='white'),
#         margin=dict(t=50, l=50, r=30),
#         showlegend=False
#     )
    
#     # Update axes for consistency with dark theme
#     fig.update_xaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
#     fig.update_yaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    
#     return fig

# plot_interactive_box(res_df, x_label="Layer", y_label="F1 Score").show()