In [None]:
import torch
import circuitsvis as cv
import pickle
import warnings


from transformer_lens import HookedTransformer


if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
#from transformer_lens.cautils.notebook import *
torch.set_grad_enabled(False)

from utils.cspa_functions import (
    get_cspa_results_batched,
    get_result_mean
)
from utils.cspa_extra_utils import (
    process_webtext,
)

clear_output()

In [None]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda",
)
model.set_use_split_qkv_input(False)
model.set_use_attn_result(True)
clear_output()

In [None]:
BATCH_SIZE = 500 # 80 for viz
SEQ_LEN = 1000 # 61 for viz

current_batch_size = 17 # These are smaller values we use for vizualization since only these appear on streamlit
current_seq_len = 61

NEGATIVE_HEADS = [(10, 7), (11, 10)]
DATA_TOKS, DATA_STR_TOKS_PARSED, indices = process_webtext(seed=6, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, model=model, verbose=True, return_indices=True)

In [None]:
USE_SEMANTICITY = True

In [None]:
if USE_SEMANTICITY:
    cspa_semantic_dict = pickle.load(open("cspa/cspa_semantic_dict_full.pkl", "rb"))

else:
    warnings.warn("Not using semanticity unlike old notebook versions!")
    cspa_semantic_dict = {}

In [None]:
# Finally, let's save a mean for later use...

result_mean = get_result_mean([(10, 7), (11, 10)], DATA_TOKS[:100, :], model, verbose=True)
# t.save(result_mean, f"/home/ubuntu/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/st_page/media/result_mean.pt")

## Run Experiment

In [None]:
# Empirically, as long as SEQ_LEN large, small BATCH_SIZE gives quite good estimates
QK_OV_BATCH_SIZE = 20
QK_OV_SEQ_LEN = 600

cspa_results_qk_ov = get_cspa_results_batched(
    model = model,
    toks = DATA_TOKS[:QK_OV_BATCH_SIZE, :QK_OV_SEQ_LEN],
    max_batch_size = 1, # 50,
    negative_head = (10, 7),
    interventions = ["ov", "qk"],
    K_unembeddings = 0.05, # most interesting in range 3-8 (out of 80)
    K_semantic = 1, # either 1 or up to 8 to capture all sem similar
    semantic_dict = cspa_semantic_dict,
    result_mean = result_mean,
    use_cuda = True,
    verbose = True,
    compute_s_sstar_dict = False,
    computation_device = "cpu", # device
)