In [1]:
import torch
from utils import DataManager, dataset_sizes, collect_training_data, compute_statistics, compute_average_accuracies, plot_lr_feature_importance
import matplotlib.pyplot as plt
from probes import CCSProbe, LRProbe, MMProbe, ALL_PROBES, TTPD_TYPES, measure_polarity_direction_lr, run_ray, get_average_coef
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import pandas as pd

In [2]:
# hyperparameters
model_family = 'Llama3' # options are 'Llama3', 'Llama2', 'Gemma', 'Gemma2' or 'Mistral'
model_size = '8B'
model_type = 'chat' # options are 'chat' or 'base'
layer = 12 # layer from which to extract activations

device = 'mps' if torch.mps.is_available() else 'cpu' # mps speeds up CCS training a fair bit but is not required
device = "cuda" if torch.cuda.is_available() else device # cuda speeds it up a bit more
device

'mps'

In [3]:
# define datasets used for training
train_sets = ["cities", "neg_cities", "sp_en_trans", "neg_sp_en_trans", "inventors", "neg_inventors", "animal_class",
                  "neg_animal_class", "element_symb", "neg_element_symb", "facts", "neg_facts"]

# train_sets = ["cities", "sp_en_trans", "inventors",  "animal_class", "element_symb", "facts"]

# get size of each training dataset to include an equal number of statements from each topic in training data
train_set_sizes = dataset_sizes(train_sets)

TTPD_CLASSES = [v for (k, v) in TTPD_TYPES]


### Parameter Hyper Optimization

In [4]:
import ray

ray.init(num_cpus=4, ignore_reinit_error=True)

val_sets = ["cities_conj", "cities_disj", "sp_en_trans_conj", "sp_en_trans_disj",
                "inventors_conj", "inventors_disj", "animal_class_conj", "animal_class_disj",
                "element_symb_conj", "element_symb_disj", "facts_conj", "facts_disj",
                "common_claim_true_false", "counterfact_true_false"]

# Run optimization
final_probe, best_config, analysis = run_ray(train_sets, val_sets)

# Shutdown Ray
ray.shutdown()

0,1
Current time:,2025-10-10 18:22:03
Running for:,00:00:16.54
Memory:,18.0/32.0 GiB

Trial name,status,loc,features,final_C,final_combo/penalty,final_combo/solver,final_l1_ratio,final_max_iter,polarity_C,polarity_combo/penal ty,polarity_combo/solve r,polarity_l1_ratio,polarity_max_iter,use_scaler,iter,total time (s),accuracy,accuracy_std,loss
train_ttpd_with_cv_8606f3c5,PENDING,,"['proj_t_g', 'p_3540",0.316366,l1,liblinear,0.341033,2000,0.00702554,l1,saga,0.269722,5000,False,,,,,
train_ttpd_with_cv_45c0bf76,TERMINATED,127.0.0.1:12810,"['proj_t_g', 'p_a1c0",0.263993,l1,liblinear,0.818001,1000,0.0423506,l2,lbfgs,0.43874,3000,True,13.0,0.127828,0.950263,0.0856619,0.0497373
train_ttpd_with_cv_7c461e90,TERMINATED,127.0.0.1:12813,"['proj_t_g', 'p_c880",0.000828578,,lbfgs,0.192403,2000,0.0009304,l1,saga,0.672523,5000,False,13.0,0.103346,0.931189,0.0968422,0.0688114
train_ttpd_with_cv_6c5f04fe,TERMINATED,127.0.0.1:12814,"['proj_t_g', 'p_a440",0.344553,l1,liblinear,0.702185,2000,0.0108945,l2,liblinear,0.774081,1000,False,13.0,0.151488,0.956175,0.0588715,0.043825
train_ttpd_with_cv_02e153fd,TERMINATED,127.0.0.1:12815,"['proj_t_g', 'p_1c40",0.00272752,l2,liblinear,0.407002,2000,2.04698,l2,lbfgs,0.472824,1000,False,13.0,0.14701,0.93229,0.0850108,0.0677103
train_ttpd_with_cv_17ec2f1e,TERMINATED,127.0.0.1:12825,"['proj_t_g', 'p_b6c0",0.085112,,lbfgs,0.609705,1000,0.233161,l2,lbfgs,0.134051,1000,True,12.0,0.153144,0.689655,,0.310345


