In [1]:
import sys
sys.path.append('../linear-probes')

In [2]:
from probes import LinearProbe
from probe_datasets import DishonestQADataset, AmongUsDataset, RepEngDataset
from configs import config_phi4, config_llama3
from evaluate_utils import evaluate_probe_on_activation_dataset
from probe_utils import read_jsonl_as_json

In [3]:
import torch.nn as nn
import torch as t
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import torch as t
import gc
from typing import List, Tuple, Dict, Any
import pickle
import pandas as pd
import os
import sys
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
import tqdm
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

In [4]:
def evaluate_probe_on_amongus(dataset, probe, config, model=None, tokenizer=None, device='cpu', plot_stuff=False):
    """
    Evaluate a probe on the AmongUs dataset and return the deception AUROC.
    
    Args:
        probe: LinearProbe object to evaluate
        config: Configuration dictionary
        model: Language model (optional)
        tokenizer: Tokenizer for the model (optional)
        device: Device to run model on
        plot_stuff: Whether to generate plots (not implemented)
        
    Returns:
        float: AUROC for deception detection
    """
    all_probe_outputs = []
    chunk_size: int = 500
    list_of_chunks_to_eval = [1]
    row_indices = []

    # Evaluate probe on test chunks
    for chunk_idx in tqdm.tqdm(list_of_chunks_to_eval):
        test_acts_chunk = dataset.get_test_acts(chunk_idx)
        
        # Store the row indices for this chunk
        start_idx = chunk_idx * chunk_size
        end_idx = start_idx + len(test_acts_chunk)
        row_indices.extend(range(start_idx, end_idx))
        
        chunk_probe_outputs, _ = evaluate_probe_on_activation_dataset(
            chunk_data=test_acts_chunk,
            probe=probe,
            device=device,
            num_tokens=None,
            verbose=False,
        )
        all_probe_outputs.extend(chunk_probe_outputs)

    av_probe_outputs = all_probe_outputs

    # Create JSON outputs with probe predictions
    json_outputs = []
    eval_rows_num = len(av_probe_outputs)

    for i in range(eval_rows_num):
        actual_row_idx = row_indices[i]
        row = dataset.agent_logs_df.iloc[actual_row_idx]
        probe_output = av_probe_outputs[i]
        
        json_output = {
            "game_index": int(row["game_index"].split(" ")[1]) if isinstance(row["game_index"], str) else int(row["game_index"]),
            "step": int(row["step"]),
            "player_name": row["player.name"],
            "probe_output": probe_output,
            "timestamp": row["timestamp"],
            "player_role": row["player.personality"],
        }
        json_outputs.append(json_output)

    probe_output_df = pd.DataFrame(json_outputs)
    
    # Load skill scores from evaluation results
    EXPT_NAMES: List[str] = [config["expt_name"],]
    LOGS_PATH: str = "../evaluations/results/"
    RAW_PATH: str = "../expt-logs/"
    summary_logs_paths: List[str] = [os.path.join(LOGS_PATH, f"{expt_name}_all_skill_scores.json") for expt_name in EXPT_NAMES]
    summary_dfs: List[pd.DataFrame] = []

    for summary_logs_path in summary_logs_paths:
        summary_logs: List[Dict[str, Any]] = read_jsonl_as_json(summary_logs_path)
        summary_df: pd.DataFrame = pd.json_normalize(summary_logs)
        summary_df = summary_df.sort_values(by=["game_index", "step"])
        summary_dfs.append(summary_df)
    
    summary_df = summary_dfs[0]
    summary_df['game_index'] = summary_df['game_index'].str.extract('(\d+)').astype(int) if summary_df['game_index'].dtype == 'object' else summary_df['game_index'].astype(int)

    # Merge probe outputs with skill scores
    probe_df = probe_output_df[['game_index', 'step', 'player_name', 'probe_output', 'timestamp']].copy()
    summary_df = summary_df[['timestamp', 'game_index', 'step', 'player_identity', 'awareness', 'lying', 'deception', 'planning', 'player_name', 'action']].copy()
    probe_eval = pd.merge(
        probe_df,
        summary_df[['game_index', 'step', 'player_name', 'player_identity', 'awareness', 'lying', 'deception', 'planning', 'action', 'timestamp']],
        on=['timestamp'],
        how='inner'
    )

    # Invert lying and deception scores (higher = more lying/deception)
    probe_eval['lying'] = 10 - probe_eval['lying'].astype(int)
    probe_eval['deception'] = 10 - probe_eval['deception'].astype(int)

    # Convert behavior scores to binary ground truth (>5 is lying)
    behaviors = ['lying', 'deception', 'planning', 'awareness']
    for col in behaviors:
        probe_eval[col] = probe_eval[col].astype(int)
        probe_eval[f'{col}_truth'] = (probe_eval[col] > 5).astype(int)
    
    # Calculate ROC for deception
    fpr_deception, tpr_deception, _ = roc_curve(probe_eval['deception_truth'], probe_eval['probe_output'])
    roc_auc_deception = auc(fpr_deception, tpr_deception)
    
    return roc_auc_deception

