In [14]:
import os
import sys
import json
import numpy as np
import requests
import pandas as pd
from typing import List, Dict, Any, Tuple, Union, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch as t
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from pprint import pprint as pp
import importlib
import pickle

sys.path.append(os.path.dirname(os.path.abspath('.')))

import datasets, plots, configs, probes, evaluate_utils
for module in [datasets, plots, configs, probes, evaluate_utils]:
    importlib.reload(module)

from datasets import AmongUsDataset
from plots import plot_behavior_distribution, plot_roc_curves, add_roc_curves, print_metrics, plot_roc_curve_eval
from configs import config_phi4, config_gpt2
from evaluate_utils import evaluate_probe_on_string, evaluate_probe_on_dataset, evaluate_probe_on_activation_dataset
from utils import load_agent_logs_df, read_jsonl_as_json, load_game_summary

config = config_gpt2
model_name = config["model_name"]
load_models = True

In [15]:
if load_models:
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto")
    device = model.device
else:
    model, tokenizer, device = None, None, 'cpu'

In [16]:
LOGS_PATH, RAW_PATH = "../evaluations/results/", "../expt-logs/"
sys.path.append("..")
EXPT_NAME, DESCRIPTIONS = "2025-02-01_phi_phi_100_games_v3", "Crew: Phi, Imp: Phi"

dataset = AmongUsDataset({**config, "test_split": 1.0}, model=model, tokenizer=tokenizer, device=device, expt_name=EXPT_NAME)
eval(f"model.{config['hook_component']}").register_forward_hook(dataset.activation_cache.hook_fn)

<torch.utils.hooks.RemovableHandle at 0x7f3a16300370>

In [20]:
tokenizer.pad_token = tokenizer.eos_token
dataset.populate_dataset(force_redo=True, max_rows=0, batched=False, seq_len=1024)

Populated 0 rows of 8585
Populated 1 rows of 8585
Populated 2 rows of 8585
Populated 3 rows of 8585
Populated 4 rows of 8585
Populated 5 rows of 8585
Populated 6 rows of 8585
Populated 7 rows of 8585
Populated 8 rows of 8585
Populated 9 rows of 8585
Populated 10 rows of 8585
Populated 11 rows of 8585
Populated 12 rows of 8585
Populated 13 rows of 8585
Populated 14 rows of 8585
Populated 15 rows of 8585
Populated 16 rows of 8585
Populated 17 rows of 8585
Populated 18 rows of 8585
Populated 19 rows of 8585
Populated 20 rows of 8585
Populated 21 rows of 8585
Populated 22 rows of 8585
Populated 23 rows of 8585
Populated 24 rows of 8585
Populated 25 rows of 8585


KeyboardInterrupt: 