2025-10-10 18:22:03,612	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/flohop/ray_results/ttpd_optimization' in 0.0051s.
2025-10-10 18:22:04,193	INFO tune.py:1041 -- Total run time: 17.14 seconds (16.53 seconds for the tuning loop).
Resume experiment with: tune.run(..., resume=True)
- train_ttpd_with_cv_8606f3c5: FileNotFoundError('Could not fetch metrics for train_ttpd_with_cv_8606f3c5: both result.json and progress.csv were not found at /Users/flohop/ray_results/ttpd_optimization/train_ttpd_with_cv_8606f3c5_6_features=proj_t_g_proj_p_proj_t_p_inter,final_C=0.3164,penalty=l1,solver=liblinear,final_l1_ratio=0._2025-10-10_18-22-01')



OPTIMIZATION RESULTS

Best Trial: train_ttpd_with_cv_6c5f04fe
Best Accuracy: 0.9562
Accuracy Std Dev: 0.0589

Best Hyperparameters:
--------------------------------------------------------------------------------
  polarity_combo: {'penalty': 'l2', 'solver': 'liblinear'}
  final_combo: {'penalty': 'l1', 'solver': 'liblinear'}
  polarity_C: 0.010894470780762364
  polarity_l1_ratio: 0.7740809777807832
  polarity_max_iter: 1000
  final_C: 0.34455286384677336
  final_max_iter: 2000
  final_l1_ratio: 0.7021846107475395
  use_scaler: False
  features: ['proj_t_g', 'proj_p', 'proj_t_p_inter']

TOP 10 CONFIGURATIONS
          accuracy  config/polarity_C config/polarity_penalty config/final_penalty
trial_id                                                                          
6c5f04fe  0.956175           0.010894                      l2                   l1
45c0bf76  0.950263           0.042351                      l2                   l1
02e153fd  0.932290           2.046978              

### Feature importance

In [None]:
# TODO: Add Scaler here
config = {
    "polarity_C": 1.0,                        # default regularization strength (unused here)
    "polarity_penalty": "l2",                 # typical default, though your code doesn't use it
    "polarity_solver": "lbfgs",               # stable for small feature counts
    "polarity_max_iter": 1000,                # reasonable iteration cap



    "features": [],

    "final_penalty": None,                    # matches your LogisticRegression(penalty=None)
    "final_C": None,                          # not used since penalty=None
    "final_solver": "lbfgs",                  # sklearn default for multi-purpose use
    "final_max_iter": 2000                    # matches your pattern of long convergence caps
}

features_config = [
    ["proj_t_g", "proj_p", "proj_t_p", "proj_t_p_inter"],
    ["proj_t_g", "proj_t_p_inter"],
    ["proj_t_g", "proj_p"],
    ["proj_t_g", "proj_p", "proj_t_p"],
    ["proj_t_g", "proj_p", "proj_t_p_inter"],
]

acts_centered_train, acts_train, labels_train, polarities_train = collect_training_data(
        train_sets, train_set_sizes, model_family, model_size, model_type, layer)

for feature_config in features_config:
    config["features"] = feature_config



    avg_coef, avg_norm_coef = get_average_coef(acts_centered_train, acts_train, labels_train, polarities_train, runs=10, config=config)

    plot_lr_feature_importance(avg_coef, feature_names=feature_config)
    plot_lr_feature_importance(avg_norm_coef, feature_names=feature_config, title="Feature Importance Norm. (|Coefficient|)")


### Polarity direction accuracy

In [None]:
val_sets = ["cities_conj", "cities_disj", "sp_en_trans_conj", "sp_en_trans_disj",
                "inventors_conj", "inventors_disj", "animal_class_conj", "animal_class_disj",
                "element_symb_conj", "element_symb_disj", "facts_conj", "facts_disj",
                "common_claim_true_false", "counterfact_true_false"]

val_set_sizes = dataset_sizes(val_sets)

cv_train_sets = np.array(train_sets)
cv_test_sets = np.array(val_sets)
acts_centered, acts, labels, polarities = collect_training_data(cv_train_sets, train_set_sizes, model_family,
                                                                    model_size, model_type, layer)

# Test set
t_acts_centered, t_acts, t_labels, t_polarities = collect_training_data(cv_test_sets, dataset_sizes(val_sets), model_family, model_size, model_type, layer)

measure_polarity_direction_lr(acts, polarities, t_acts, t_polarities)


### Unseen topics

