In [2]:
# Replicate ITI results, make sure ITI utils and probing utils work right

#%%
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the TransformerLens code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")
    
import plotly.io as pio
# pio.renderers.default = "png"
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

from tqdm import tqdm
from utils.probing_utils import ModelActs
from utils.dataset_utils import CounterFact_Dataset, TQA_MC_Dataset, EZ_Dataset

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from utils.iti_utils import patch_iti

from utils.analytics_utils import plot_probe_accuracies, plot_norm_diffs, plot_cosine_sims

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [3]:

device = "cuda"
print("loading model")
model = HookedTransformer.from_pretrained(
    "gpt2-xl",
    center_unembed=False,
    center_writing_weights=False,
    fold_ln=False,
    refactor_factored_attn_matrices=True,
    device=device,
)
# model.to(device)
print("done")
model.set_use_attn_result(True)
model.cfg.total_heads = model.cfg.n_heads * model.cfg.n_layers

model.reset_hooks()

loading model


Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-xl into HookedTransformer
done


In [4]:
from utils.dataset_utils import MS_Dataset, Elem_Dataset, MisCons_Dataset, Kinder_Dataset, HS_Dataset, BoolQ_Question_Dataset, TruthfulQA_Tfn, CounterFact_Tfn, Fever_Tfn, BoolQ_Tfn, Creak_Tfn, CommonClaim_Tfn
random_seed = 5

datanames = ["MS", "Elem", "MisCons", "Kinder", "HS", "TruthfulQA", "CounterFact", "Fever", "Creak", "BoolQ", "CommonClaim"]

ms_data = MS_Dataset(model.tokenizer, questions=True)
elem_data = Elem_Dataset(model.tokenizer, questions=True)
miscons_data = MisCons_Dataset(model.tokenizer, questions=True)
kinder_data = Kinder_Dataset(model.tokenizer, questions=True)
hs_data = HS_Dataset(model.tokenizer, questions=True)
# boolq_data = BoolQ_Question_Dataset(model.tokenizer)

tqa_data = TruthfulQA_Tfn(model.tokenizer, questions=True)
cfact_data = CounterFact_Tfn(model.tokenizer, questions=True)
fever_data = Fever_Tfn(model.tokenizer, questions=True)
boolq_data = BoolQ_Tfn(model.tokenizer, questions=True)
creak_data = Creak_Tfn(model.tokenizer, questions=True)
commonclaim_data = CommonClaim_Tfn(model.tokenizer, questions=True)

datasets = {"MS": ms_data, "Elem": elem_data, "MisCons": miscons_data, "Kinder": kinder_data, "HS": hs_data, "TruthfulQA": tqa_data, "CounterFact": cfact_data, "Fever": fever_data, "Creak": creak_data, "BoolQ": boolq_data, "CommonClaim": commonclaim_data}


Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--elem_tf-64ec49cd4cd5be64/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--ms_tf-728c6138d8f6c1c5/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--misconceptions_tf-131f43b181040ffa/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--kindergarten_tf-e3c53e366bc35ec1/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--hs_tf-9f911d9357ff2386/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--truthfulness-4380c84abeab6c8f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--truthfulness-4380c84abeab6c8f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--truthfulness-4380c84abeab6c8f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--truthfulness-4380c84abeab6c8f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--truthfulness-4380c84abeab6c8f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/notrichardren___parquet/notrichardren--truthfulness-4380c84abeab6c8f/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
for dataname in datanames:
    print(f"for {dataname}, prompts in form {model.tokenizer.batch_decode(datasets[dataname].all_prompts[5])}")

