In [1]:
import os
import sys
import pickle
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch as t
sys.path.append(os.path.dirname(os.path.abspath('.')))
import configs
from datasets import TruthfulQADataset, DishonestQADataset, AmongUsDataset, RolePlayingDataset, RepEngDataset, ApolloProbeDataset
from probes import LinearProbe

from configs import config_phi4

In [2]:
dataset_name = 'TruthfulQADataset'
config = config_phi4
model, tokenizer, device = None, None, 'cpu'
EXPT_NAME: str = "2025-02-01_phi_phi_100_games_v3"
# dataset = eval(f"{dataset_name}")(config, model=model, tokenizer=tokenizer, device=device, expt_name=EXPT_NAME, test_split=0)
dataset = eval(f"{dataset_name}")(config, model=model, tokenizer=tokenizer, device=device, test_split=0.2)

In [3]:
# dataset.populate_dataset(force_redo=False, just_load=True)
dataset.populate_dataset(force_redo=False)

Loading existing chunk from ./data/TruthfulQA_phi4_acts/chunk_0.pkl


In [10]:
train_loader = dataset.get_train(batch_size=32, num_tokens=5, chunk_idx=0)

if dataset_name == 'AmongUsDataset':
    # Create new dataset with flipped labels (so 1 is crewmate and 0 is impostor)
    flipped_dataset = [(x, 1-y) for x,y in train_loader.dataset]

    # Create new train loader with flipped labels (won't need this fix when we recreate AmongUs activations)
    train_loader = t.utils.data.DataLoader(
        flipped_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

In [11]:
probe = LinearProbe(input_dim=dataset.activation_size, device=device)

print(f'Training probe on {len(train_loader)} batches and {len(train_loader.dataset)} samples.')

Training probe on 198 batches and 6320 samples.


In [12]:
probe.fit(train_loader, epochs=20)

Epoch 1: Train Loss = 0.5173, Train Acc = 0.7337
Epoch 3: Train Loss = 0.3281, Train Acc = 0.8541
Epoch 5: Train Loss = 0.2667, Train Acc = 0.8834
Epoch 7: Train Loss = 0.2381, Train Acc = 0.9028
Epoch 9: Train Loss = 0.2129, Train Acc = 0.9142
Epoch 11: Train Loss = 0.2036, Train Acc = 0.9166
Epoch 13: Train Loss = 0.1942, Train Acc = 0.9212
Epoch 15: Train Loss = 0.1800, Train Acc = 0.9245
Epoch 17: Train Loss = 0.1765, Train Acc = 0.9282
Epoch 19: Train Loss = 0.1710, Train Acc = 0.9345
Final Train Acc: 0.9340


In [13]:
checkpoint_path = f'checkpoints/{dataset_name}_probe_{config["short_name"]}.pkl'
with open(checkpoint_path, 'wb') as f:
    pickle.dump(probe, f)
    print(f"Probe saved to {checkpoint_path}")

Probe saved to checkpoints/TruthfulQADataset_probe_phi4.pkl
