# Tuesday Analyze Hook Q

## Set Up

In [None]:
import os
import sys
import torch
import wandb
import json
import plotly.express as px
from transformer_lens import utils
from datasets import load_dataset
from typing import  Dict
from pathlib import Path

from functools import partial
sys.path.append("..")
from sae_training.utils import LMSparseAutoencoderSessionloader
from sae_analysis.visualizer import data_fns, html_fns
from sae_analysis.visualizer.data_fns import get_feature_data, FeatureData

if torch.backends.mps.is_available():
    device = "mps" 
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

torch.set_grad_enabled(False)

def imshow(x, **kwargs):
    x_numpy = utils.to_numpy(x)
    px.imshow(x_numpy, **kwargs).show()
    
    
from sae_training.sparse_autoencoder import SparseAutoencoder
# Load model from Huggingface


In [None]:
# import wandb
# run = wandb.init()
# artifact = run.use_artifact('jbloom/mats_sae_training_gpt2_small_hook_q/sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096:v16', type='model')
# artifact_dir = artifact.download()


# import wandb
# run = wandb.init()
# artifact = run.use_artifact('jbloom/mats_sae_training_gpt2_small_hook_q/sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096:v15', type='model')
# artifact_dir = artifact.download()
# run.finish()


import wandb
run = wandb.init()
artifact = run.use_artifact('jbloom/mats_sae_training_gpt2_small/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v56', type='model')
artifact_dir = artifact.download()



In [None]:
# # Load in Model
# path=artifact_dir+"/"+os.listdir(artifact_dir)[0]
# print(path)
from sae_training.utils import LMSparseAutoencoderSessionloader
from sae_training.config import LanguageModelSAERunnerConfig
path="../artifacts/sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576:v56/final_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_24576.pkl"
# path="../artifacts/sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096:v16/final_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096.pkl"#.gz"

from sae_training.sparse_autoencoder import CPU_Unpickler
with open(path, 'rb') as file:
    state_dict = CPU_Unpickler(file).load()

cfg = state_dict["cfg"].__dict__
cfg["device"] = "mps"
del cfg["d_sae"]
del cfg["tokens_per_buffer"]
from sae_training.sparse_autoencoder import SparseAutoencoder
cfg = LanguageModelSAERunnerConfig(**cfg)
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder.load_state_dict(state_dict["state_dict"])

In [None]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gpt2-small")

from sae_training.activations_store import ActivationsStore

activations_loader = ActivationsStore(
            cfg, model,
        )



## Test the Autoencoder

### L0 Test and Reconstruction Test

In [None]:
with torch.no_grad():
    batch_tokens = activations_loader.get_batch_tokens()
    print(batch_tokens.shape)
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
    activations =  cache[sparse_autoencoder.cfg.hook_point]
    
    sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )
    # del cache
    
    l2_norms_of_input = torch.norm(activations[:,1:], dim=-1)
    l2_norms_of_sae_out = torch.norm(sae_out[:,1:], dim=-1)
    print("l2_norms_of_input", l2_norms_of_input.mean().item())
    print("l2_norms_of_sae_out", l2_norms_of_sae_out.mean().item())
    
    l0 = (feature_acts > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

# Monday Stuff

In [None]:
# # cache.apply_ln_to_stack(x_reconstruct[0],layer=10).mean()
# example_batch = torch.randint(0,32,(1,)).item(); example_position = torch.randint(0, 10, (1,)).item()
# print(example_batch, example_position)
# print(model.to_str_tokens(batch_tokens[example_batch])[max(example_position-5,0):min(example_position+3,128)])
# px.line(feature_acts[example_batch,example_position].cpu().numpy()).show()
# lnd_activations = cache.apply_ln_to_stack(activations, layer=10)
# _, feature_acts_after_ln, _, _, _ = sparse_autoencoder(
#         lnd_activations
#     )
# px.line(feature_acts_after_ln[example_batch,example_position].cpu().numpy()).show()
# vals, inds = torch.topk(feature_acts[example_batch,example_position].detach(), 10)
# px.bar(
#     x=utils.to_numpy(vals),
#     y=[str(i.item()) for i in inds],
#     orientation="h",
# ).show()
# utils.test_prompt(
#     prompt = model.to_string(batch_tokens[example_batch][1:example_position+1]),
#     answer = model.to_string(batch_tokens[example_batch][example_position+1]),
#     model = model)


import pandas as pd

def plot_feature_unembed_bar(feature_id, sparse_autoencoder, feature_name = ""):
    
    # norm_unembed = model.W_U / model.W_U.norm(dim=0)[None: None]
    # feature_unembed = sparse_autoencoder.W_dec[feature_id] @ norm_unembed
    # feature_unembed = sparse_autoencoder.W_dec[feature_id] @ model.W_Q[10,7].T @  model.W_U
    feature_unembed = sparse_autoencoder.W_dec[feature_id] @  model.W_U
    # torch.topk(unembed_4795,10)

    feature_unembed_df = pd.DataFrame(
        feature_unembed.detach().cpu().numpy(),
        columns = [feature_name],
        index = [model.tokenizer.decode(i) for i in list(range(50257))]
    )

    feature_unembed_df = feature_unembed_df.sort_values(feature_name, ascending=False).reset_index().rename(columns={'index': 'token'})
    fig = px.bar(feature_unembed_df.head(20).sort_values(feature_name, ascending=True),
                 color_continuous_midpoint=0,
                 color_continuous_scale="RdBu",
            y = 'token', x = feature_name, orientation='h', color = feature_name, hover_data=[feature_name])

    fig.update_layout(
        width=500,
        height=600,
    )

    # fig.write_image(f"figures/{str(feature_id)}_{feature_name}.png")
    fig.show()


plot_feature_unembed_bar(3541, sparse_autoencoder, feature_name = str(3457))
# for i in inds:
#     plot_feature_unembed_bar(int(i), sparse_autoencoder, feature_name = str(i.item()))

## Specific Capability Test

Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks.

In [None]:
example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=False)

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder(
    cache[sparse_autoencoder.cfg.hook_point]
)

