In [1]:
# Replicate ITI results, make sure ITI utils and probing utils work right

#%%
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the TransformerLens code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")
    
import plotly.io as pio
# pio.renderers.default = "png"
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

from tqdm import tqdm
from probing_utils import ModelActs
from dataset_utils import CounterFact_Dataset, TQA_MC_Dataset, EZ_Dataset

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from iti_utils import patch_top_activations, patch_iti

from analytics_utils import plot_probe_accuracies, plot_norm_diffs, plot_cosine_sims

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:

device = "cuda"
print("loading model")
model = HookedTransformer.from_pretrained(
    "gpt2-xl",
    center_unembed=False,
    center_writing_weights=False,
    fold_ln=False,
    refactor_factored_attn_matrices=True,
    device=device,
)
# model.to(device)
print("done")
model.set_use_attn_result(True)
model.cfg.total_heads = model.cfg.n_heads * model.cfg.n_layers

model.reset_hooks()

loading model


Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-xl into HookedTransformer
done


In [3]:
random_seed = 5

datanames = ["tqa", "cfact", "ez"]

tqa_data = TQA_MC_Dataset(model.tokenizer, seed=random_seed)
cfact_data = CounterFact_Dataset(model.tokenizer, seed=random_seed)
ez_data = EZ_Dataset(model.tokenizer, seed=random_seed)

datasets = {"tqa":tqa_data, "cfact":cfact_data, "ez":ez_data}

Found cached dataset truthful_qa (/root/.cache/huggingface/datasets/truthful_qa/multiple_choice/1.1.0/63502f6bc6ee493830ce0843991b028d0ab568d221896b2ee3b8a5dfdaa9d7f4)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset parquet (/root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--counterfact-tracing-39c4f800d46af5cf/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

Found cached dataset csv (/root/.cache/huggingface/datasets/csv/default-37473b04ab8b95b9/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


  0%|          | 0/1 [00:00<?, ?it/s]

In [26]:
n_acts = 200
acts = {}

for name in datanames:
    acts[name] = ModelActs(model, datasets[name], act_types=["z", "mlp_out", "result"])
    model_acts: ModelActs = acts[name]
    model_acts.gen_acts(N=n_acts, id=f"{name}_gpt2xl_{n_acts}")
    break
    # ez_acts.load_acts(id=f"ez_gpt2xl_{n_acts}", load_probes=False)
    # model_acts.train_probes(max_iter=1000)


100%|██████████| 200/200 [00:14<00:00, 13.39it/s]


Stored at tqa_gpt2xl_200


In [22]:
model_acts.load_acts(id=f"tqa_gpt2xl_{n_acts}", load_probes=False)

In [25]:
model_acts.stored_acts['z'].shape

torch.Size([200, 48, 25, 64])

In [28]:
acts["tqa"].train_z_probes()

torch.Size([160, 1200, 64]), torch.Size([40, 1200, 64]), torch.Size([160, 1200]), torch.Size([40, 1200])


100%|██████████| 1200/1200 [00:08<00:00, 140.33it/s]


In [30]:
acts["tqa"].train_mlp_out_probes()

torch.Size([160, 48, 1600]), torch.Size([40, 48, 1600]), torch.Size([160, 48]), torch.Size([40, 48])


100%|██████████| 48/48 [00:01<00:00, 26.27it/s]


In [32]:
acts["tqa"].probe_accs["z"].mean()

0.5880833333333334

In [25]:
from plotly.subplots import make_subplots

plots = []

for name in datanames:
    model_acts: ModelActs = acts[name]
    for other_name in datanames:
        transfer_accs = model_acts.get_transfer_acc(acts[other_name])
        plots.append(plot_probe_accuracies(model_acts, sorted=False, title=f"{name} probes on {other_name} data", other_head_accs=transfer_accs).show())



1200it [00:00, 2524.73it/s]


1200it [00:00, 2415.03it/s]


1200it [00:00, 1420.32it/s]


1200it [00:00, 2422.11it/s]


1200it [00:00, 2481.90it/s]


1200it [00:00, 1824.89it/s]


1200it [00:00, 2253.89it/s]


1200it [00:00, 2460.75it/s]


1200it [00:00, 2572.37it/s]


In [28]:
import plotly.graph_objects as go

fig_combined = make_subplots(rows=3, cols=3)

for i, fig in enumerate(plots):
    row = i // 3 + 1  # calculate the row index
    col = i % 3 + 1   # calculate the column index

    # Extract data from the individual figures and add it to the subplots
    for trace in fig.data:
        fig_combined.add_trace(
            go.Heatmap(
                z=trace.z,
                x0=trace.x[0],
                dx=trace.x[1] - trace.x[0],
                y0=trace.y[0],
                dy=trace.y[1] - trace.y[0],
                zmin=trace.zmin,
                zmax=trace.zmax,
                coloraxis=trace.coloraxis,
                showscale=False,
            ),
            row=row,
            col=col,
        )

# Add a colorbar that's common to all subplots
fig.update_layout(coloraxis=dict(colorscale='viridis', colorbar=dict(tickfont=dict(size=10))))

fig.show()

AttributeError: 'NoneType' object has no attribute 'data'

In [29]:
plots

[None, None, None, None, None, None, None, None, None]