# Causal Experiments

In [None]:
import sys 
sys.path.append("../..")
sys.path.append("..")

from importlib import reload
from tqdm import tqdm

import joseph
from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *


reload(joseph.analysis)
reload(joseph.visualisation)
reload(joseph.utils)
reload(joseph.data)

from joseph.analysis import *
from joseph.visualisation import *
from joseph.utils import *
from joseph.data import *

# turn torch grad tracking off
torch.set_grad_enabled(False)


model = HookedTransformer.from_pretrained(
    "gpt2-small",
    # "tiny-stories-2L-33M",
    # "attn-only-2l",
    # center_unembed=True,
    # center_writing_weights=True,
    # fold_ln=True,
    # refactor_factored_attn_matrices=True,
)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)


path = "./artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152:v28/1100001280_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt"
sparse_autoencoder = SparseAutoencoder.load_from_pretrained(path)

print(sparse_autoencoder.cfg)


# sanity check
text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
model(text, return_type="loss")

# Dev / Single Example

In [None]:
prompt = "When John and Mary went to the shops, John gave the shopping to"
answer = " Mary"
# prompt = "All's fair in love and"
# answer = " war"
# prompt = " The cat is cute. The dog is"
# prompt = " Alice, with her keen intelligence and artistic talent, discussed philosophy with Bob, who shared her intellect and also possessed remarkable culinary skills, while"
# answer = " cute"
model.reset_hooks()
utils.test_prompt(prompt, answer, model)

HEAD_HOOK_RESULT_NAME = "blocks.10.attn.hook_z"
LAYER_IDX = sparse_autoencoder.cfg.hook_point_layer
HEAD_IDX = 7
def hook_to_ablate_head(head_output: Float[Tensor, "batch seq_len head_idx d_head"], hook: HookPoint, head = (LAYER_IDX, HEAD_IDX)):
    print(hook.layer(), hook.name)
    assert head[0] == hook.layer(), f"{head[0]} != {hook.layer()}"
    assert ("result" in hook.name) or ("q" in hook.name) or ("z" in hook.name)
    head_output[:, :, head[1], :] = 0
    return head_output

with model.hooks(fwd_hooks=[(HEAD_HOOK_RESULT_NAME, hook_to_ablate_head)]):
    utils.test_prompt(prompt, answer, model)

In [None]:
import joseph
reload(joseph.analysis)
from joseph.analysis import *


token_df, original_cache, cache_reconstructed_query, feature_acts = eval_prompt([prompt + answer], model, sparse_autoencoder, head_idx_override=7)
print(token_df.columns)
filter_cols = ["str_tokens", "unique_token", "context", "batch", "pos", "label", "loss", "loss_diff", "mse_loss", "num_active_features", "explained_variance", "kl_divergence",
               "top_k_features"]
token_df[filter_cols].tail().style.background_gradient(
    subset=["loss_diff", "mse_loss","explained_variance", "num_active_features", "kl_divergence"],
    cmap="coolwarm")


In [None]:
def get_max_attn_key(cache, token_df, layer_idx, head_idx):
    '''
    Given some cache and token_df, return a tensor with the key vectors
    which were most attended to by the head.
    
    '''
    keys = cache[utils.get_act_name("k",layer_idx)][:, :, head_idx, :].cpu()
    pos = torch.tensor(token_df["max_idx_pos"].values, dtype=torch.long)
    keys = keys[0, pos, :]
    return keys

keys = get_max_attn_key(original_cache, token_df, 10, 7)
print(keys.shape)

In [None]:
def plot_attn(patterns, token_df, title="", facet_col_labels = ["Original", "Reconstructed"]):
    '''
    # patterns_original = cache[utils.get_act_name("pattern", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
    # patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("pattern", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
    patterns_original = cache[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
    patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
    both_patterns = torch.stack([patterns_original, patterns_reconstructed])
    plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")
    
    '''
    fig = px.imshow(patterns, text_auto=".2f", title=title,
                    facet_col=0,
                    color_continuous_midpoint=0,
                    color_continuous_scale="RdBu",
                    )
    
    tickvals = np.arange(patterns.shape[2])
    ticktext = token_df["unique_token"].tolist()
    
    # add tokens as x-ticks and y-ticks, for each facet
    # Update x-ticks and y-ticks for each facet
    for i in range(len(facet_col_labels)):
        fig.update_xaxes(
            dict(tickmode='array', tickvals=tickvals, ticktext=ticktext),
            row=1, col=i+1
        )
        fig.update_yaxes(
            dict(tickmode='array', tickvals=tickvals, ticktext=ticktext),
            row=1, col=i+1
        )
    
    
    # add facet col labels:
    for i, label in enumerate(facet_col_labels):
        fig.layout.annotations[i].text = label
        fig.layout.annotations[i].font.size = 20
        
    fig.update_layout(
        width=1200,
        height=800,
    )
    fig.show()


