In [1]:
import sys
from pathlib import Path

# Add the src directory to Python path
src_path = str(Path("./").resolve().parent)
if src_path not in sys.path:
    sys.path.append(src_path)

sys.path

['/Users/anton/dev/MARS/correctness-model-internals/notebooks',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python311.zip',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11',
 '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload',
 '',
 '/Users/anton/dev/.virtual_envs/default_env/lib/python3.11/site-packages',
 '/Users/anton/dev/MARS/correctness-model-internals']

In [8]:
from pathlib import Path

import torch as pt
import pandas as pd
# import seaborn as sns

# sns.set_theme(style="whitegrid")


In [20]:
def load_activations(
    model_id,
    dataset_id,
    prompt_id,
    subset_id,
    input_type,
    layer,
    batch_ids=None,
):
    if batch_ids:
        batch_ids = [int(batch_id) for batch_id in batch_ids]

    paths = sorted(
        list(
            Path(
                f"../activations/{model_id}/{dataset_id}/{prompt_id}/{subset_id}/{input_type}"
            ).iterdir()
        ),
        key=lambda p: int(p.stem.split("_")[-1]),
    )

    activations_list, indices = [], []
    batch_size = None
    for batch_file in paths:
        batch_id = int(batch_file.stem.split("_")[-1])
        if batch_ids and batch_id not in batch_ids:
            continue

        activations = pt.load(batch_file, map_location=pt.device("cpu"))[layer]
        activations_list.append(activations)

        batch_size = activations.shape[0]

        if batch_size is None:
            batch_size = activations.shape[0]
        else:
            assert batch_size == activations.shape[0]

        indices.append(
            pd.Series(range(batch_size), name="index") + batch_id * batch_size
        )
    return (
        pt.cat(activations_list, dim=0),
        pd.concat(indices).reset_index(drop=True),
    )


def load_labels(model_id, dataset_id, prompt_id, subset_id, indices=None):
    paths = list(
        Path(f"../evaluations/{model_id}/{dataset_id}/{prompt_id}/").iterdir()
    )
    for path in paths:
        filename = path.stem
        if subset_id != filename.split("_generations_evaluated")[0]:
            continue
        df = pd.read_csv(path)
        if indices is not None:
            df = df.iloc[indices].reset_index(drop=True)
        return df
    raise ValueError(
        f"No labels found for {model_id} {dataset_id} {prompt_id} {subset_id}"
    )


In [21]:
MODEL_ID = "llama3_3b_chat"
DATASET_ID = "gsm8k"
PROMPT_ID = "base_3_shot"
SUBSET_ID = "main"
INPUT_TYPE = "prompt_only"

In [22]:
from src.classifying import (
    ActivationsHandler,
    combine_activations_handlers,
    get_correctness_direction_classifier,
    get_logistic_regression_classifier,
)

In [28]:
from collections import defaultdict

res_dict = defaultdict(list)
for layer in range(100):
    try:
        activations, indices = load_activations(
            model_id=MODEL_ID,
            dataset_id=DATASET_ID,
            prompt_id=PROMPT_ID,
            subset_id=SUBSET_ID,
            input_type=INPUT_TYPE,
            layer=layer,
        )
    except KeyError as e:
        continue
    print(f"layer {layer}")
    labels_df = load_labels(
        model_id=MODEL_ID,
        dataset_id=DATASET_ID,
        prompt_id=PROMPT_ID,
        subset_id=SUBSET_ID,
        indices=indices,
    )

    ################### DELETE ME ###################
    activations = pt.cat([activations]*10, dim=0)
    labels_df = pd.concat([labels_df]*10, axis=0).reset_index(drop=True)
    ################### DELETE ME ###################

    activation_handler = ActivationsHandler(
        activations=activations, labels=labels_df["correct"].astype(bool)
    )

    activation_handler = (
        activation_handler.sample_equally_across_groups(
            group_labels=[False, True]
        )
    )

    activations_handler_folds = list(
        activation_handler.split_dataset(split_sizes=[0.5] * 2)
    )

    # fold_stats = {}
    for fold_i, activations_handler_test in enumerate(activations_handler_folds):
        activations_handler_train = combine_activations_handlers(
            [ah for j, ah in enumerate(activations_handler_folds) if j != fold_i]
        )
        # stats_dict = {
        #     "n_train": activations_handler_train.activations.shape[0],
        #     "n_test": activations_handler_test.activations.shape[0],
        # }

        # direction_classifier, direction_calculator = (
        #     get_correctness_direction_classifier(
        #         activations_handler_train=activations_handler_train,
        #         activations_handler_test=activations_handler_test,
        #     )
        # )
        # stats_dict["correctness_direction_classifier"] = (
        #     direction_classifier.classification_metrics
        # )
        # stats_dict["activation_space_directions"] = {
        #     name: getattr(direction_calculator, name).tolist()
        #     for name in [
        #         "classifying_direction",
        #         "mean_activations",
        #         "centroid_from",
        #         "centroid_to",
        #         "max_activations_from",
        #         "min_activations_from",
        #         "max_activations_to",
        #         "min_activations_to",
        #     ]
        # }

        # stats_dict["logistic_regression_classifier"] = (
        #     get_logistic_regression_classifier(
        #         activations_handler_train=activations_handler_train,
        #         activations_handler_test=activations_handler_test,
        #     )[0].classification_metrics
        # )
        res_dict["layer"].append(layer)
        res_dict[f"fold_n"].append(fold_i)

        for key, value in get_correctness_direction_classifier(
                activations_handler_train=activations_handler_train,
                activations_handler_test=activations_handler_test,
            )[0].classification_metrics.items():
            res_dict[f"direction_{key}"].append(value)
        
        for key, value in get_logistic_regression_classifier(
                activations_handler_train=activations_handler_train,
                activations_handler_test=activations_handler_test,
            )[0].classification_metrics.items():
            res_dict[f"logistic_regression_{key}"].append(value)

