In [16]:
import os
import sys
import pickle
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch as t
sys.path.append(os.path.dirname(os.path.abspath('.')))
import configs
from datasets import TruthfulQADataset, DishonestQADataset, AmongUsDataset, RolePlayingDataset, RepEngDataset
from probes import LinearProbe

from configs import config_phi4

In [8]:
dataset_name = 'AmongUsDataset'
config = config_phi4
model, tokenizer, device = None, None, 'cpu'
EXPT_NAME: str = "2025-02-01_phi_phi_100_games_v3"
dataset = eval(f"{dataset_name}")(config, model=model, tokenizer=tokenizer, device=device, expt_name=EXPT_NAME, test_split=0)

In [9]:
dataset.populate_dataset(force_redo=False, just_load=True)

Loaded 86 existing chunks


In [17]:
if dataset_name == 'AmongUsDataset':
    # Get original train loader
    train_loader = dataset.get_train(batch_size=32, num_tokens=10, chunk_idx=0)

    # Create new dataset with flipped labels (so 1 is crewmate and 0 is impostor)
    flipped_dataset = [(x, 1-y) for x,y in train_loader.dataset]

    # Create new train loader with flipped labels (won't need this fix when we recreate AmongUs activations)
    train_loader = t.utils.data.DataLoader(
        flipped_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

In [18]:
probe = LinearProbe(input_dim=dataset.activation_size, device=device)

print(f'Training probe on {len(train_loader)} batches and {len(train_loader.dataset)} samples.')

Training probe on 32 batches and 1000 samples.


In [19]:
probe.fit(train_loader, epochs=10)

Epoch 1: Train Loss = 0.4037, Train Acc = 0.8030
Epoch 2: Train Loss = 0.1286, Train Acc = 0.9870
Epoch 3: Train Loss = 0.0762, Train Acc = 0.9960
Epoch 4: Train Loss = 0.0506, Train Acc = 0.9960
Epoch 5: Train Loss = 0.0366, Train Acc = 0.9990
Epoch 6: Train Loss = 0.0306, Train Acc = 0.9990
Epoch 7: Train Loss = 0.0248, Train Acc = 1.0000
Epoch 8: Train Loss = 0.0189, Train Acc = 1.0000
Epoch 9: Train Loss = 0.0169, Train Acc = 1.0000
Epoch 10: Train Loss = 0.0141, Train Acc = 1.0000
Final Train Acc: 1.0000


In [20]:
checkpoint_path = f'checkpoints/{dataset_name}_probe_{config["short_name"]}.pkl'
with open(checkpoint_path, 'wb') as f:
    pickle.dump(probe, f)
    print(f"Probe saved to {checkpoint_path}")

Probe saved to checkpoints/AmongUsDataset_probe_phi4.pkl