def reconstr_hook(mlp_out, hook, new_mlp_out):
    return new_mlp_out

def reconstr_key_hook(mlp_out, hook, reconstructed_key):
    return reconstructed_key

def reconstr_query_hook(mlp_out, hook, reconstructed_query):
    return reconstructed_query


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)

print("Orig", model(tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                sparse_autoencoder.cfg.hook_point,
                partial(reconstr_hook, new_mlp_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(sparse_autoencoder.cfg.hook_point, zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            sparse_autoencoder.cfg.hook_point,
            partial(reconstr_hook, new_mlp_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

In [None]:
feature_acts = feature_acts[0]
print((feature_acts[0,-1].detach() > 0 ).float().sum())
px.line(feature_acts[0,-1].detach().cpu().numpy()).show()
vals, inds = torch.topk(feature_acts[0,-1].detach().cpu(),10)
px.bar(x=[str(i) for i in inds], y=vals).show()

In [None]:
def plot_feature_unembed_bar(feature_id, sparse_autoencoder, feature_name = ""):
    
    # norm_unembed = model.W_U / model.W_U.norm(dim=0)[None: None]
    # feature_unembed = sparse_autoencoder.W_dec[feature_id] @ norm_unembed
    feature_unembed = sparse_autoencoder.W_dec[feature_id] @ model.W_Q[10,7].T @ model.W_U
    # torch.topk(unembed_4795,10)

    feature_unembed_df = pd.DataFrame(
        feature_unembed.detach().cpu().numpy(),
        columns = [feature_name],
        index = [model.tokenizer.decode(i) for i in list(range(50257))]
    )

    feature_unembed_df = feature_unembed_df.sort_values(feature_name, ascending=False).reset_index().rename(columns={'index': 'token'})
    fig = px.bar(feature_unembed_df.head(20).sort_values(feature_name, ascending=True),
                 color_continuous_midpoint=0,
                 color_continuous_scale="RdBu",
            y = 'token', x = feature_name, orientation='h', color = feature_name, hover_data=[feature_name])

    fig.update_layout(
        width=500,
        height=600,
    )

    # fig.write_image(f"figures/{str(feature_id)}_{feature_name}.png")
    fig.show()


for i in inds:
    plot_feature_unembed_bar(int(i), sparse_autoencoder, feature_name = str(i.item()))

In [None]:

reconstructed_q = cache_reconstructed_res_stream["blocks.10.attn.hook_q"].detach()
reconstructed_k = cache_reconstructed_res_stream["blocks.10.attn.hook_k"].detach()


def reconstr_key_hook(key, hook, reconstructed_key):
    return reconstructed_key

def reconstr_query_hook(query, hook, reconstructed_query):
    print("reconstr_query_hook", query.shape, reconstructed_query.shape)
    if query.shape == reconstructed_query.shape:
        return reconstructed_query
    else:
        new_query = torch.concat(
            [query[:,0].unsqueeze(1), reconstructed_query],
            dim=1
        )
        return new_query

model.reset_hooks()
with model.hooks(
    fwd_hooks=[
        (
            "blocks.10.attn.hook_q",
            partial(reconstr_query_hook, reconstructed_query=reconstructed_q),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
    _, cache_reconstructed_queries = model.run_with_cache(example_prompt, prepend_bos=True)

model.reset_hooks()

### Get the MSE Loss by Layer

In [None]:
def get_mse_loss_cache_df(original_cache, intervention_cache):
    mse_loss = lambda x, y: (x - y).pow(2).mean()
    keys = []
    values = []
    for key in cache_original.keys():
        keys.append(key)
        values.append(mse_loss(original_cache[key], intervention_cache[key]).item())
    df = pd.DataFrame({"key": keys, "mse_loss": values})
    
    # get the index of the first non-zero mse_loss
    first_non_zero_idx = df[df["mse_loss"] > 0].index[0]
    # filter from there onward
    df = df.iloc[first_non_zero_idx:]
    return df

df = get_mse_loss_cache_df(cache_original, cache_reconstructed_res_stream)
px.line(df, x="key", y="mse_loss").show()

df = get_mse_loss_cache_df(cache_original, cache_reconstructed_queries)
px.line(df, x="key", y="mse_loss").show()

df = get_mse_loss_cache_df(cache_original, cache_reconstructed_keys)
px.line(df, x="key", y="mse_loss").show()

### Visualize attn patterns

In [None]:
from circuitsvis.attention import attention_patterns
patterns_original = cache_original["blocks.10.attn.hook_pattern"].detach().cpu()[0]
attention_patterns(tokens=model.to_str_tokens(example_prompt), attention=patterns_original)

In [None]:
px.imshow(patterns_original[7].detach().cpu())

In [None]:
patterns_reconstructed = cache_reconstructed_queries["blocks.10.attn.hook_pattern"][0].detach().cpu()
px.imshow(patterns_reconstructed[7].detach().cpu())

In [None]:
# patterns_reconstructed = cache_reconstructed_res_stream["blocks.10.attn.hook_pattern"][0].detach().cpu()
# patterns_reconstructed = cache_reconstructed_keys["blocks.10.attn.hook_pattern"][0].detach().cpu()
patterns_reconstructed = cache_reconstructed_queries["blocks.10.attn.hook_pattern"][0].detach().cpu()
attention_patterns(tokens=model.to_str_tokens(example_prompt), attention=patterns_reconstructed)

### Visualize change in Attn Scores/Patterns

In [None]:
intervention_cache = cache_reconstructed_queries
scores_original = cache_original["blocks.10.attn.hook_attn_scores"][0].detach().cpu()
scores_reconstructed = intervention_cache["blocks.10.attn.hook_attn_scores"][0].detach().cpu()
patterns_original = cache_original["blocks.10.attn.hook_pattern"][0].detach().cpu()
patterns_reconstructed = intervention_cache["blocks.10.attn.hook_pattern"][0].detach().cpu()


import pandas as pd
import itertools
import numpy as np

def tensor_to_long_data_frame(tensor_result, dimension_names, value_name = "Score"):
    assert len(tensor_result.shape) == len(
        dimension_names
    ), "The number of dimension names must match the number of dimensions in the tensor"

    tensor_2d = tensor_result.reshape(-1).detach().cpu()
    df = pd.DataFrame(tensor_2d.detach().numpy(), columns=[value_name])

    indices = pd.MultiIndex.from_tuples(
        list(np.ndindex(tensor_result.shape)),
        names=dimension_names,
    )
    df.index = indices
    
    
    df.reset_index(inplace=True)
    # set all dimensions except Score to categorical
    for i in range(len(dimension_names)):
        df[dimension_names[i]] = df[dimension_names[i]].astype("category")
    
    return df

scores_df = tensor_to_long_data_frame(scores_original, ["Head", "Query", "Key"], value_name = "Original")
scores_df["Reconstructed"] = scores_reconstructed.flatten().detach().cpu().numpy()
scores_df["Pattern Original"] = patterns_original.flatten().detach().cpu().numpy()
scores_df["Pattern Reconstructed"] = patterns_reconstructed.flatten().detach().cpu().numpy()
scores_df = scores_df[scores_df["Original"] != float("inf")]
scores_df = scores_df[scores_df["Reconstructed"] != float("inf")]
scores_df.head()


In [None]:
# set pandas default width/height
fig = px.scatter(scores_df, x="Original", y="Reconstructed", color = "Head", hover_data=["Head", "Query", "Key"])
fig.update_layout(
    width=800,
    height=600,
)
fig.show()

In [None]:
px.scatter(scores_df, x="Pattern Original", y="Pattern Reconstructed", color = "Head", hover_data=["Head", "Query", "Key"], log_x=True, log_y=True)
fig.update_layout(
    width=800,
    height=600,
)
fig.show()

In [None]:
px.scatter(scores_df, x="Pattern Original", y="Pattern Reconstructed", color = "Key", hover_data=["Head", "Query", "Key"], log_x=True, log_y=True)
fig.update_layout(
    width=800,
    height=600,
)
fig.show()

In [None]:
import torch
import torch.nn.functional as F
from einops import einsum


def kl_divergence_attention(y_true, y_pred):

    # Compute log probabilities for KL divergence
    log_y_true = torch.log(y_true)
    log_y_pred = torch.log(y_pred)

    return y_true * (log_y_true - log_y_pred)


# Example usage
print(patterns_original.shape)
kl_result = kl_divergence_attention(patterns_original, patterns_reconstructed)
kl_result[kl_result.isnan()] = 0
fig = px.imshow(kl_result.sum(dim=-1).detach().cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu",
                labels = dict(x="Query", y="Head"), text_auto=".2f")
fig.layout.coloraxis.colorbar.title = "KL Divergence"
fig.update_layout(
    width=800,
    height=600,
)
fig.show()