for MS, prompts in form ['Is the below statement true or false? Cows produce milk. Answer:']
for Elem, prompts in form ['Is the below statement true or false? Cows are herbivores. Answer:']
for MisCons, prompts in form ['Is the below statement true or false? Dolphins are fish. Answer:']
for Kinder, prompts in form ['Is the below statement true or false? Bats sleep during the day. Answer:']
for HS, prompts in form ['Is the below statement true or false? Elephants are the largest land animals. Answer:']
for TruthfulQA, prompts in form ['Is the below statement true or false? What happens to you if you eat watermelon seeds? You grow watermelons in your stomach. Answer:']
for CounterFact, prompts in form ['Is the below statement true or false? Toko Yasuda, the piano. Answer:']
for Fever, prompts in form ['Is the below statement true or false? The Boston Celtics play their home games at TD Garden. Answer:']
for Creak, prompts in form ['Is the below statement true or false? The crack in the L

In [6]:
n_acts = 1000
acts = {}

for name in datanames:
    acts[name] = ModelActs(model, datasets[name], act_types=["z", "mlp_out", "resid_post", "resid_pre", "logits"])
    model_acts: ModelActs = acts[name]
    # model_acts.gen_acts(N=n_acts, id=f"{name}_gpt2xl_{n_acts}")
    model_acts.load_acts(id=f"{name}_gpt2xl_{n_acts}", load_probes=False)
    model_acts.train_probes("z", max_iter=1000)


torch.Size([800, 1200, 64]), torch.Size([200, 1200, 64]), torch.Size([800, 1200]), torch.Size([200, 1200])


100%|██████████| 1200/1200 [00:26<00:00, 45.93it/s]


torch.Size([800, 1200, 64]), torch.Size([200, 1200, 64]), torch.Size([800, 1200]), torch.Size([200, 1200])


100%|██████████| 1200/1200 [00:29<00:00, 40.77it/s]


torch.Size([800, 1200, 64]), torch.Size([200, 1200, 64]), torch.Size([800, 1200]), torch.Size([200, 1200])


100%|██████████| 1200/1200 [00:37<00:00, 32.34it/s]


In [6]:
## Check Probe Transfer Accuracy
act_type = "z"

# for probe_idx, probe_source in enumerate(datanames):
    # acts[probe_source].train_probes(act_type, max_iter=10000)
 
probe_transfer_accs = torch.zeros(size=(len(datanames), len(datanames)))
for probe_idx, probe_source in enumerate(datanames):
    # acts[probe_source].train_probes(act_type)
    for data_idx, data_source in enumerate(datanames):
        accs = acts[probe_source].get_transfer_acc(act_type, acts[data_source])
        probe_transfer_accs[probe_idx, data_idx] = accs.mean(axis=-1)

torch.Size([800, 48, 1600]), torch.Size([200, 48, 1600]), torch.Size([800, 48]), torch.Size([200, 48])


100%|██████████| 48/48 [01:45<00:00,  2.20s/it]


torch.Size([800, 48, 1600]), torch.Size([200, 48, 1600]), torch.Size([800, 48]), torch.Size([200, 48])


100%|██████████| 48/48 [02:21<00:00,  2.95s/it]


torch.Size([800, 48, 1600]), torch.Size([200, 48, 1600]), torch.Size([800, 48]), torch.Size([200, 48])


100%|██████████| 48/48 [01:23<00:00,  1.75s/it]


torch.Size([800, 48, 1600]), torch.Size([200, 48, 1600]), torch.Size([800, 48]), torch.Size([200, 48])


100%|██████████| 48/48 [02:03<00:00,  2.58s/it]


torch.Size([800, 48, 1600]), torch.Size([200, 48, 1600]), torch.Size([800, 48]), torch.Size([200, 48])


100%|██████████| 48/48 [02:16<00:00,  2.85s/it]


torch.Size([800, 48, 1600]), torch.Size([200, 48, 1600]), torch.Size([800, 48]), torch.Size([200, 48])


100%|██████████| 48/48 [02:34<00:00,  3.22s/it]


torch.Size([800, 48, 1600]), torch.Size([200, 48, 1600]), torch.Size([800, 48]), torch.Size([200, 48])


 46%|████▌     | 22/48 [01:06<01:18,  3.03s/it]


In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt


df = pd.DataFrame(probe_transfer_accs, columns=datanames, index=datanames)

plt.figure(figsize=(10, 7))  # Set the figure size
sns.heatmap(df, annot=True, cmap='viridis')  # Create a heatmap
plt.title(f"Generalization of Probes on {act_type} across Datasets, Means of Probe Accuracies ")
plt.xlabel("Probe Testing Dataset")
plt.ylabel("Probe Training Dataset")
plt.show()

In [None]:
import utils.analytics_utils as analytics

for name in datanames:
    analytics.plot_probe_accuracies(acts[name]).show()

In [None]:
# Get difference in positive and negative logits

# Should be same number of positive and negative tokens
positive_str_tokens = ["Yes", "yes", " Yes", " yes", "True", "true", " True", " true"]
negative_str_tokens = ["No", "no", " No", " no", "False", "false", " False", " false"]

positive_tokens = [model.tokenizer(token).input_ids[0] for token in positive_str_tokens]
negative_tokens = [model.tokenizer(token).input_ids[0] for token in negative_str_tokens]

def tot_logit_diff(model_acts, use_probs=False, eps=1e-8, test_only=True, act_type="z", check_balanced_output=False):

    if test_only:
        sample_labels = np.array(model_acts.dataset.all_labels)[model_acts.indices_tests[act_type]] # labels
        positive_sum = torch.empty(size=(model_acts.indices_tests[act_type].shape[0],))
        negative_sum = torch.empty(size=(model_acts.indices_tests[act_type].shape[0],))
        meta_indices = np.array([np.where(model_acts.indices == i)[0][0] for i in model_acts.indices_tests["z"]])

    
    else:
        sample_labels = np.array(model_acts.dataset.all_labels)[model_acts.indices] # labels
        positive_sum = torch.empty(size=(model_acts.indices.shape[0],))
        negative_sum = torch.empty(size=(model_acts.indices.shape[0],))
        meta_indices = np.arange(model_acts.indices.shape[0],)
    
    check_positive_prop = 0

    for idx, logits in enumerate(model_acts.stored_acts["logits"][meta_indices]):

        # if answer to statement is True, correct tokens is Yes, yes, ..., true
        correct_tokens = positive_tokens if sample_labels[idx] else negative_tokens
        incorrect_tokens = negative_tokens if sample_labels[idx] else positive_tokens
        
        check_positive_prop += 1 if sample_labels[idx] else 0

        if check_balanced_output:
            correct_tokens = positive_tokens
            incorrect_tokens = negative_tokens


        if use_probs:
            probs = torch.nn.functional.softmax(logits, dim=1)
            positive_prob = probs[0, correct_tokens].sum(dim=-1)
            negative_prob = probs[0, incorrect_tokens].sum(dim=-1)
            positive_sum[idx] = positive_prob #/ (positive_prob + negative_prob + eps)
            negative_sum[idx] = negative_prob #/ (positive_prob + negative_prob + eps)

        else:
            positive_sum[idx] = logits[0, correct_tokens].sum(dim=-1)
            negative_sum[idx] = logits[0, incorrect_tokens].sum(dim=-1)

    print(f"proportion of positive labels is {check_positive_prop/len(meta_indices)}")
    return positive_sum, negative_sum

for name in datanames:
    positive_sum, negative_sum = tot_logit_diff(acts[name], use_probs=True, test_only=True, check_balanced_output=False)
    print(f"{name}, {positive_sum.mean()}, {negative_sum.mean()}, {(positive_sum.mean() - negative_sum.mean())*100}%")


In [None]:
sum(datasets["TruthfulQA"].sample(100)[2])

In [None]:
from sklearn.metrics import accuracy_score

def get_datapoint_probe_acc(model_acts, point, label, act_type):
    """
    Get probe accuracy on one particular datapoint (one sample)
    Label is 0 or 1 (true or false)
    Datapoint is shape (num_probes, d_probe) (e.g. stored_acts["z"][i])
    """
    probe_accs = []
    # print(f"{point.shape=}")
    for idx, probe in enumerate(model_acts.probes[act_type]):
        y_pred = probe.predict_proba(point[idx].numpy().reshape(1, -1))
        # print(f"{label=}, {y_pred=}")
        # print(y_pred)
        # probe_accs.append(accuracy_score([label], y_pred))
        if label == 1.0:
            # then 1st is good
            probe_accs.append(y_pred[0][1])
        else:
            probe_accs.append(y_pred[0][0])
    return probe_accs

print(f"Average stored probe acc is {acts['MS'].probe_accs['z'].mean()}")
all_accs = 0
for i in range(100):
    point = einops.rearrange(acts["MS"].stored_acts["z"][i], "n_l n_h d_p -> (n_l n_h) d_p")
    label = acts["MS"].dataset.all_labels[acts["MS"].indices[i]]
    probe_accs = get_datapoint_probe_acc(acts["MS"], point, label, "z")
    all_accs += sum(probe_accs)/len(probe_accs)
print(all_accs / 100)



# "True or False: The first president of the United States was George Washington." label: True
# "True or False: The first president of the United States was John Adam." label: False


# "Tell me if the prompt is True or False \n The first president of the United States was George Washington. \n Print "true" fi it's true and "fFalse" if it's false" label: True
# "True or False: The first president of the United States was John Adam." label: False



In [None]:
from utils.iti_utils import patch_iti
n_acts = 1000

iti_acts = {}
for name in datanames:
    iti_acts[name] = ModelActs(model, datasets[name], act_types=["z", "mlp_out", "resid_post", "resid_pre","logits"])
    cache_interventions = torch.zeros(size=(model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head))
    patch_iti(model, acts[name], topk=50, alpha=20, use_MMD=True, model_device="cuda", cache_interventions=cache_interventions, train_only=True)
    model_acts: ModelActs = iti_acts[name]

    # if name == "BoolQ":
    model_acts.gen_acts(N=n_acts, id=f"iti_{name}_gpt2xl_{n_acts}", indices=acts[name].indices)
    # else:
    #     model_acts.load_acts(id=f"iti_{name}_gpt2xl_{n_acts}", load_probes=False)
    # model_acts.train_probes("z", max_iter=1000)
    model_acts.indices_trains = acts[name].indices_trains
    model_acts.indices_tests = acts[name].indices_tests

    print(cache_interventions.sum())

In [None]:
# plot graph of model accuracy differential vs probe accuracy
import matplotlib.pyplot as plt

colors=['b', 'r']
names=["default", "iti"]
probe_acc_dict = {}
for idx, act_list in enumerate([acts, iti_acts]):
    model_acc_diffs = []
    probe_accs = []
    for name in datanames:
        positive_sum, negative_sum = tot_logit_diff(act_list[name], use_probs=True)
        model_acc_diffs.append(positive_sum.mean() - negative_sum.mean())

        probe_accs.append(acts[name].probe_accs["z"].mean())
        probe_acc_dict[name] = round(acts[name].probe_accs["z"].mean(), 3)

    plt.scatter(probe_accs, model_acc_diffs, c=colors[idx], label=names[idx])

plt.title("Model performance on datasets before/after ITI")
plt.xlabel("Probe Accuracy")
plt.ylabel("Model Accuracy Differential")
plt.legend()
plt.show()

print(probe_acc_dict)

In [None]:
acts["TruthfulQA"].y_tests["z"][:,0].sum()

In [None]:
sum(np.array(acts["TruthfulQA"].dataset.all_labels)[acts["TruthfulQA"].indices])

In [None]:
np.array(acts["TruthfulQA"].dataset.all_labels)[acts["TruthfulQA"].indices_tests["z"]].sum()

In [None]:
sum(acts["BoolQ"].dataset.sample(1000)[2])

In [None]:
acts["BoolQ"].dataset.all_prompts == iti_acts["BoolQ"].dataset.all_prompts

In [None]:
np.array(acts["BoolQ"].dataset.all_labels)[acts["BoolQ"].dataset.sample(1000)[0]].sum()

In [None]:
from utils.gpt_judge import get_iti_scores
# Try GPT-Judge output, measure truthfulness and info before/after ITI

truth_scores = []
info_scores = []
truth_scores_iti = []
info_scores_iti = []

for name in tqdm(datanames):
    truth_score, info_score, truth_score_iti, info_score_iti, _, _ = get_iti_scores(model, datasets[name], alpha=1, topk=1200, existing_acts=acts[name])
    truth_scores.append(truth_score)
    info_scores.append(info_score)
    truth_scores_iti.append(truth_score_iti)
    info_scores_iti.append(info_score_iti)

In [None]:
# Measure truthfulness before/after ITI
import matplotlib.pyplot as plt

colors=['b', 'r']
names=["default", "iti"]
probe_acc_dict = {}
for idx, score_list in enumerate([truth_scores, truth_scores_iti]):
    model_acc_diffs = []
    probe_accs = []
    for name_idx, name in enumerate(datanames):
        model_acc_diffs.append(score_list[name_idx])
        probe_accs.append(acts[name].probe_accs["z"].mean())
        probe_acc_dict[name] = round(acts[name].probe_accs["z"].mean(), 3)

    plt.scatter(probe_accs, model_acc_diffs, c=colors[idx], label=names[idx])

plt.title("Truth Evaluation of Models on Datasets")
plt.xlabel("Probe Accuracy")
plt.ylabel("Model Truth Evaluation (GPT-Judge)")
plt.legend()
plt.show()

print(probe_acc_dict)

In [None]:
# Measure truthfulness before/after ITI
import matplotlib.pyplot as plt

colors=['b', 'r']
names=["default", "iti"]
probe_acc_dict = {}
for idx, score_list in enumerate([info_scores, info_scores_iti]):
    model_acc_diffs = []
    probe_accs = []
    for name_idx, name in enumerate(datanames):
        model_acc_diffs.append(score_list[name_idx])
        probe_accs.append(acts[name].probe_accs["z"].mean())
        probe_acc_dict[name] = round(acts[name].probe_accs["z"].mean(), 3)

    plt.scatter(probe_accs, model_acc_diffs, c=colors[idx], label=names[idx])

plt.title("Informative Evaluation of Models on Datasets")
plt.xlabel("Probe Accuracy")
plt.ylabel("Model Info Evaluation (GPT-Judge)")
plt.legend()
plt.show()

print(probe_acc_dict)

In [None]:
import pickle
# Test ITI generalization
n_acts = 1000
# First, plot the old scatters, default accs and iti on same data
colors=['b', 'r']
names=["default", "iti_on_same_data"]
for idx, act_list in enumerate([acts, iti_acts]):
    model_acc_diffs = []
    probe_accs = []
    for name in datanames:
        act_list[name].indices_trains = acts[name].indices_trains
        act_list[name].indices_tests = acts[name].indices_tests

        positive_sum, negative_sum = tot_logit_diff(act_list[name], use_probs=True)
        model_acc_diffs.append(positive_sum.mean() - negative_sum.mean())

        probe_accs.append(acts[name].probe_accs["z"].mean())

    plt.scatter(probe_accs, model_acc_diffs, c=colors[idx], label=names[idx])

# next, plot scatter after doing ITI from one dataset, model acc from others
colors=['g', 'y', 'pink', 'purple', 'black', 'cyan']
names=[f"iti on {name}" for name in datanames]
print("done with old")

# all_iti_acts = {}
for idx, iti_name in enumerate(datanames):
    iti_acts_temp = {}
    # regenerate ITI 
    # cache_interventions = torch.zeros(size=(model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head))
    
    probe_accs = []
    model_acc_diffs = []
    patch_iti(model, acts[iti_name], topk=50, alpha=20, use_MMD=True, model_device="cuda", train_only=True)
    for data_name in datanames:
        iti_acts_temp[data_name] = ModelActs(model, datasets[data_name], act_types = ["z", "mlp_out", "logits"])
                                            #  act_types=["z", "mlp_out", "resid_post", "resid_pre", "logits"])
        model_acts: ModelActs = iti_acts_temp[data_name]
        model_acts.gen_acts(N=n_acts, id=f"iti_{iti_name}_data_{data_name}_gpt2xl_{n_acts}", indices=acts[data_name].indices)
        # model_acts.load_acts(id=f"iti_{iti_name}_data_{data_name}_gpt2xl_{n_acts}", load_probes=False)
        # model_acts.train_probes("z", max_iter=1000)
        model_acts.indices_trains = acts[data_name].indices_trains
        model_acts.indices_tests = acts[data_name].indices_tests
        
        positive_sum, negative_sum = tot_logit_diff(model_acts, use_probs=True)
        model_acc_diffs.append(positive_sum.mean() - negative_sum.mean())
        probe_accs.append(acts[data_name].probe_accs["z"].mean())
    
    # all_iti_acts[iti_name] = iti_acts_temp
    # with open(f'activations/iti_{iti_name}_all_modelacts.pickle', 'wb') as handle:
    #     pickle.dump(iti_acts_temp, handle, protocol=pickle.HIGHEST_PROTOCOL)

    plt.scatter(probe_accs, model_acc_diffs, c=colors[idx], label=names[idx])


plt.xlabel("Probe Accuracy")
plt.ylabel("Model Accuracy Differential")
plt.legend()
plt.show()

In [None]:
# plot graph of model accuracy differential vs probe accuracy


for name in datanames:

    probe_accs = []
    positive_sum, negative_sum = tot_logit_diff(acts[name], use_probs=True)
    model_acc_diffs = (positive_sum - negative_sum)
    for i in tqdm(range(1000)):
        point = einops.rearrange(acts["MS"].stored_acts["z"][i], "n_l n_h d_p -> (n_l n_h) d_p")
        label = acts["MS"].dataset.all_labels[acts["MS"].indices[i]]
        probe_acc = get_datapoint_probe_acc(acts["MS"], point, label, "z")
        probe_accs.append(sum(probe_acc)/len(probe_acc))
        
    plt.scatter(model_acc_diffs, probe_accs)

    break
plt.xlabel("Model Accuracy Differential")
plt.ylabel("Probe Accuracy")
plt.show()

In [None]:
for name in datanames:
    # acts[name] = ModelActs(model, datasets[name], act_types=["z", "mlp_out", "resid_post", "resid_pre", "result", "logits"])
    model_acts: ModelActs = acts[name]
    model_acts.train_probes("resid_post", max_iter=10000)

In [None]:
def get_data_probe_accs(model_acts, points, labels, act_type):
    """
    Get probe accuracy over multiple datapoints
    points in shape (batch, num_probes, d_probe)
    """
    # print(f"{point.shape=}")
    all_probe_accs = []
    for idx, probe in enumerate(model_acts.probes[act_type]):
        probe_accs = []
        y_pred = probe.predict_proba(points[:,idx].numpy())

        for data_idx, label in enumerate(labels):
            if label == 1.0:
                # then 1st is good
                probe_accs.append(y_pred[data_idx][1])
            else:
                probe_accs.append(y_pred[data_idx][0])
        all_probe_accs.append(probe_accs)

    return all_probe_accs


# colors = {"MS": 'b', "Elem": 'g', "MisCons": 'r', "Kinder": 'purple', "HS": 'black', "BoolQ": 'cyan'}
for name in tqdm(datanames):

    positive_sum, negative_sum = tot_logit_diff(acts[name], use_probs=True)
    model_acc_diffs = (positive_sum - negative_sum)

    X_test = acts["MS"].X_tests["z"]# einops.rearrange(acts["MS"].X_tests["z"], "b n_l n_h d_p -> b (n_l n_h) d_p")
    label = torch.tensor(acts["MS"].dataset.all_labels)[acts["MS"].indices_tests["z"]]
    # print(label.shape)
    # print(X_test.shape)
    probe_accs = torch.tensor(get_data_probe_accs(acts[name], X_test, label, "z"))
    print(model_acc_diffs.shape)
    print(probe_accs.shape)
    # probe_accs.append(sum(probe_acc)/len(probe_acc))
        
    plt.scatter(model_acc_diffs, probe_accs.mean(dim=0), label=name, s=2)

    # break
plt.title("Testing Probes from each Datasource on MS Test Data")
plt.xlabel("Model Accuracy Differential")
plt.ylabel("Probe Accuracy")
plt.legend()
plt.show()

In [None]:
name = "MS"
colors = ['r', 'cyan']
names = ["Before ITI", "After ITI"]

for idx, act_list in enumerate([acts, iti_acts]):

    positive_sum, negative_sum = tot_logit_diff(act_list[name], use_probs=True)
    model_acc_diffs = (positive_sum - negative_sum)

    X_test = acts["MS"].X_tests["z"]# einops.rearrange(acts["MS"].X_tests["z"], "b n_l n_h d_p -> b (n_l n_h) d_p")
    label = torch.tensor(acts["MS"].dataset.all_labels)[acts["MS"].indices_tests["z"]]
    probe_accs = torch.tensor(get_data_probe_accs(acts[name], X_test, label, "z"))
    print(model_acc_diffs.shape)
    print(probe_accs.shape)
    # probe_accs.append(sum(probe_acc)/len(probe_acc))
        
    plt.scatter(model_acc_diffs, probe_accs.mean(dim=0), color=colors[idx], label=names[idx], alpha=.5, s=2)

    # break
plt.title(f"Testing Probes before ITI and after ITI on {name} Test Data")
plt.xlabel("Model Accuracy Differential")
plt.ylabel("Probe Accuracy")
plt.legend()
plt.show()

In [None]:
colors = {"MS": 'b', "Elem": 'g', "MisCons": 'r', "Kinder": 'purple', "HS": 'black', "BoolQ": 'yellow'}
for name in tqdm(datanames):

    positive_sum, negative_sum = tot_logit_diff(acts[name], use_probs=True)
    model_acc_diffs = (positive_sum - negative_sum)

    X_test = acts["MS"].stored_acts["resid_post"]
    label = torch.tensor(acts["MS"].dataset.all_labels)[acts["MS"].indices]
    # print(label.shape)
    # print(X_test.shape)
    probe_accs = torch.tensor(get_data_probe_accs(acts["MS"], X_test, label, "resid_post"))
    # print(probe_accs.shape)
    # probe_accs.append(sum(probe_acc)/len(probe_acc))
        
    plt.scatter(model_acc_diffs, probe_accs.mean(dim=0), color=colors[name], label=name, alpha=.5, s=2)

    # break
plt.xlabel("Model Accuracy Differential")
plt.ylabel("Probe Accuracy")
plt.show()

In [None]:
from utils.gpt_judge import get_iti_scores
plots = []

np.seterr(all="ignore")
for name in datanames:
    get_iti_scores(model, datasets[name], alpha=20, topk=50)

In [None]:
from utils.gpt_judge import get_iti_scores
plots = []

for name in datanames:
    get_iti_scores(model, datasets[name], alpha=1, topk=1200)

In [None]:
from plotly.subplots import make_subplots
from utils.gpt_judge import check_iti_generalization
plots = []

for name in datanames:
    model_acts: ModelActs = acts[name]
    for other_name in datanames:
        print(f"Checking generation on {name}, ITI on {other_name}")
        results = check_iti_generalization(model, datasets[name], datasets[other_name], 50, 1000, alpha=1, topk=1200, existing_gen_acts=model_acts)
        print(f"Truth score before ITI: {results[0]}, Truth score after ITI: {results[2]}")
        print(f"Info score before ITI: {results[1]}, Info score after ITI: {results[3]}")
        print()

        break
    break

        # transfer_accs = model_acts.get_transfer_acc(acts[other_name])
        # plots.append(plot_probe_accuracies(model_acts, sorted=False, title=f"{name} probes on {other_name} data", other_head_accs=transfer_accs).show())

In [None]:
from plotly.subplots import make_subplots
from utils.gpt_judge import check_iti_generalization
plots = []

np.seterr(all="ignore")
for name in datanames[1:]:
    model_acts: ModelActs = acts[name]
    for other_name in datanames:
        print(f"Checking generation on {name}, ITI on {other_name}")
        results = check_iti_generalization(model, datasets[name], datasets[other_name], 50, 1000, alpha=10, existing_gen_acts=acts[name])
        print(f"Truth score before ITI: {results[0]}, Truth score after ITI: {results[2]}")
        print(f"Info score before ITI: {results[1]}, Info score after ITI: {results[3]}")
        print()

        # transfer_accs = model_acts.get_transfer_acc(acts[other_name])
        # plots.append(plot_probe_accuracies(model_acts, sorted=False, title=f"{name} probes on {other_name} data", other_head_accs=transfer_accs).show())

In [None]:
import plotly.graph_objects as go

fig_combined = make_subplots(rows=3, cols=3)

for i, fig in enumerate(plots):
    row = i // 3 + 1  # calculate the row index
    col = i % 3 + 1   # calculate the column index

    # Extract data from the individual figures and add it to the subplots
    for trace in fig.data:
        fig_combined.add_trace(
            go.Heatmap(
                z=trace.z,
                x0=trace.x[0],
                dx=trace.x[1] - trace.x[0],
                y0=trace.y[0],
                dy=trace.y[1] - trace.y[0],
                zmin=trace.zmin,
                zmax=trace.zmax,
                coloraxis=trace.coloraxis,
                showscale=False,
            ),
            row=row,
            col=col,
        )

# Add a colorbar that's common to all subplots
fig.update_layout(coloraxis=dict(colorscale='viridis', colorbar=dict(tickfont=dict(size=10))))

fig.show()

In [None]:
plots