In [None]:
# compare TTPD, LR and CCS on topic-specific datasets
probe_types = [t for (name, t) in ALL_PROBES]
results = {t: defaultdict(list) for t in probe_types}
num_iter = 3

TTPD_CLASSES = [v for (k, v) in TTPD_TYPES]

total_iterations = len(probe_types) * num_iter * len(train_sets)
with tqdm(total=total_iterations, desc="Training and evaluating classifiers") as pbar: # progress bar
    for probe_type in probe_types:
        for n in range(num_iter):
            indices = np.arange(0, 12, 2)
            for i in indices:
                cv_train_sets = np.delete(np.array(train_sets), [i, i+1], axis=0)
                # load training data
                acts_centered, acts, labels, polarities = collect_training_data(cv_train_sets, train_set_sizes, model_family,
                                                                                model_size, model_type, layer)

                if probe_type in TTPD_CLASSES:
                    probe = probe_type.from_data(acts_centered, acts, labels, polarities)
                elif probe_type == LRProbe:
                    probe = LRProbe.from_data(acts, labels)
                elif probe_type == CCSProbe:
                    acts_affirm = acts[polarities == 1.0]
                    acts_neg = acts[polarities == -1.0]
                    labels_affirm = labels[polarities == 1.0]
                    mean_affirm = torch.mean(acts_affirm, dim=0)
                    mean_neg = torch.mean(acts_neg, dim=0)
                    acts_affirm = acts_affirm - mean_affirm
                    acts_neg = acts_neg - mean_neg
                    probe = CCSProbe.from_data(acts_affirm, acts_neg, labels_affirm, device=device).to('cpu')
                elif probe_type == MMProbe:
                    probe = MMProbe.from_data(acts, labels)

                # evaluate classification accuracy on held out datasets
                dm = DataManager()
                for j in range(0,2):
                    dm.add_dataset(train_sets[i+j], model_family, model_size, model_type, layer, split=None, center=False, device='cpu')
                    acts, labels = dm.data[train_sets[i+j]]

                    # classifier specific predictions
                    if probe_type == CCSProbe:
                        if j == 0:
                            acts = acts - mean_affirm
                        if j == 1:
                            acts = acts - mean_neg
                    predictions = probe.pred(acts)
                    results[probe_type][train_sets[i+j]].append((predictions == labels).float().mean().item())
                    pbar.update(1)

stat_results = compute_statistics(results)

# Compute mean accuracies and standard deviations for each probe type
probe_accuracies = compute_average_accuracies(results, num_iter)

for probe_type, stats in probe_accuracies.items():
    print(f"{probe_type}:")
    print(f"  Mean Accuracy: {stats['mean']*100:.2f}%")
    print(f"  Standard Deviation of the mean accuracy: {stats['std_dev']*100:.2f}%")

In [None]:
probes = [p for (t, p ) in ALL_PROBES]
titles = [t for (t, p) in ALL_PROBES]

fig, axes = plt.subplots(figsize=(14, 6), ncols=len(probes))

if len(probes) == 1:
    axes = [axes]


for t, (ax, key) in enumerate(zip(axes, probes)):
    grid = [[stat_results[key]['mean'][dataset]] for dataset in train_sets]
    grid_std = [[stat_results[key]['std'][dataset]] for dataset in train_sets]

    im = ax.imshow(grid, vmin=0, vmax=1, cmap='plasma', aspect='auto')

    for i, row in enumerate(grid):
        for j, val in enumerate(row):
            ax.text(j, i, f'{round(grid[i][j] * 100):2d} $\pm$ {round(grid_std[i][j] * 100):2d}',
                    ha='center', va='center', fontsize=13)

    ax.set_yticks(range(len(train_sets)))
    ax.set_xticks([])
    ax.set_title(titles[t], fontsize=12)

# y tick labels only on first subplot
axes[0].set_yticklabels(train_sets, fontsize=12)
for ax in axes[1:]:
    ax.set_yticklabels([])

cbar = fig.colorbar(im, ax=axes, shrink=0.6, location="right")
cbar.ax.tick_params(labelsize=12)

fig.suptitle("Classification accuracies", fontsize=15)
plt.show()


### Generalisation to logical conjunctions and disjunctions

In [None]:
# compare TTPD, LR, CCS and MM on logical conjunctions and disjunctions
val_sets = ["cities_conj", "cities_disj", "sp_en_trans_conj","sp_en_trans_disj",
             "inventors_conj", "inventors_disj", "animal_class_conj", "animal_class_disj",
               "element_symb_conj", "element_symb_disj", "facts_conj", "facts_disj",
            "common_claim_true_false", "counterfact_true_false"]