res_df = pd.DataFrame(res_dict)
res_df


layer 1


  activations = pt.load(batch_file, map_location=pt.device("cpu"))[layer]
  labels=pd.concat([self.labels, other.labels]),
  labels=pd.concat([self.labels, other.labels]),
  activations = pt.load(batch_file, map_location=pt.device("cpu"))[layer]


Unnamed: 0,layer,fold_n,direction_optimal_cut,direction_optimal_train_set_cut,direction_test_roc_auc,direction_accuracy_score,direction_f1_score,logistic_regression_optimal_cut,logistic_regression_optimal_train_set_cut,logistic_regression_test_roc_auc,logistic_regression_accuracy_score,logistic_regression_f1_score
0,1,0,0.161888,0.161888,1.0,1.0,1.0,0.5,0.999332,1.0,0.9,0.9
1,1,1,0.251568,0.251568,1.0,1.0,1.0,0.5,0.998992,1.0,1.0,1.0


In [24]:
# experiment_results = {
#     "llama3_3b_chat": {
#         "gsm8k": {
#             "base_3_shot": {
#                 "main": {
#                     "prompt_answer": {
#                         1: {
#                             "fold_1": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.7,
#                                     "accuracy": 0.6,
#                                     "f1_score": 0.7,
#                                 },
#                             },
#                             "fold_2": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.8,
#                                     "accuracy": 0.7,
#                                     "f1_score": 0.8,
#                                 },
#                             },
#                             "fold_3": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.9,
#                                     "accuracy": 0.8,
#                                     "f1_score": 0.9,
#                                 },
#                             },
#                         },
#                         2: {
#                             "fold_1": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.7,
#                                     "accuracy": 0.6,
#                                     "f1_score": 0.7,
#                                 },
#                             },
#                             "fold_2": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.8,
#                                     "accuracy": 0.7,
#                                     "f1_score": 0.8,
#                                 },
#                             },
#                             "fold_3": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.9,
#                                     "accuracy": 0.8,
#                                     "f1_score": 0.9,
#                                 },
#                             },
#                         },
#                         3: {
#                             "fold_1": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.7,
#                                     "accuracy": 0.6,
#                                     "f1_score": 0.7,
#                                 },
#                             },
#                             "fold_2": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.8,
#                                     "accuracy": 0.7,
#                                     "f1_score": 0.8,
#                                 },
#                             },
#                             "fold_3": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.9,
#                                     "accuracy": 0.8,
#                                     "f1_score": 0.9,
#                                 },
#                             },
#                         },
#                         4: {
#                             "fold_1": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.7,
#                                     "accuracy": 0.6,
#                                     "f1_score": 0.7,
#                                 },
#                             },
#                             "fold_2": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.8,
#                                     "accuracy": 0.7,
#                                     "f1_score": 0.8,
#                                 },
#                             },
#                             "fold_3": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.9,
#                                     "accuracy": 0.8,
#                                     "f1_score": 0.9,
#                                 },
#                             },
#                         },
#                         5: {
#                             "fold_1": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.7,
#                                     "accuracy": 0.6,
#                                     "f1_score": 0.7,
#                                 },
#                             },
#                             "fold_2": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.8,
#                                     "accuracy": 0.7,
#                                     "f1_score": 0.8,
#                                 },
#                             },
#                             "fold_3": {
#                                 "correctness_direction_classifier": {
#                                     "optimal_cut": 0.5,
#                                     "optimal_train_set_cut": 0.5,
#                                     "test_roc_auc": 0.9,
#                                     "accuracy": 0.8,
#                                     "f1_score": 0.9,
#                                 },
#                             },
#                         },
#                     }
#                 }
#             }
#         }
#     }
# }