In [5]:
ks = [0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
# ks = [0.9]
test_aurocs = []
best_lambdas = []
amongus_aurocs = []
train_accuracies = []
average_cosine_similarities = []

model = None
tokenizer = None
device = 'cpu'
config = config_phi4

In [6]:
for k in ks:
    print(f"Running for k = {k}")
    dataset = RepEngDataset(config, model=model, tokenizer=tokenizer, device=device, test_split=0.2)
    amongus_dataset = AmongUsDataset(config, model=model, tokenizer=tokenizer, device=device, test_split=1, expt_name=config['expt_name'])
    train_loader, val_loader = dataset.get_train(
        batch_size=config["probe_training_batch_size"],
        num_tokens=config["probe_training_num_tokens"],
        keep_frac=k,
        get_val=True
    )
    # Define a range of L2 weight decay values to try
    weight_decays = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
    best_val_acc = 0
    best_wd = None
    best_probe = None
    
    # Grid search to find the best weight decay
    for wd in weight_decays:
        # Create and train probe with current weight decay
        probe_candidate = LinearProbe(
            input_dim=dataset.activation_size,
            device=device,
            lr=config["probe_training_learning_rate"],
            seed=420,
            verbose=False,
            weight_decay=wd
        )
        probe_candidate.fit(train_loader, epochs=config["probe_training_epochs"])
        
        # Evaluate on validation set
        val_acc = probe_candidate.accuracy(val_loader)
        
        # Update best if this is better
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_wd = wd
            best_probe = probe_candidate
    
    # Use the best probe for further evaluation
    probe = best_probe
    train_acc = probe.train_accs[-1]  # Get the final training accuracy
    print(f"Best weight decay: {best_wd}")
    best_lambdas.append(best_wd)  # Store the best weight decay
    test_acts_chunk = dataset.get_test_acts()
    amongus_test_acts_chunk = amongus_dataset.get_test_acts()
    av_probe_outputs, accuracy = evaluate_probe_on_activation_dataset(
        chunk_data=test_acts_chunk,
        probe=probe,
        device=device,
        num_tokens=None,
        verbose=False
    )
    labels = t.tensor([batch[1] for batch in test_acts_chunk]).numpy()
    fpr, tpr, _ = roc_curve(labels, av_probe_outputs)
    roc_auc = auc(fpr, tpr)
    test_aurocs.append(roc_auc)
    amongus_aurocs.append(evaluate_probe_on_amongus(amongus_dataset, probe, config, model, tokenizer, device, plot_stuff=False))
    train_accuracies.append(train_acc)
    directions = []
    for i in tqdm.trange(10):
        seed = random.randint(0, 1000000)
        probe = LinearProbe(input_dim=dataset.activation_size,  device=device,  lr=config["probe_training_learning_rate"], seed=seed, verbose=False)
        acc = probe.fit(train_loader, epochs=config["probe_training_epochs"])
        directions.append(probe.model.linear.weight.data.cpu().numpy()[0])
    directions = np.array(directions)
    normalized_directions = directions / np.linalg.norm(directions, axis=1)[:, np.newaxis]
    directions_matrix = np.dot(normalized_directions, normalized_directions.T)
    average_cosine_similarities.append(np.mean(directions_matrix[np.triu_indices(len(directions_matrix), k=1)]))

Running for k = 0.01


Validation size: 15


100%|██████████| 122/122 [00:00<00:00, 418.31it/s]
100%|██████████| 500/500 [00:27<00:00, 18.04it/s]
100%|██████████| 1/1 [00:37<00:00, 37.50s/it]
100%|██████████| 10/10 [00:00<00:00, 114.21it/s]


Running for k = 0.02
Validation size: 55


100%|██████████| 122/122 [00:00<00:00, 533.33it/s]
100%|██████████| 500/500 [00:27<00:00, 18.33it/s]
100%|██████████| 1/1 [00:37<00:00, 37.06s/it]
100%|██████████| 10/10 [00:00<00:00, 36.27it/s]


Running for k = 0.05
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 427.65it/s]
100%|██████████| 500/500 [00:26<00:00, 18.81it/s]
100%|██████████| 1/1 [00:36<00:00, 36.40s/it]
100%|██████████| 10/10 [00:00<00:00, 12.67it/s]


Running for k = 0.1
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 281.44it/s]
100%|██████████| 500/500 [00:25<00:00, 19.45it/s]
100%|██████████| 1/1 [00:36<00:00, 36.17s/it]
100%|██████████| 10/10 [00:01<00:00,  5.68it/s]


Running for k = 0.2
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 290.36it/s]
100%|██████████| 500/500 [00:26<00:00, 19.14it/s]
100%|██████████| 1/1 [00:36<00:00, 36.51s/it]
100%|██████████| 10/10 [00:03<00:00,  3.23it/s]


Running for k = 0.3
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 238.32it/s]
100%|██████████| 500/500 [00:25<00:00, 19.67it/s]
100%|██████████| 1/1 [00:36<00:00, 36.40s/it]
100%|██████████| 10/10 [00:04<00:00,  2.24it/s]