LAYER_IDX = sparse_autoencoder.cfg.hook_point_layer
HEAD_IDX = 7
patterns_original = original_cache[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
both_patterns = torch.stack([patterns_original, patterns_reconstructed])
plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")
# patterns_original = original_cache[utils.get_act_name("pattern", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
# patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("pattern", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
# both_patterns = torch.stack([patterns_original, patterns_reconstructed])
# plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")

In [None]:

layer_idx = 10
head_idx = 7
tokens =model.to_tokens([prompt + answer])

original_act = original_cache[sparse_autoencoder.cfg.hook_point]
# token_df["q_norm"] = torch.norm(original_act, dim=-1)[:,1:].flatten().tolist()
sae_out, feature_acts, _, mse_loss, _ = sparse_autoencoder(original_act)
head_hook_query_name = utils.get_act_name("q", layer_idx)
head_hook_resid_name = utils.get_act_name("resid_pre", layer_idx)

@torch.no_grad()
def get_top_k_sae_approximation(sparse_autoencoder, feature_acts, top_k):
    top_k_features = torch.topk(feature_acts, top_k, dim=2, sorted=False)
    feature_acts_top_k = torch.zeros_like(feature_acts)
    feature_acts_top_k[:, :, top_k_features.indices[0]] = feature_acts[:, :, top_k_features.indices[0]]
    new_sae_out = (feature_acts_top_k @ sparse_autoencoder.W_dec) + sparse_autoencoder.b_dec
    return new_sae_out

# get the top k features by activation, and construct a new sae out 
for top_k in tqdm([1,3,5,10,30,50,60,100]):
    new_sae_out = get_top_k_sae_approximation(sparse_autoencoder, feature_acts, top_k)
    print((sae_out - new_sae_out).norm().item())


In [None]:

# need to generate query
def replacement_hook(resid_pre, hook, new_resid_pre=new_sae_out):
    return new_resid_pre
new_sae_out = get_top_k_sae_approximation(sparse_autoencoder, feature_acts, top_k=66)

model.reset_hooks()
with model.hooks(fwd_hooks=[(head_hook_resid_name, replacement_hook)]):
    _, top_k_sae_out_cache = model.run_with_cache(tokens, return_type="loss", loss_per_token=True)
    top_k_acts_queries = top_k_sae_out_cache[head_hook_query_name][:,:,head_idx]

print(top_k_acts_queries.shape)
print((sae_out - new_sae_out).norm().item())

In [None]:
LAYER_IDX =10
HEAD_IDX = 7
patterns_original = original_cache[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
patterns_reconstructed = cache_reconstructed_query[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
both_patterns = torch.stack([patterns_original, patterns_reconstructed])
plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")


patterns_original = cache_reconstructed_query[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
patterns_reconstructed = top_k_sae_out_cache[utils.get_act_name("attn_scores", LAYER_IDX)][0,HEAD_IDX].detach().cpu()
both_patterns = torch.stack([patterns_original, patterns_reconstructed])
plot_attn(both_patterns.detach().cpu(), token_df, title="Original and Reconstructed Attention Distribution")

# Full Distribution

In [None]:
data = get_webtext()

In [None]:
str_token_list = []
loss_list = []
ablated_loss_list = []
# data = get_webtext()

NUM_PROMPTS = 200
# MAX_PROMPT_LEN = 100
# BATCH_SIZE = 10
dataframe_list = []
feature_acts_list = []
with torch.no_grad():
    for i in tqdm(range(NUM_PROMPTS)):
        
        # Get Token Data
        prompt = model.to_string(model.to_tokens(data[i])[0,:128])
        token_df, _, _, feature_acts = eval_prompt(prompt, model, sparse_autoencoder, head_idx_override=7)
        feature_acts_list.append(feature_acts)
        dataframe_list.append(token_df)
        
all_token_df = pd.concat(dataframe_list)
all_token_df.reset_index(drop=True)
all_token_features = torch.cat(feature_acts_list)

print(all_token_df.shape)
print(all_token_df.columns)
all_token_df.head()