In [49]:

import plotly.graph_objects as go
def plot_interactive_lineplot(df, x_label, y_label, title=None):
    fig = go.Figure()
    
    # Calculate statistics for each column
    means = df.mean(axis=1)
    stds = df.std(axis=1)
    mins = df.min(axis=1)
    maxs = df.max(axis=1)
    
    # Add min/max range (very faint)
    fig.add_trace(go.Scatter(
        x=df.index,
        y=maxs,
        mode='lines',
        line=dict(width=0),
        showlegend=False,
        name='Max'
    ))
    fig.add_trace(go.Scatter(
        x=df.index,
        y=mins,
        mode='lines',
        line=dict(width=0),
        fillcolor='rgba(68, 138, 255, 0.1)',  # Very faint blue
        fill='tonexty',
        showlegend=False,
        name='Min'
    ))
    
    # Add ±1 std range (moderately faint)
    fig.add_trace(go.Scatter(
        x=df.index,
        y=means + stds,
        mode='lines',
        line=dict(width=0),
        showlegend=False,
        name='+1 STD'
    ))
    fig.add_trace(go.Scatter(
        x=df.index,
        y=means - stds,
        mode='lines',
        line=dict(width=0),
        fillcolor='rgba(68, 138, 255, 0.3)',  # Slightly more visible blue
        fill='tonexty',
        showlegend=False,
        name='-1 STD'
    ))
    
    # Add mean line
    fig.add_trace(go.Scatter(
        x=df.index,
        y=means,
        mode='lines+markers',
        line=dict(color='#1f77b4', width=2),
        marker=dict(size=8),
        name='Mean'
    ))
    
    # Add individual points for each fold
    for col in df.columns:
        fig.add_trace(go.Scatter(
            x=df.index,
            y=df[col],
            mode='markers',
            marker=dict(
                color='#1f77b4',
                size=6,
                opacity=0.5
            ),
            showlegend=False,
            name=f'Fold {col}'
        ))
    
    fig.update_layout(
        title=title,
        yaxis_title=y_label,
        xaxis_title=x_label,
        template='plotly_dark',
        plot_bgcolor='rgba(32, 32, 32, 1)',
        paper_bgcolor='rgba(32, 32, 32, 1)',
        font=dict(color='white'),
        margin=dict(t=50, l=50, r=30),
        showlegend=False
    )
    
    # Update axes for consistency with dark theme
    fig.update_xaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    fig.update_yaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    
    return fig

def get_layer_performance_df(layer_results, classifier_name, metric_name):
    res_dict = {}
    for layer, layer_stats in layer_results.items():
        metric_val_folds = pd.Series([fold_stats[classifier_name][metric_name] for fold_stats in layer_stats.values()])
        res_dict[layer] = metric_val_folds + np.random.normal(0, 0.5, 3)
    res_df = pd.DataFrame(res_dict).T.sort_index()
    return res_df
    return 

res_df = get_layer_performance_df(experiment_results["llama3_3b_chat"]["gsm8k"]["base_3_shot"]["main"]["prompt_answer"], "correctness_direction_classifier", "f1_score")
plot_interactive_lineplot(res_df, x_label="Layer", y_label="F1 Score").show()
            

In [52]:
def plot_interactive_box(df, x_label, y_label, title=None):
    fig = go.Figure()
    
    # Create box plot
    fig.add_trace(go.Box(
        x=[idx for idx in df.index for _ in range(len(df.columns))],  # Repeat each index for each fold
        y=df.values.flatten(),  # Flatten all values
        boxpoints='all',  # Show all points
        jitter=0,        # No jitter for points
        pointpos=0,      # Position points at center
        marker=dict(
            color='#1f77b4',
            size=8,
            opacity=0.5
        ),
        line=dict(
            color='#1f77b4',
            width=2
        ),
        fillcolor='rgba(68, 138, 255, 0.5)',
        opacity=0.6,
        showlegend=False,
        boxmean=True,    # Show mean as a dashed line
        width=0.5        # Width of boxes
    ))

    fig.update_layout(
        title=title,
        yaxis_title=y_label,
        xaxis_title=x_label,
        template='plotly_dark',
        plot_bgcolor='rgba(32, 32, 32, 1)',
        paper_bgcolor='rgba(32, 32, 32, 1)',
        font=dict(color='white'),
        margin=dict(t=50, l=50, r=30),
        showlegend=False
    )
    
    # Update axes for consistency with dark theme
    fig.update_xaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    fig.update_yaxes(gridcolor='rgba(128, 128, 128, 0.2)', zeroline=False)
    
    return fig

plot_interactive_box(res_df, x_label="Layer", y_label="F1 Score").show()