In [36]:
from abc import ABC, abstractmethod
from pathlib import Path
from pprint import pprint as pp
from typing import List, Optional, Tuple, cast
import os

import einops
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from einops import einsum, rearrange
from jaxtyping import Float
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append(os.path.dirname(os.path.abspath('.')))

In [47]:
import importlib
import datasets
import configs
import probes
importlib.reload(datasets)
importlib.reload(configs)
importlib.reload(probes)
from datasets import TruthfulQADataset

In [38]:
from configs import config_gpt2, config_phi4

config = config_phi4

In [39]:
model_name = config["model_name"]
load_models: bool = False

if load_models:
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto")
    device = model.device
    tqa_dataset = TruthfulQADataset(config, model=model, tokenizer=tokenizer, device=device)
    eval(f"model.{config['hook_component']}").register_forward_hook(tqa_dataset.activation_cache.hook_fn)

else:
    model, tokenizer, device = None, None, 'cpu'
    tqa_dataset = TruthfulQADataset(config, model=model, tokenizer=tokenizer, device=device)

In [40]:
tqa_dataset.populate_dataset(force_redo=load_models)

In [41]:
tqa_dataset.get_stats()

{'total_samples': 6320, 'class_distribution': {0: 3160, 1: 3160}}

In [48]:
from probes import LinearProbe
train_loader = tqa_dataset.get_train(batch_size=32)
probe = LinearProbe(input_dim=tqa_dataset.activation_size, device=device)

In [49]:
probe.fit(train_loader, epochs=1000)

Epoch 100: Train Loss = 0.1101, Train Acc = 0.9555
Epoch 200: Train Loss = 0.0870, Train Acc = 0.9611
Epoch 300: Train Loss = 0.0785, Train Acc = 0.9657
Epoch 400: Train Loss = 0.0741, Train Acc = 0.9677
Epoch 500: Train Loss = 0.0703, Train Acc = 0.9688
Epoch 600: Train Loss = 0.0683, Train Acc = 0.9691
Epoch 700: Train Loss = 0.0665, Train Acc = 0.9703
Epoch 800: Train Loss = 0.0656, Train Acc = 0.9704
Epoch 900: Train Loss = 0.0648, Train Acc = 0.9717
Epoch 1000: Train Loss = 0.0640, Train Acc = 0.9715
Final Train Acc: 0.9715


In [50]:
from plots import plot_metrics

In [51]:
metrics = {
    "Train Accuracy": probe.train_accs,
}
plot_metrics(metrics)

In [52]:
import pickle
with open(f'checkpoints/tqa_lying_post_gen_probe_{config["short_name"]}.pkl', 'wb') as f:
    pickle.dump(probe, f)