probe_types = [t for (name, t) in ALL_PROBES]
results = {t: defaultdict(list) for t in probe_types}

TTPD_CLASSES = [v for (k, v) in TTPD_TYPES]


num_iter = 20

total_iterations = len(probe_types) * num_iter
with tqdm(total=total_iterations, desc="Training and evaluating classifiers") as pbar: # progress bar
    for probe_type in probe_types:
        for n in range(num_iter):
            # load training data
            acts_centered, acts, labels, polarities = collect_training_data(train_sets, train_set_sizes, model_family, model_size,
                                                                             model_type, layer)

            if probe_type in TTPD_CLASSES:
                probe = probe_type.from_data(acts_centered, acts, labels, polarities)
            if probe_type == LRProbe:
                probe = LRProbe.from_data(acts, labels)
            if probe_type == CCSProbe:
                acts_affirm = acts[polarities == 1.0]
                acts_neg = acts[polarities == -1.0]
                labels_affirm = labels[polarities == 1.0]
                mean_affirm = torch.mean(acts_affirm, dim=0) 
                mean_neg = torch.mean(acts_neg, dim=0)
                acts_affirm = acts_affirm - mean_affirm
                acts_neg = acts_neg - mean_neg
                probe = CCSProbe.from_data(acts_affirm, acts_neg, labels_affirm, device=device).to('cpu')
            if probe_type == MMProbe:
                probe = MMProbe.from_data(acts, labels)

            # evaluate classification accuracy on validation datasets
            dm = DataManager()
            for val_set in val_sets:
                dm.add_dataset(val_set, model_family, model_size, model_type, layer, split=None, center=False, device='cpu')
                acts, labels = dm.data[val_set] # retrieve the activations and labels that were just added to the DM
                
                # classifier specific predictions
                if probe_type == CCSProbe:
                    acts = acts - (mean_affirm + mean_neg)/2
                predictions = probe.pred(acts) # one prediction per example. 0 if we think its a lie, 1 if we predicte its true

                # compare prediction with ground truth labels and average it
                results[probe_type][val_set].append((predictions == labels).float().mean().item())
            pbar.update(1)

stat_results = compute_statistics(results)

# Compute mean accuracies and standard deviations for each probe type
probe_accuracies = compute_average_accuracies(results, num_iter)

for probe_type, stats in probe_accuracies.items():
    print(f"{probe_type}:")
    print(f"  Mean Accuracy: {stats['mean']*100:.2f}%")
    print(f"  Standard Deviation of the mean accuracy: {stats['std_dev']*100:.2f}%")

In [None]:

probes = [p for (t, p ) in ALL_PROBES]
titles = [t for (t, p) in ALL_PROBES]

fig, axes = plt.subplots(figsize=(14, 6), ncols=len(probes))

if len(probes) == 1:
    axes = [axes]


for t, (ax, key) in enumerate(zip(axes, probes)):
    grid = [[stat_results[key]['mean'][dataset]] for dataset in val_sets]
    grid_std = [[stat_results[key]['std'][dataset]] for dataset in val_sets]

    im = ax.imshow(grid, vmin=0, vmax=1, cmap='plasma', aspect='auto')

    for i, row in enumerate(grid):
        for j, val in enumerate(row):
            ax.text(j, i, f'{round(grid[i][j] * 100):2d} $\pm$ {round(grid_std[i][j] * 100):2d}',
                    ha='center', va='center', fontsize=13)

    ax.set_yticks(range(len(val_sets)))
    ax.set_xticks([])
    ax.set_title(titles[t], fontsize=12)

# y tick labels only on first subplot
axes[0].set_yticklabels(val_sets, fontsize=12)
for ax in axes[1:]:
    ax.set_yticklabels([])

cbar = fig.colorbar(im, ax=axes, shrink=0.6, location='right')
cbar.ax.tick_params(labelsize=12)

fig.suptitle("Classification accuracies", fontsize=15)
plt.show()



### Generalisation to German statements

In [None]:
# compare TTPD, LR, CCS and MM on statements translated to german
val_sets = ["cities_de", "neg_cities_de", "sp_en_trans_de", "neg_sp_en_trans_de", "inventors_de", "neg_inventors_de", "animal_class_de",
                  "neg_animal_class_de", "element_symb_de", "neg_element_symb_de", "facts_de", "neg_facts_de"]