Running for k = 0.4
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 220.77it/s]
100%|██████████| 500/500 [00:25<00:00, 19.42it/s]
100%|██████████| 1/1 [00:37<00:00, 37.04s/it]
100%|██████████| 10/10 [00:05<00:00,  1.70it/s]


Running for k = 0.5
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 337.47it/s]
100%|██████████| 500/500 [00:25<00:00, 19.48it/s]
100%|██████████| 1/1 [00:37<00:00, 37.48s/it]
100%|██████████| 10/10 [00:07<00:00,  1.33it/s]


Running for k = 0.6
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 226.48it/s]
100%|██████████| 500/500 [00:25<00:00, 19.63it/s]
100%|██████████| 1/1 [00:37<00:00, 37.27s/it]
100%|██████████| 10/10 [00:08<00:00,  1.19it/s]


Running for k = 0.7
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 210.39it/s]
100%|██████████| 500/500 [00:26<00:00, 19.01it/s]
100%|██████████| 1/1 [00:38<00:00, 38.33s/it]
100%|██████████| 10/10 [00:09<00:00,  1.01it/s]


Running for k = 0.8
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 200.94it/s]
100%|██████████| 500/500 [00:25<00:00, 19.56it/s]
100%|██████████| 1/1 [00:37<00:00, 37.73s/it]
100%|██████████| 10/10 [00:11<00:00,  1.20s/it]


Running for k = 0.9
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 214.87it/s]
100%|██████████| 500/500 [00:25<00:00, 19.59it/s]
100%|██████████| 1/1 [00:37<00:00, 37.77s/it]
100%|██████████| 10/10 [00:12<00:00,  1.27s/it]


Running for k = 1.0
Validation size: 123


100%|██████████| 122/122 [00:00<00:00, 208.41it/s]
100%|██████████| 500/500 [00:26<00:00, 19.22it/s]
100%|██████████| 1/1 [00:38<00:00, 38.11s/it]
100%|██████████| 10/10 [00:14<00:00,  1.41s/it]


In [48]:
data_size = len(train_loader.dataset)
# datapoints = [int(data_size * k) for k in ks]
# now the number of datapoints also includes the validation set
datapoints = [int(data_size * k * (1.5)) for k in ks]

# plot test auroc, train accuracy, and average cosine similarity vs k in different colors on the same plot
fig = go.Figure()
colors = ['#D4A27F', '#5F8670', '#7D6E83', '#A75D5D', '#3F4E4F', '#6D8B74', '#BDCDD6', '#4F709C', '#9E7676', '#6C3428']
fig.add_trace(go.Scatter(x=datapoints, y=test_aurocs, mode='lines+markers', name='Test AUROC (RepEng)', 
                         line=dict(color=colors[0], width=3)))
fig.add_trace(go.Scatter(x=datapoints, y=amongus_aurocs, mode='lines+markers', name='Test AUROC (AmongUs)', 
                         line=dict(color=colors[1], width=3)))
fig.add_trace(go.Scatter(x=datapoints, y=train_accuracies, mode='lines+markers', name='Train Acc. (RepEng)', 
                         line=dict(color=colors[2], width=3)))
fig.add_trace(go.Scatter(x=datapoints, y=average_cosine_similarities, mode='lines+markers', name='Cosine Sim. (C(10, 2))', 
                         line=dict(color=colors[3], width=3)))
fig.update_xaxes(title='Train (+Val) Data Size', title_font=dict(family="Computer Modern", size=16))
fig.update_yaxes(title='Similarity / Performance', title_font=dict(family="Computer Modern", size=16))
fig.update_layout(template='plotly_white')
fig.update_layout(
    template='plotly_white', 
    font=dict(family="Computer Modern", size=14), 
    xaxis=dict(
        gridcolor='lightgray', 
        zeroline=True, 
        zerolinecolor='black', 
        showline=True, 
        linewidth=2, 
        linecolor='black', 
    ), 
    yaxis=dict(
        gridcolor='lightgray', 
        zeroline=True,
        zerolinecolor='black',
        zerolinewidth=2,
        showline=False, 
        linewidth=2, 
        linecolor='black',
    ),
    plot_bgcolor='#fafaf7'
)
fig.update_layout(width=600, height=400)
fig.update_layout(legend=dict(x=0.4, y=0.1, font=dict(family="Computer Modern", size=15), bgcolor='rgba(220, 220, 215, 0.8)'))
fig.update_layout(font=dict(family="Computer Modern", size=15))
fig.update_yaxes(range=[0, 1.1])
fig.update_xaxes(range=[0, max(datapoints) * 1.05])  # Ensure x-axis starts at 0
fig.update_xaxes(tickfont=dict(family="Computer Modern", size=15))
fig.update_yaxes(tickfont=dict(family="Computer Modern", size=15))
fig.show()

In [50]:
fig.write_image("plots/less_data_probes.pdf")

In [25]:
best_lambdas

[1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 0.01,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 0.01,
 1e-05,
 1e-05,
 1e-05,
 1e-05,
 1e-05]

In [11]:
len(test_aurocs), len(amongus_aurocs), len(train_accuracies), len(average_cosine_similarities)

(26, 26, 26, 26)