probe_types = [t for (name, t) in ALL_PROBES]
results = {t: defaultdict(list) for t in probe_types}

num_iter = 20

total_iterations = len(probe_types) * num_iter
with tqdm(total=total_iterations, desc="Training and evaluating classifiers") as pbar: # progress bar
    for probe_type in probe_types:
        for n in range(num_iter):
            # load training data
            acts_centered, acts, labels, polarities = collect_training_data(train_sets, train_set_sizes, model_family, model_size,
                                                                                           model_type, layer)
            if probe_type in TTPD_CLASSES:
                probe = probe_type.from_data(acts_centered, acts, labels, polarities)
            if probe_type == LRProbe:
                probe = LRProbe.from_data(acts, labels)
            if probe_type == CCSProbe:
                acts_affirm = acts[polarities == 1.0]
                acts_neg = acts[polarities == -1.0]
                labels_affirm = labels[polarities == 1.0]
                mean_affirm = torch.mean(acts_affirm, dim=0) 
                mean_neg = torch.mean(acts_neg, dim=0)
                acts_affirm = acts_affirm - mean_affirm
                acts_neg = acts_neg - mean_neg
                probe = CCSProbe.from_data(acts_affirm, acts_neg, labels_affirm, device=device).to('cpu')
            if probe_type == MMProbe:
                probe = MMProbe.from_data(acts, labels)

            # evaluate classification accuracy on validation datasets
            dm = DataManager()
            for val_set in val_sets:
                dm.add_dataset(val_set, model_family, model_size, model_type, layer, split=None, center=False, device='cpu')
                acts, labels = dm.data[val_set]
                
                # classifier specific predictions
                if probe_type == CCSProbe:
                    acts = acts - (mean_affirm + mean_neg)/2
                predictions = probe.pred(acts)
                
                results[probe_type][val_set].append((predictions == labels).float().mean().item())
            pbar.update(1)

stat_results = compute_statistics(results)

# Compute mean accuracies and standard deviations for each probe type
probe_accuracies = compute_average_accuracies(results, num_iter)

for probe_type, stats in probe_accuracies.items():
    print(f"{probe_type}:")
    print(f"  Mean Accuracy: {stats['mean']*100:.2f}%")
    print(f"  Standard Deviation of the mean accuracy: {stats['std_dev']*100:.2f}%")

In [None]:
probes = [p for (t, p ) in ALL_PROBES]
titles = [t for (t, p) in ALL_PROBES]

fig, axes = plt.subplots(figsize=(14, 6), ncols=len(probes))


if len(probes) == 1:
    axes = [axes]

for t, (ax, key) in enumerate(zip(axes, probes)):
    grid = [[stat_results[key]['mean'][dataset]] for dataset in val_sets]
    grid_std = [[stat_results[key]['std'][dataset]] for dataset in val_sets]

    im = ax.imshow(grid, vmin=0, vmax=1, cmap='plasma', aspect='auto')

    for i, row in enumerate(grid):
        for j, val in enumerate(row):
            ax.text(j, i, f'{round(grid[i][j] * 100):2d} $\pm$ {round(grid_std[i][j] * 100):2d}',
                    ha='center', va='center', fontsize=13)

    ax.set_yticks(range(len(val_sets)))
    ax.set_xticks([])
    ax.set_title(titles[t], fontsize=12)

# y tick labels only on first subplot
axes[0].set_yticklabels(val_sets, fontsize=12)
for ax in axes[1:]:
    ax.set_yticklabels([])

cbar = fig.colorbar(im, ax=axes, shrink=0.6, location='right')
cbar.ax.tick_params(labelsize=12)

fig.suptitle("Classification accuracies", fontsize=15)
plt.show()


### Displaying generalisation to Conjunctions, Disjunctions and German statements in one table

In [None]:
# Define the validation sets and the probe types
val_sets = ["cities_conj", "cities_disj", "sp_en_trans_conj","sp_en_trans_disj",
             "inventors_conj", "inventors_disj", "animal_class_conj", "animal_class_disj",
               "element_symb_conj", "element_symb_disj", "facts_conj", "facts_disj", "cities_de", "neg_cities_de", "sp_en_trans_de", "neg_sp_en_trans_de", "inventors_de", "neg_inventors_de", "animal_class_de",
                  "neg_animal_class_de", "element_symb_de", "neg_element_symb_de", "facts_de", "neg_facts_de",
            "common_claim_true_false", "counterfact_true_false"]

probe_types = [t for (name, t) in ALL_PROBES]
results = {t: defaultdict(list) for t in probe_types}
num_iter = 20

TTPD_CLASSES = [v for (k, v) in TTPD_TYPES]

# Training and evaluating classifiers
total_iterations = len(probe_types) * num_iter
with tqdm(total=total_iterations, desc="Training and evaluating classifiers") as pbar:
    for probe_type in probe_types:
        for n in range(num_iter):
            # load training data
            acts_centered, acts, labels, polarities = collect_training_data(train_sets, train_set_sizes, model_family, model_size,
                                                                                           model_type, layer)
            if probe_type in TTPD_CLASSES:
                probe = probe_type.from_data(acts_centered, acts, labels, polarities)
            if probe_type == LRProbe:
                probe = LRProbe.from_data(acts, labels)
            if probe_type == CCSProbe:
                acts_affirm = acts[polarities == 1.0]
                acts_neg = acts[polarities == -1.0]
                labels_affirm = labels[polarities == 1.0]
                mean_affirm = torch.mean(acts_affirm, dim=0) 
                mean_neg = torch.mean(acts_neg, dim=0)
                acts_affirm = acts_affirm - mean_affirm
                acts_neg = acts_neg - mean_neg
                probe = CCSProbe.from_data(acts_affirm, acts_neg, labels_affirm, device=device).to('cpu')
            if probe_type == MMProbe:
                probe = MMProbe.from_data(acts, labels)

            # evaluate classification accuracy on validation datasets
            dm = DataManager()
            for val_set in val_sets:
                dm.add_dataset(val_set, model_family, model_size, model_type, layer, split=None, center=False, device='cpu')
                acts, labels = dm.data[val_set]
                
                # classifier specific predictions
                if probe_type == CCSProbe:
                    acts = acts - (mean_affirm + mean_neg)/2
                predictions = probe.pred(acts)
                results[probe_type][val_set].append((predictions == labels).float().mean().item())
            pbar.update(1)

In [None]:
# Define the groups
groups = {
    'Conjunctions': [dataset for dataset in val_sets if dataset.endswith('_conj')],
    'Disjunctions': [dataset for dataset in val_sets if dataset.endswith('_disj')],
    'Affirmative German': [dataset for dataset in val_sets if dataset.endswith('_de') and not dataset.startswith('neg_')],
    'Negated German': [dataset for dataset in val_sets if dataset.startswith('neg_') and dataset.endswith('_de')],
    'common_claim_true_false': ['common_claim_true_false'],
    'counterfact_true_false': ['counterfact_true_false']
}

# Initialize group results
group_results = {probe_type: {group_name: [] for group_name in groups} for probe_type in probe_types}

# Process results to compute mean accuracies per group per classifier
for probe_type in probe_types:
    for n in range(num_iter):
        for group_name, group_datasets in groups.items():
            accuracies = []
            for dataset in group_datasets:
                accuracy = results[probe_type][dataset][n]
                accuracies.append(accuracy)
            mean_accuracy = sum(accuracies) / len(accuracies)
            group_results[probe_type][group_name].append(mean_accuracy)

# Compute statistics
stat_group_results = {probe_type: {'mean': {}, 'std': {}} for probe_type in probe_types}

for probe_type in probe_types:
    for group_name in groups:
        accuracies = group_results[probe_type][group_name]
        mean_accuracy = np.mean(accuracies)
        std_accuracy = np.std(accuracies)
        stat_group_results[probe_type]['mean'][group_name] = mean_accuracy
        stat_group_results[probe_type]['std'][group_name] = std_accuracy

# Map probe types to classifier names

probe_type_to_name = {probe:name for (name, probe) in ALL_PROBES}

# probe_type_to_name = {
#     TTPD: 'TTPD',
#     TTPD4d: "TTPD4d",
#     TTPD3dTp: "TTPD3dTp",
#     TTPD3dTpInv: "TTPD3dTpInv",
#     LRProbe: 'LR',
#     CCSProbe: 'CCS',
#     MMProbe: 'MM'
# }

# Create DataFrames for mean accuracies and standard deviations
group_names = ['Conjunctions', 'Disjunctions', 'Affirmative German', 'Negated German', 'common_claim_true_false', 'counterfact_true_false']
classifier_names = [n for (n, _) in ALL_PROBES]
# classifier_names = ['TTPD', "TTPD4d", "TTPD3dTp", "TTPD3dTpInv", 'LR', 'CCS', 'MM']

mean_df = pd.DataFrame(index=group_names, columns=classifier_names)
std_df = pd.DataFrame(index=group_names, columns=classifier_names)

for probe_type in probe_types:
    classifier_name = probe_type_to_name[probe_type]
    for group_name in group_names:
        mean_accuracy = stat_group_results[probe_type]['mean'][group_name]
        std_accuracy = stat_group_results[probe_type]['std'][group_name]
        mean_df.loc[group_name, classifier_name] = mean_accuracy
        std_df.loc[group_name, classifier_name] = std_accuracy

num_classifiers = len(classifier_names)
fig, axes = plt.subplots(figsize=(2.5*num_classifiers, 6), ncols=num_classifiers)

for idx, classifier_name in enumerate(classifier_names):
    ax = axes[idx]
    mean_values = mean_df[classifier_name].values.astype(float)
    std_values = std_df[classifier_name].values.astype(float)

    # Create heatmap with a single column
    im = ax.imshow(mean_values[:, np.newaxis], vmin=0, vmax=1, cmap='plasma', aspect='auto')

    # Annotate the heatmap
    for i in range(len(group_names)):
        mean_accuracy = mean_values[i]
        std_accuracy = std_values[i]
        ax.text(0, i, f'{round(mean_accuracy * 100):2d} ± {round(std_accuracy * 100):2d}',
                ha='center', va='center', fontsize=14)

    # Set ticks and labels
    ax.set_xticks([])
    if idx == 0:
        ax.set_yticks(np.arange(len(group_names)))
        ax.set_yticklabels(group_names, fontsize=14)
    else:
        ax.set_yticks([])
    ax.set_title(classifier_name, fontsize=15)

# Add colorbar on the right
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), location='right', shrink=0.8)
cbar.ax.tick_params(labelsize=13)

fig.suptitle("Classification Accuracies", fontsize=17)
plt.show()

### Real world scenarios / lies

In [None]:
probe_types = [t for (name, t) in ALL_PROBES]
results = {t: [] for t in probe_types}
num_iter = 50

total_iterations = len(probe_types) * num_iter
with tqdm(total=total_iterations, desc="Training and evaluating classifiers") as pbar: # progress bar
    for probe_type in probe_types:
        for n in range(num_iter):
            # load training data
            acts_centered, acts, labels, polarities = collect_training_data(train_sets, train_set_sizes, model_family,
                                                                                           model_size, model_type,layer)
            if probe_type in TTPD_CLASSES:
                probe = probe_type.from_data(acts_centered, acts, labels, polarities)
            if probe_type == LRProbe:
                probe = LRProbe.from_data(acts, labels)
            if probe_type == CCSProbe:
                acts_affirm = acts[polarities == 1.0]
                acts_neg = acts[polarities == -1.0]
                labels_affirm = labels[polarities == 1.0]
                mean_affirm = torch.mean(acts_affirm, dim=0) 
                mean_neg = torch.mean(acts_neg, dim=0)
                acts_affirm = acts_affirm - mean_affirm
                acts_neg = acts_neg - mean_neg
                probe = CCSProbe.from_data(acts_affirm, acts_neg, labels_affirm, device=device).to('cpu')
            if probe_type == MMProbe:
                probe = MMProbe.from_data(acts, labels)

            # evaluate classification accuracy on real world scenarios
            dm = DataManager()
            real_world_dataset = "real_world_scenarios/all_unambiguous_replies"
            dm.add_dataset(real_world_dataset, model_family, model_size, model_type, layer, split=None, center=False, device='cpu')
            acts, labels = dm.data[real_world_dataset]
            
            # classifier specific predictions
            if probe_type == CCSProbe:
                acts = acts - (mean_affirm + mean_neg)/2

            predictions = probe.pred(acts)
            results[probe_type].append((predictions == labels).float().mean().item())
            pbar.update(1)

for probe_type in probe_types:
    mean = np.mean(results[probe_type])
    std = np.std(results[probe_type])
    print(f"{probe_type.__name__}:")
    print(f"  Mean Accuracy: {mean*100:.2f}%")
    print(f"  Standard Deviation: {std*100:.2f}%")