In [1]:

from sae_lens import HookedSAETransformer, SAE, SAEConfig
from gemma_utils import get_gemma_2_config, gemma_2_sae_loader


In [2]:


from sae_lens import HookedSAETransformer, SAE, SAEConfig
from gemma_utils import get_gemma_2_config, gemma_2_sae_loader
import numpy as np
import torch
import tqdm
import einops
import re
from jaxtyping import Int, Float
from typing import List, Optional, Any
from torch import Tensor
import json
import os
from torch.utils.data import Dataset, DataLoader
import random
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import random
from transformer_lens.utils import get_act_name
from IPython.display import display, HTML
import plotly.express as px


In [3]:

np.random.seed(0)
random.seed(0)
torch.random.manual_seed(0)


<torch._C.Generator at 0x7f063db3b9f0>

In [4]:

model = HookedSAETransformer.from_pretrained("google/gemma-2-2b-it")




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer


In [5]:
topics = ["City Names","Countries","Animals","Types of Trees","Types of Flowers",
"Fruits","Vegetables","Car Brands","Sports","Rivers","Mountains","Ocean",
"Inventions","Languages","Capital Cities","Movies","Books","TV Shows",
"Famous Scientists","Famous Writers","Video Games","Companies","Colors"]


In [5]:
def generate_lists(topics):
    generation_dict = {}
    for topic in topics:
        generation_dict[topic] = []
        for _ in range(5):
            messages = [
                {"role": "user", "content": f"Provide me with a short list of a few {topic}. Just provide the names, no need for any other information."},
            ]
            input_ids = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

            input_ids += "-"
            tokens = model.to_tokens(input_ids, prepend_bos=False)
            out = model.generate(
                tokens,
                max_new_tokens = 200,
                temperature = 0.7,
                top_p = 0.8,
                stop_at_eos=True,
                )

            toks = out.detach().clone()
            generation_dict[topic].append(toks)
    return generation_dict
#generation_dict = generate_lists(topics)
#torch.save(generation_dict, "gemma2_generation_dict.pt")
generation_dict = torch.load("gemma2_generation_dict.pt")


In [6]:




hypen_tok_id = 235290
break_tok_id = 108
eot_tok_id = 107
blanck_tok_id = 235248
toks = generation_dict["City Names"][0].squeeze()
hypen_positions = torch.where(toks == hypen_tok_id)[0].to("cpu")
print(hypen_positions)
break_positions = torch.where(toks == break_tok_id)[0].to("cpu")
print(break_positions)
eot_positions = torch.where(toks == eot_tok_id)[0].to("cpu")
print(eot_positions)
filter_break_pos = [pos.item() for pos in break_positions if pos+1 in hypen_positions]
topic_spans = [(hypen_positions[i].item(),hypen_positions[i+1].item()) for i in range(len(hypen_positions)-1)] +[(hypen_positions[-1].item(),eot_positions[-1].item())]
token_spans = []
for span in topic_spans:
    token_spans.append(toks[span[0]:span[1]].tolist())

print(len(token_spans))
number_of_tokens_per_item = [len(span) for span in token_spans]
print(number_of_tokens_per_item)
white_space_tok = torch.tensor([235248 in tok_span for tok_span in token_spans])
white_spaces_tok_pos = torch.where(white_space_tok)[0].to("cpu")
print(white_space_tok)
print(white_spaces_tok_pos)



### What's happening with whitespaces


In [8]:
def get_stats(dict_toks):
    """
    For each topic create a dictionary of stats
    For each generated list get:
    - Number of tokens
    - Number of items in the list
    - Average number of tokens per item
    - Item positions in which blank tokens are foung
    """
    stats_dict = {}
    for topic, toks_list in dict_toks.items():
        stats_dict[topic] = []
        for toks in toks_list:
            toks = toks.squeeze()

            hypen_positions = torch.where(toks == hypen_tok_id)[0].to("cpu")
            break_positions = torch.where(toks == break_tok_id)[0].to("cpu")
            eot_positions = torch.where(toks == eot_tok_id)[0].to("cpu")
            filter_break_pos = [pos.item() for pos in break_positions if pos+1 in hypen_positions]
            topic_spans = [(hypen_positions[i].item(),hypen_positions[i+1].item()) for i in range(len(hypen_positions)-1)] +[(hypen_positions[-1].item(),eot_positions[-1].item())]
            token_spans = []
            for span in topic_spans:
                token_spans.append(toks[span[0]:span[1]].tolist())
            num_items = len(token_spans)
            number_of_tokens_per_item = torch.tensor([len(span) for span in token_spans])
            white_space_tok = torch.tensor([235248 in tok_span for tok_span in token_spans])
            white_spaces_tok_pos = torch.where(white_space_tok)[0].to("cpu")

            stats_dict[topic].append({"num_tokens": number_of_tokens_per_item, "num_items": num_items, "avg_tokens_per_item": number_of_tokens_per_item, "blank_positions": white_spaces_tok_pos})
    return stats_dict




stats_dict = get_stats(generation_dict)
    


In [9]:
import plotly.graph_objects as go
import torch

def plot_stats(stats_dict):
    """
    Plot the contents of the stats_dict generated from get_stats in Plotly interactive plots.
    """
    # Create separate lists for each stat
    topics = []
    num_tokens_per_topic = []
    avg_tokens_per_topic = []
    num_items_per_topic = []
    
    num_blank_tokens_per_topic = []
    blank_token_positions = []

    for topic, stats_list in stats_dict.items():
        # Iterate through the lists for each topic
        for stats in stats_list:
            topics.append(topic)
            num_tokens_per_topic.append(stats['num_tokens'].sum().item()/len(stats["num_tokens"]))  # Total number of tokens
            avg_tokens_per_topic.append(stats['avg_tokens_per_item'].float().mean().item())  # Average tokens per item
            num_items_per_topic.append(stats['num_items'])  # Number of items
            # Number of blank tokens
            num_blank_tokens_per_topic.append(len(stats['blank_positions']))
            
            # Blank token positions (we need to record topic and position)
            blank_token_positions.extend([(topic, pos.item()) for pos in stats['blank_positions']])
    # Create the interactive plot with Plotly
    fig = go.Figure()

    # Adding a bar chart for the number of tokens per topic
    fig.add_trace(go.Bar(
        x=topics,
        y=num_tokens_per_topic,
        name='Number of Tokens',
        marker_color='blue'
    ))

    # Adding a line chart for average tokens per item
    fig.add_trace(go.Scatter(
        x=topics,
        y=avg_tokens_per_topic,
        name='Average Tokens per Item',
        mode='lines+markers',
        marker_color='green'
    ))

    # Adding a bar chart for the number of items per topic
    fig.add_trace(go.Bar(
        x=topics,
        y=num_items_per_topic,
        name='Number of Items',
        marker_color='orange',
        opacity=0.6
    ))

    # Update the layout for better presentation
    fig.update_layout(
        title="Token Statistics per Topic",
        xaxis_title="Topics",
        yaxis_title="Values",
        barmode='group',  # Bars will be shown side by side
        legend=dict(
            x=0.1,
            y=1.1,
            orientation="h"
        ),
        template='plotly_dark'
    )

    fig1 = go.Figure()
    
    fig1.add_trace(go.Scatter(
        x=topics,
        y=avg_tokens_per_topic,
        mode='lines+markers',
        name='Avg Tokens per Item',
        marker=dict(color='green', size=8),
        line=dict(color='green', width=2)
    ))

    fig1.update_layout(
        title="Average Tokens per Item across Topics",
        xaxis_title="Topics",
        yaxis_title="Average Number of Tokens",
        template="plotly_dark"
    )

    # Plot 2: Bar plot for number of blank tokens per topic
    fig2 = go.Figure()

    fig2.add_trace(go.Bar(
        x=topics,
        y=num_blank_tokens_per_topic,
        name='Number of Blank Tokens',
        marker_color='orange',
        opacity=0.7
    ))

    # Adding a scatter plot for blank token positions
    if blank_token_positions:
        topics_pos, positions = zip(*blank_token_positions)  # Unzipping topic and position data
        fig2.add_trace(go.Scatter(
            x=topics_pos,
            y=positions,
            mode='markers',
            name='Blank Token Positions',
            marker=dict(color='red', size=8, symbol='x')
        ))

    fig2.update_layout(
        title="Blank Token Information across Topics",
        xaxis_title="Topics",
        yaxis_title="Blank Token Count / Position",
        template="plotly_dark",
        showlegend=True
    )
    # Show the plot
    fig.show()

    fig1.show()
    fig2.show()

plot_stats(stats_dict)


In [10]:
import plotly.graph_objects as go
import torch

def plot_items_vs_blank_tokens(stats_dict):
    """
    Plot a direct comparison between the number of items per topic and the number of blank tokens per topic.
    """
    # Create separate lists for each stat
    topics = []
    num_items_per_topic = []
    num_blank_tokens_per_topic = []
    proportion_blank_tokens_per_topic = []

    for topic, stats_list in stats_dict.items():
        for stats in stats_list:
            topics.append(topic)
            
            # Number of items per topic
            num_items_per_topic.append(stats['num_items'])
            
            # Number of blank tokens per topic
            num_blank_tokens_per_topic.append(len(stats['blank_positions']))
            # Proportion of blank tokens per topic
            proportion_blank_tokens_per_topic.append(len(stats['blank_positions']) / len(stats['num_tokens']))

    # Create the grouped bar chart
    fig = go.Figure()

    # Bar for the proportion of blank tokens per topic 
    fig.add_trace(go.Box(
        x=topics,
        y=proportion_blank_tokens_per_topic,
        name='Proportion of Blank Tokens',
        marker_color='blue'
    ))

    # Update the layout for better visualization
    fig.update_layout(
        title="Comparison of Number of Items vs. Blank Tokens per Topic",
        xaxis_title="Topics",
        yaxis_title="Count",
        barmode='group',  # Bars are grouped side by side for comparison
        template='plotly_dark',
        legend=dict(x=0.1, y=1.1, orientation="h")
    )

    # Display the plot
    fig.show()

# Example usage (you would call get_stats first and pass the result to plot_items_vs_blank_tokens)
# dict_toks = {...}  # Input dictionary of tokens per topic
# stats = get_stats(dict_toks)
plot_items_vs_blank_tokens(stats_dict)


## Template

- Create Template
- Sample from the Template 
- Filtering with final logits


In [27]:


toks = generation_dict["Vegetables"][0]
#with torch.no_grad():
#    logits,cache = model.run_with_cache(toks)


In [12]:
all_logit_diff = logits[:,:,eot_tok_id] - logits[:,:,hypen_tok_id]
all_logit_diff = all_logit_diff.cpu().squeeze().numpy()




# Display the difference in logit diff over the positions with a lineplot
# Use the tokens as the x
str_tokens = model.to_str_tokens(toks[0])

unique_tokens = [f"{i}/{t}" for i, t in enumerate(str_tokens)]


fig = px.bar(
    x=unique_tokens,
    y=all_logit_diff,
    labels={'x': 'Unique Tokens', 'y': 'Logit Difference'},  # Axis labels
    title="Logit Differences Across Unique Tokens"  # Plot title
)

# Customize the layout to add titles, axis labels, and other styling
fig.update_layout(
    title={
        'text': "Logit Differences Across Unique Tokens",
        'x': 0.5,  # Center the title
        'xanchor': 'center',
        'yanchor': 'top'
    },
    xaxis_title="Unique Tokens",
    yaxis_title="Logit Difference",
    xaxis_tickangle=-45,  # Rotate x-axis labels for better readability
    legend_title="Legend Title",  # Add a title to the legend if multiple traces
    showlegend=True  # Ensure the legend is shown (not necessary for single trace)
)

# Show the plot
fig.show()



### Create Template with the same structure 


## Logit Lens


In [28]:
import pandas as pd


In [29]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

def plot_superposed_lineplots(data, x, y, hue, title=None, xlabel=None, ylabel=None, 
                              palette='husl', line_styles=None, markers=True, 
                              grid=True, legend=True, legend_loc='best', 
                              figsize=(10, 6), linewidth=2):
    """
    Plots multiple superposed line plots with custom aesthetics using Seaborn.

    Parameters:
        data (pd.DataFrame): DataFrame containing the data.
        x (str): Column name to be used for the x-axis.
        y (str): Column name to be used for the y-axis.
        hue (str): Column name to be used for grouping data (each group gets a different line).
        title (str): Title of the plot.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        palette (str or list): Seaborn color palette or list of colors.
        line_styles (list or dict): List or dictionary of line styles for each hue category.
        markers (bool): Whether to show markers at data points.
        grid (bool): Whether to show grid lines.
        legend (bool): Whether to show the legend.
        legend_loc (str): Location of the legend.
        figsize (tuple): Size of the figure.
        linewidth (int): Width of the lines.
    """
    # Set up the matplotlib figure and size
    plt.figure(figsize=figsize)

    # Apply a Seaborn style
    sns.set(style="whitegrid")

    # Create the line plot
    sns.lineplot(data=data, x=x, y=y, hue=hue, palette=palette, style=hue if line_styles else None, 
                 dashes=line_styles, markers=markers, linewidth=linewidth)
    
    # Customize gridlines
    if grid:
        plt.grid(True, which='major', linestyle='--', linewidth=0.5)
    
    # Set the title and labels
    if title:
        plt.title(title, fontsize=16, fontweight='bold')
    if xlabel:
        plt.xlabel(xlabel, fontsize=14)
    if ylabel:
        plt.ylabel(ylabel, fontsize=14)
    plt.xticks(ticks=range(27), labels=labels, rotation=45, ha="right", fontsize=12)

    # Customize the legend
    if legend:
        plt.legend(title=hue, loc=legend_loc, title_fontsize='13', fontsize='11', frameon=True, framealpha=0.9)

    # Improve the overall layout
    plt.tight_layout()

    # Show the plot
    plt.show()

import pandas as pd

pattern_hook_names_filter = lambda name: name.endswith("hook_resid_pre") or name.endswith("25.hook_resid_post") or  name.endswith('ln_final.hook_scale')

with torch.no_grad():
    _,cache = model.run_with_cache(toks,names_filter = pattern_hook_names_filter)

toks = generation_dict["Vegetables"][0].squeeze()
hypen_positions = torch.where(toks == hypen_tok_id)[0].to("cpu")
break_positions = torch.where(toks == break_tok_id)[0].to("cpu")
eot_positions = torch.where(toks == eot_tok_id)[0].to("cpu")
filter_break_pos = [pos.item() for pos in break_positions if pos+1 in hypen_positions] + [eot_positions[-1].item()-1]


In [30]:


all_logit_lens = []
for pos in filter_break_pos:
    accumulated_residual, labels = cache.accumulated_resid(
        incl_mid=False, pos_slice=pos, return_labels=True,apply_ln = True, 
    )

    dir = model.W_U[:,eot_tok_id].detach()-model.W_U[:,hypen_tok_id].detach()
    logit_lens = einops.einsum(dir, accumulated_residual,"d_model,comps batch d_model -> comps batch ")
    all_logit_lens.append(logit_lens)

data_dict = {}
data_dict["Layers"] = list(range(model.cfg.n_layers+1))
for i,tensor in enumerate(all_logit_lens):
    data_dict[f"Item {i}"] =  tensor.cpu().reshape(-1).numpy()

data = pd.DataFrame(data_dict)
                     
# Melt the data for seaborn compatibility
melted_data = pd.melt(data, id_vars=['Layers'], value_vars=list(data_dict.keys())[1:], 
                      var_name='List Item', value_name='Value')

plot_superposed_lineplots(melted_data, x='Layers', y='Value', hue='List Item', 
                          title='Logit lens', xlabel='Layer', ylabel='Value',
                          palette='husl',  markers=True)


In [31]:

all_logit_lens = []
for pos in hypen_positions+1:
    accumulated_residual, labels = cache.accumulated_resid(
        incl_mid=False, pos_slice=pos, return_labels=True,apply_ln = True
    )

    dir = model.W_U[:,235248].detach()-model.W_U[:,break_tok_id].detach()
    logit_lens = einops.einsum(dir, accumulated_residual,"d_model,comps batch d_model -> comps batch ")
    all_logit_lens.append(logit_lens)

data_dict = {}
data_dict["Layers"] = list(range(model.cfg.n_layers+1))
for i,tensor in enumerate(all_logit_lens):
    data_dict[f"Item {i}"] =  tensor.cpu().reshape(-1).numpy()

data = pd.DataFrame(data_dict)
                     
# Melt the data for seaborn compatibility
melted_data = pd.melt(data, id_vars=['Layers'], value_vars=list(data_dict.keys())[1:], 
                      var_name='List Item', value_name='Value')

plot_superposed_lineplots(melted_data, x='Layers', y='Value', hue='List Item', 
                          title='Logit lens', xlabel='Layer', ylabel='Value',
                          palette='husl',  markers=True)


## What are the most important componetnts


### Attention Exploration


In [32]:
with torch.no_grad():
    _,cache = model.run_with_cache(toks)
dir = model.W_U[:,235248].detach()-model.W_U[:,break_tok_id].detach()
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=hypen_positions[-1]+1, return_labels=True
)
per_head_residual_per_layer = einops.rearrange(per_head_residual,"(layer n_head) batch d_model-> layer n_head batch d_model", layer = model.cfg.n_layers, n_head = model.cfg.n_heads)
per_head_layer_logit_diff = einops.einsum(per_head_residual_per_layer,dir.detach(),"layer n_head batch d_model, d_model-> layer n_head batch")[:,:,0].to("cpu").detach()


In [33]:
np_labels = np.array(labels).reshape(26,8)
data = {f"Layer {i}":arr for i,arr in enumerate(per_head_layer_logit_diff)}
data["Head"] = torch.tensor(list(range(8)))
df = pd.DataFrame(data)
df = df.set_index("Head")
sns.heatmap(df)


### Naively get the top model components by lodit diff


In [17]:
toks[0,47]


In [None]:


with torch.no_grad():
    _,cache = model.run_with_cache(toks)
resid_decomp, labels = cache.get_full_resid_decomposition(layer = -1, pos_slice = 48, apply_ln=True, return_labels = True,expand_neurons=False)


In [None]:
label_index = []
color_values = []
for lab in labels:
    if lab.startswith("L"):
        label_index.append(int(lab.split("L")[1].split("H")[0]))
        color_values.append(1)
    
    elif lab.endswith("out"):
        label_index.append(int(lab.split("_")[0]))
        color_values.append(2)
    else:
        label_index.append(-1)
        color_values.append(0)
color_map = {0: 'green', 1: 'blue', 2: 'red'}
colors = [color_map[val] for val in color_values]


In [None]:
per_comp_logit_diff = einops.einsum(resid_decomp,dir.detach()," comp batch d_model, d_model->  comp batch")[:,0].to("cpu").detach()


In [None]:
val,ind = per_comp_logit_diff.topk(k =10)
plt.figure(figsize=(10, 6))
plt.scatter(label_index, per_comp_logit_diff, color=colors)
# Identify the top 10 values
top_10_x = [label_index[i] for i in ind]
top_10_names = [labels[i] for i in ind]

# Annotate the top 10 values
for i in range(10):
    plt.annotate(
        top_10_names[i],
        (top_10_x[i], val[i]),
        xytext=(5, 5),
        textcoords='offset points',
        arrowprops=dict(arrowstyle='->', color='red')
    )

# Labels and title
plt.xlabel('Layers')
plt.ylabel('Logit Difference')
plt.title('Logit Lens for all the model components')

# Show the plot
plt.show()


# Ablation Experiments



1. After the last element of the last item, the model usually outputs a blank space.

Which Components can we zero ablate to make it output a "\n" instead of " "


2. Which elements can we ablate in the last break to make the model output "-" instead of "<end_of_turn>"


In [11]:

from transformer_lens.hook_points import HookPoint
from transformer_lens import utils
from functools import partial
import tqdm


**Attention Value Ablation**


In [19]:


def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    pos: Int,
    head_to_ablate: Int,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    value[:, pos, head_to_ablate, :] = 0.
    return value



def get_all_v_ablations(toks, pos ):
    n_layers = model.cfg.n_layers
    logit_diff_mat = model.cfg.n_key_value_heads
    logit_diff_mat = np.zeros((model.cfg.n_key_value_heads, n_layers))
    original_logits = model(toks, return_type="logits")
    clean_logit_diff = original_logits[:,pos,235248] - original_logits[:,pos,break_tok_id]  
    for head_to_ablate in tqdm.tqdm(range(model.cfg.n_key_value_heads)):
        for layer_to_ablate in range(n_layers):
            hook_func = partial(head_ablation_hook, pos=pos, head_to_ablate=head_to_ablate)
            ablated_logits = model.run_with_hooks(
                toks, 
                return_type="logits", 
                fwd_hooks=[(
                    utils.get_act_name("v", layer_to_ablate), 
                    hook_func
                    )]
                )
            ablated_logit_diff = ablated_logits[:,pos,235248] - ablated_logits[:,pos,break_tok_id]  
            logit_diff_mat[head_to_ablate,layer_to_ablate] = ablated_logit_diff/clean_logit_diff
    return logit_diff_mat




pos = 46
logit_diff_mat = get_all_v_ablations(toks, pos)


In [20]:
# Plotly heatmap
fig = px.imshow(logit_diff_mat, labels=dict(x="Layers", y="Heads", color="Logit Difference"),
                title=f"Logit Difference Ablations for Position {pos}",
                x=[f"Layer {i}" for i in range(model.cfg.n_layers)],
                y=[f"Head {i}" for i in range(model.cfg.n_key_value_heads)],
                color_continuous_scale='Viridis')
fig.show()
torch.cuda.empty_cache()


In [29]:


def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    pos: Int,
    head_to_ablate: Int,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    value[:, pos, head_to_ablate, :] = 0
    return value



def get_all_v_ablations(toks, pos ):
    n_layers = model.cfg.n_layers
    logit_diff_mat = np.zeros((model.cfg.n_heads, n_layers))
    original_logits = model(toks, return_type="logits")
    clean_logit_diff = original_logits[:,pos,235248] - original_logits[:,pos,break_tok_id]  
    for head_to_ablate in tqdm.tqdm(range(model.cfg.n_heads)):
        for layer_to_ablate in range(n_layers):
            hook_func = partial(head_ablation_hook, pos=pos, head_to_ablate=head_to_ablate)
            ablated_logits = model.run_with_hooks(
                toks, 
                return_type="logits", 
                fwd_hooks=[(
                    utils.get_act_name("z", layer_to_ablate), 
                    hook_func
                    )]
                )
            ablated_logit_diff = ablated_logits[:,pos,235248] - ablated_logits[:,pos,break_tok_id]  
            logit_diff_mat[head_to_ablate,layer_to_ablate] = ablated_logit_diff
    return logit_diff_mat




pos = 46
logit_diff_mat = get_all_v_ablations(toks, pos)


In [30]:

fig = px.imshow(logit_diff_mat, labels=dict(x="Layers", y="Heads", color="Logit Difference"),
                title=f"Logit Difference Ablations for Position {pos}",
                x=[f"Layer {i}" for i in range(model.cfg.n_layers)],
                y=[f"Head {i}" for i in range(model.cfg.n_heads)],
                color_continuous_scale='Viridis')
fig.show()


**Residual Stream Patching**


In [32]:

def rs_ablation_hook(
    act: Float[torch.Tensor, "batch pos d_model"],
    hook: HookPoint,
    pos: Int,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    act[:,pos,:] = 0
    return act



def get_all_resiudal_ablations(toks, pos ):
    n_layers = model.cfg.n_layers
    positions = list(range(33,pos-1))
    logit_diff_mat = np.zeros((n_layers, len(positions)))
    original_logits = model(toks, return_type="logits")
    clean_logit_diff = original_logits[:,pos,235248] - original_logits[:,pos,break_tok_id]  
    for i,pos1 in tqdm.tqdm(enumerate(positions)):
        for layer_to_ablate in range(n_layers):
            hook_func = partial(rs_ablation_hook, pos=pos1)
            ablated_logits = model.run_with_hooks(
                toks, 
                return_type="logits", 
                fwd_hooks=[(
                    f"blocks.{layer_to_ablate}.hook_resid_pre", 
                    hook_func
                    )]
                )
            ablated_logit_diff = ablated_logits[:,pos,235248] - ablated_logits[:,pos,break_tok_id]  
            logit_diff_mat[layer_to_ablate,i] = ablated_logit_diff
    return logit_diff_mat
pos = 46
logit_diff_mat = get_all_resiudal_ablations(toks, pos)


In [33]:

fig = px.imshow(logit_diff_mat, labels=dict(x="Positions", y="Layers", color="Logit Difference"),
                title=f"Logit Difference Ablations for Position {pos}",
                x=[f"Position {i}" for i in range(33,pos-1)],
                y=[f"Layers {i}" for i in range(model.cfg.n_layers)],
                color_continuous_scale='Viridis')
fig.show()


**Logit Diff Patching hook_q**


In [35]:
cache["blocks.0.attn.hook_q"].shape


In [36]:

def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    pos: Int,
    head_to_ablate: Int,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    value[:, pos, head_to_ablate, :] = 0
    return value



def get_all_v_ablations(toks, pos ):
    n_layers = model.cfg.n_layers
    logit_diff_mat = np.zeros((model.cfg.n_heads, n_layers))
    original_logits = model(toks, return_type="logits")
    clean_logit_diff = original_logits[:,pos,235248] - original_logits[:,pos,break_tok_id]  
    for head_to_ablate in tqdm.tqdm(range(model.cfg.n_heads)):
        for layer_to_ablate in range(n_layers):
            hook_func = partial(head_ablation_hook, pos=pos, head_to_ablate=head_to_ablate)
            ablated_logits = model.run_with_hooks(
                toks, 
                return_type="logits", 
                fwd_hooks=[(
                    utils.get_act_name("q", layer_to_ablate), 
                    hook_func
                    )]
                )
            ablated_logit_diff = ablated_logits[:,pos,235248] - ablated_logits[:,pos,break_tok_id]  
            logit_diff_mat[head_to_ablate,layer_to_ablate] = ablated_logit_diff
    return logit_diff_mat




pos = 46
logit_diff_mat = get_all_v_ablations(toks, pos)


In [37]:


fig = px.imshow(logit_diff_mat, labels=dict(x="Layers", y="Heads", color="Logit Difference"),
                title=f"Logit Difference Ablations for Position {pos}",
                x=[f"Layer {i}" for i in range(model.cfg.n_layers)],
                y=[f"Head {i}" for i in range(model.cfg.n_heads)],
                color_continuous_scale='Viridis')
fig.show()


**Patch key value**


In [12]:

def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    pos: Int,
    head_to_ablate: Int,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    value[:, 30, head_to_ablate, :] = 0
    return value



def get_all_v_ablations(toks, pos ):
    n_layers = model.cfg.n_layers
    logit_diff_mat = np.zeros((model.cfg.n_key_value_heads, n_layers))
    original_logits = model(toks, return_type="logits")
    clean_logit_diff = original_logits[:,pos,235248] - original_logits[:,pos,break_tok_id]  
    for head_to_ablate in tqdm.tqdm(range(model.cfg.n_key_value_heads)):
        for layer_to_ablate in range(n_layers):
            hook_func = partial(head_ablation_hook, pos=pos, head_to_ablate=head_to_ablate)
            ablated_logits = model.run_with_hooks(
                toks, 
                return_type="logits", 
                fwd_hooks=[(
                    utils.get_act_name("k", layer_to_ablate), 
                    hook_func
                    )]
                )
            ablated_logit_diff = ablated_logits[:,pos,235248] - ablated_logits[:,pos,break_tok_id]  
            logit_diff_mat[head_to_ablate,layer_to_ablate] = ablated_logit_diff
    return logit_diff_mat




pos = 46
logit_diff_mat = get_all_v_ablations(toks, pos)


In [14]:

fig = px.imshow(logit_diff_mat, labels=dict(x="Layers", y="Heads", color="Logit Difference"),
                title=f"Logit Difference Ablations for Position {pos}",
                x=[f"Layer {i}" for i in range(model.cfg.n_layers)],
                y=[f"Head {i}" for i in range(model.cfg.n_key_value_heads)],
                color_continuous_scale='Viridis')
fig.show()
torch.cuda.empty_cache()


**MLP output ablation**


In [28]:
def mlp_ablation_hook(
    act: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    pos: Int,
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    act[:, pos,:] = 0
    return act






def get_all_mlp_ablations(toks, pos ):
    n_layers = model.cfg.n_layers
    positions = list(range(33,pos-1))
    logit_diff_mat = np.zeros((n_layers, len(positions)))
    original_logits = model(toks, return_type="logits")
    clean_logit_diff = original_logits[:,pos,235248] - original_logits[:,pos,break_tok_id]  
    for i,pos1 in tqdm.tqdm(enumerate(positions)):
        for layer_to_ablate in range(n_layers):
            hook_func = partial(mlp_ablation_hook, pos=pos1)
            ablated_logits = model.run_with_hooks(
                toks, 
                return_type="logits", 
                fwd_hooks=[(
                    f"blocks.{layer_to_ablate}.hook_mlp_out", 
                    hook_func
                    )]
                )
            ablated_logit_diff = ablated_logits[:,pos,235248] - ablated_logits[:,pos,break_tok_id]  
            logit_diff_mat[layer_to_ablate,i] = ablated_logit_diff
    return logit_diff_mat

pos = 46
logit_diff_mat = get_all_mlp_ablations(toks, pos)
torch.cuda.empty_cache()


In [29]:

fig = px.imshow(logit_diff_mat, labels=dict(x="Positions", y="Layers", color="Logit Difference"),
                title=f"Logit Difference Ablations for Position {pos}",
                x=[f"Position {i}" for i in range(33,pos-1)],
                y=[f"Layers {i}" for i in range(model.cfg.n_layers)],
                color_continuous_scale='Viridis')
fig.show()


**Replacement Hook Residual Stream**


In [125]:
torch.cuda.empty_cache()
def get_resid(toks,positions):
    names_filter = lambda x: "hook_resid_pre" in x
    with torch.no_grad():
        _,cache = model.run_with_cache(toks, names_filter = names_filter)
    all_resids = [torch.stack([cache[f"blocks.{i}.hook_resid_pre"][:,p+1] for p in positions]) for i in range(26)]
    all_resids = torch.stack(all_resids).mean(1).mean(1)
    print(all_resids.shape)
    return all_resids

def get_ln1_normalized(toks,positions):
    names_filter = lambda x: "ln1.hook_normalized" in x
    with torch.no_grad():
        _,cache = model.run_with_cache(toks, names_filter = names_filter)
    all_ln1_normalized = [torch.stack([cache[f"blocks.{i}.ln1.hook_normalized"][:,p+1] for p in positions]) for i in range(26)]
    all_ln1_normalized = torch.stack(all_ln1_normalized).mean(1).mean(1)
    print(all_ln1_normalized.shape)
    return all_ln1_normalized



def get_mlp(toks,positions):
    names_filter = lambda x: "hook_mlp_out" in x
    with torch.no_grad():
        _,cache = model.run_with_cache(toks, names_filter = names_filter)
    all_mlps = [torch.stack([cache[f"blocks.{i}.hook_mlp_out"][:,p+1] for p in positions]) for i in range(26)]
    all_mlps = torch.stack(all_mlps).mean(1).mean(1)
    print(all_mlps.shape)
    return all_mlps

def get_attn(toks,positions):
    names_filter = lambda x: "hook_attn_out" in x
    with torch.no_grad():
        _,cache = model.run_with_cache(toks, names_filter = names_filter)
    all_attn = [torch.stack([cache[f"blocks.{i}.hook_attn_out"][:,p+1] for p in positions]) for i in range(26)]
    all_attn = torch.stack(all_attn).mean(1).mean(1)
    print(all_attn.shape)
    return all_attn

def get_attn_z(toks, positions):
    names_filter = lambda x: "hook_z" in x
    with torch.no_grad():
        _,cache = model.run_with_cache(toks, names_filter = names_filter)
    all_attn = [torch.stack([cache[f"blocks.{i}.attn.hook_z"][:,p+1] for p in positions]) for i in range(26)]
    all_attn = torch.stack(all_attn).mean(1).mean(1)
    print(all_attn.shape)
    return all_attn

def get_attn_q(toks, positions):
    names_filter = lambda x: "hook_q" in x
    with torch.no_grad():
        _,cache = model.run_with_cache(toks, names_filter = names_filter)
    all_attn = [torch.stack([cache[f"blocks.{i}.attn.hook_q"][:,p+1] for p in positions]) for i in range(26)]
    all_attn = torch.stack(all_attn).mean(1).mean(1)
    print(all_attn.shape)
    return all_attn


def get_attn_k(toks, positions):
    names_filter = lambda x: "hook_k" in x
    with torch.no_grad():
        _,cache = model.run_with_cache(toks, names_filter = names_filter)
    all_attn = [torch.stack([cache[f"blocks.{i}.attn.hook_k"][:,p+1] for p in positions]) for i in range(26)]
    all_attn = torch.stack(all_attn).mean(1).mean(1)
    print(all_attn.shape)
    return all_attn

def get_attn_v(toks, positions):
    names_filter = lambda x: "hook_v" in x
    with torch.no_grad():
        _,cache = model.run_with_cache(toks, names_filter = names_filter)
    all_attn = [torch.stack([cache[f"blocks.{i}.attn.hook_v"][:,p+1] for p in positions]) for i in range(26)]
    all_attn = torch.stack(all_attn).mean(1).mean(1)
    print(all_attn.shape)
    return all_attn



hypen_positions = torch.where(toks[0] == hypen_tok_id)[0]
def get_all_mean_ablations(toks,pos):
    resid_mean = get_resid(toks, hypen_positions[:-1])
    ln1_normalized_mean = get_ln1_normalized(toks, hypen_positions[:-1])
    resid_mlp = get_mlp(toks, hypen_positions[:-1])
    attn_mean = get_attn(toks, hypen_positions[:-1])
    z_mean = get_attn_z(toks, hypen_positions[:-1])
    q_mean = get_attn_q(toks, hypen_positions[:-1])
    k_mean = get_attn_k(toks, hypen_positions[:-1])
    v_mean = get_attn_v(toks, hypen_positions[:-1])

    def resid_replacement_hook(acts,hook,pos,layer):
        acts[:,pos,:] = resid_mean[layer].unsqueeze(0)
        return acts
    def ln1_normalized_replacement_hook(acts,hook,pos,layer):
        acts[:,pos,:] = ln1_normalized_mean[layer].unsqueeze(0)
        return acts
    def mlp_replacement_hook(acts,hook,pos,layer):
        acts[:,pos,:] = resid_mlp[layer].unsqueeze(0)
        return acts
    def attn_replacement_hook(acts,hook,pos,layer):
        acts[:,pos,:] = attn_mean[layer].unsqueeze(0)
        return acts
    def attn_z_replacement_hook(acts,hook,pos,layer):
        acts[:,pos,:] = z_mean[layer].unsqueeze(0)
        return acts
    def attn_q_replacement_hook(acts,hook,pos,layer):
        acts[:,pos,:] = q_mean[layer].unsqueeze(0)
        return acts
    def attn_k_replacement_hook(acts,hook,pos,layer):
        acts[:,pos,:] = k_mean[layer].unsqueeze(0)
        return acts
    def attn_v_replacement_hook(acts,hook,pos,layer):
        acts[:,pos,:] = v_mean[layer].unsqueeze(0)
        return acts
    logit_diff_mat = np.zeros((model.cfg.n_layers,8))
    n_layers = model.cfg.n_layers

    for layer_to_ablate in range(n_layers):
        for comp_id,comp in enumerate(["hook_resid_pre","hook_mlp_out","hook_attn_out","attn.hook_z","attn.hook_q","attn.hook_k","ln1.hook_normalized","attn.hook_v"]):
            if comp == "hook_resid_pre":
                hook_func = partial(resid_replacement_hook, pos=pos,layer = layer_to_ablate)
            elif comp == "hook_mlp_out":
                hook_func = partial(mlp_replacement_hook,pos = pos, layer = layer_to_ablate) 
            elif comp == "hook_attn_out":
                hook_func = partial(attn_replacement_hook,pos = pos, layer = layer_to_ablate) 
            elif comp == "attn.hook_z":
                hook_func = partial(attn_z_replacement_hook,pos = pos, layer = layer_to_ablate) 
            elif comp == "attn.hook_q":
                hook_func = partial(attn_q_replacement_hook,pos = pos, layer = layer_to_ablate) 
            elif comp == "attn.hook_k":
                hook_func = partial(attn_k_replacement_hook,pos = pos, layer = layer_to_ablate) 
            elif comp == "attn.hook_v":
                hook_func = partial(attn_v_replacement_hook,pos = pos, layer = layer_to_ablate) 
            elif comp == "ln1.hook_normalized":
                hook_func = partial(ln1_normalized_replacement_hook,pos = pos, layer = layer_to_ablate) 
            with torch.no_grad():
                ablated_logits = model.run_with_hooks(
                    toks, 
                    return_type="logits", 
                    fwd_hooks=[(
                        f"blocks.{layer_to_ablate}.{comp}", 
                        hook_func
                        )]
                    )
            ablated_logit_diff = ablated_logits[:,pos,235248] - ablated_logits[:,pos,break_tok_id]  
            logit_diff_mat[layer_to_ablate,comp_id] = ablated_logit_diff.cpu().numpy()
    return logit_diff_mat
    

logit_diff_mat = get_all_mean_ablations(toks, 46)
torch.cuda.empty_cache()


In [119]:
cache["blocks.0.ln1.hook_scale"].shape


In [126]:


comps = ["hook_resid_pre","hook_mlp_out","hook_attn_out","attn.hook_z","attn.hook_q","attn.hook_k","ln1.hook_normalized","attn.hook_v"]
fig = px.imshow(logit_diff_mat.T, labels=dict(x="Positions", y="Layers", color="Logit Difference"),
                title=f"Logit Difference Ablations for Position {pos}",
                y=[comp for comp in comps],
                x=[f"Layers {i}" for i in range(model.cfg.n_layers)],
                color_continuous_scale='Viridis')
fig.show()


# Viusalize attention patterns


In [127]:
pattern_hook_names_filter = lambda name: "pattern" in name
with torch.no_grad():
    _,cache = model.run_with_cache(toks,names_filter = pattern_hook_names_filter)


cache = cache.to("cpu")



In [128]:
pos = 46
all_patterns_pos = torch.stack([cache[f"blocks.{i}.attn.hook_pattern"][0,:,pos] for i in range(25)])
all_patterns_max = all_patterns_pos.max(dim = 1).values

str_tokens = model.to_str_tokens(toks)
unique_tokens = [f"{i}/{t}" for i, t in enumerate(str_tokens)] 
fig = px.imshow(all_patterns_max, labels=dict(x="Toks", y="Layers", color="Attention Pattern"),
                title=f"Logit Difference Ablations for Position {pos}",
                x=[i for i in unique_tokens],
                y=[f"Layer {i}" for i in range(25)],
                color_continuous_scale='Viridis')
fig.show()
torch.cuda.empty_cache()


In [126]:
pos = 46
str_tokens = model.to_str_tokens(toks)
unique_tokens = [f"{i}/{t}" for i, t in enumerate(str_tokens)] 
pattern = cache["blocks.21.attn.hook_pattern"][0,:,pos]
fig = px.imshow(pattern, labels=dict(x="Toks", y="Heads", color="Attention Pattern"),
                title=f"Attention Patterns",
                x=[i for i in unique_tokens],
                y=[f"Head {i}" for i in range(8)],
                color_continuous_scale='Viridis')
fig.show()


## SAEs


In [1]:
from attribution_utils import calculate_feature_attribution
from torch.nn.functional import log_softmax
from gemma_utils import get_all_string_min_l0_resid_gemma
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils
from functools import partial
import tqdm
from sae_lens import HookedSAETransformer, SAE, SAEConfig
from gemma_utils import get_gemma_2_config, gemma_2_sae_loader
import numpy as np
import torch
import tqdm
import einops
import re
from jaxtyping import Int, Float
from typing import List, Optional, Any
from torch import Tensor
import json
import os
from torch.utils.data import Dataset, DataLoader
import random
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import random
from transformer_lens.utils import get_act_name
from IPython.display import display, HTML
import plotly.express as px


In [2]:

model = HookedSAETransformer.from_pretrained("google/gemma-2-2b-it")
generation_dict = torch.load("gemma2_generation_dict.pt")
toks = generation_dict["Vegetables"][0]

hypen_tok_id = 235290
break_tok_id = 108
eot_tok_id = 107
blanck_tok_id = 235248
hypen_positions = torch.where(toks[0] == hypen_tok_id)[0]
break_positions = torch.where(toks[0] == break_tok_id)[0]
eot_positions = torch.where(toks[0] == eot_tok_id)[0]
filter_break_pos = [pos.item() for pos in break_positions if pos+1 in hypen_positions]


In [3]:
pos = 46
toks[:,pos]


In [7]:

def metric_fn(logits: torch.Tensor, pos:int = 46) -> torch.Tensor:
    return logits[0,pos,235248] - logits[0,pos,break_tok_id]


In [4]:
full_strings = get_all_string_min_l0_resid_gemma()
layer = 20
saes_dict = {}
with torch.no_grad():
    repo_id = "google/gemma-scope-2b-pt-res"
    folder_name = full_strings[layer]
    config = get_gemma_2_config(repo_id, folder_name)
    cfg, state_dict, log_spar = gemma_2_sae_loader(repo_id, folder_name)
    sae_cfg = SAEConfig.from_dict(cfg)
    sae = SAE(sae_cfg)
    sae.load_state_dict(state_dict)
    sae.to("cuda:0")
    sae.use_error_term = True

    saes_dict[sae.cfg.hook_name] = sae


In [5]:
import pandas as pd
import plotly.express as px


In [8]:

feature_attribution_df = calculate_feature_attribution(
    model = model,
    input = toks,
    metric_fn = metric_fn,
    include_saes=saes_dict,
    include_error_term=True,
    return_logits=True,
)


In [12]:
def convert_sparse_feature_to_long_df(sparse_tensor: torch.Tensor) -> pd.DataFrame:
    """
    Convert a sparse tensor to a long format pandas DataFrame.
    """
    df = pd.DataFrame(sparse_tensor.detach().cpu().numpy())
    df_long = df.melt(ignore_index=False, var_name='column', value_name='value')
    df_long.columns = ["feature", "attribution"]
    df_long_nonzero = df_long[df_long['attribution'] != 0]
    df_long_nonzero = df_long_nonzero.reset_index().rename(columns={'index': 'position'})
    return df_long_nonzero

df_long_nonzero = convert_sparse_feature_to_long_df(feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0])
df_long_nonzero.sort_values("attribution", ascending=False)


## Parallel Cordinates Plot 


In [5]:
positions = [pos-1 for pos in all_tok2_pos+[tok_pos]]
attr_df_pos = df_long_nonzero[df_long_nonzero['position'].isin(positions)]
attr_df_pos = attr_df_pos.pivot(index = "feature", columns = "position", values = "attribution").reset_index()
mat = df.iloc[:,2:].to_numpy()
sort = mat.argsort(axis = 0)
dataframe = pd.DataFrame.from_records(sort).reset_index()
dataframe["feature"] = df['feature'].tolist()
df = attr_df_pos.dropna(axis = 0)


In [None]:

mean = mat.mean(axis=1)

# New normalization with dramatic difference
power = 2  # Adjust this value to increase or decrease the effect
colors = (mean - mean.min()) / (mean.max() - mean.min())
colors = np.power(colors, power)
top_10_indices = np.argsort(mean)[-10:][::-1]

plt.figure(figsize=(10, 6))


# Plot each feature
for i in range(len(dataframe)):
    y = dataframe.iloc[i, 1:-1].values
    x = range(len(y))
    color = str(1 - colors[i])  # Invert the color value (darker = higher value)
    line = plt.plot(x, y, color=color, linewidth=2, alpha=0.7)
    
    # Annotate top 10 lines
    if i in top_10_indices:
        plt.annotate(dataframe['feature'][i], 
                     xy=(len(x)-1, y[-1]), 
                     xytext=(5, 0), 
                     textcoords='offset points',
                     ha='left', 
                     va='center', 
                     fontsize=8,
                     color=color,
                     fontweight='bold')
# Customize the plot
plt.title('Parallel Coordinates Plot')
plt.xlabel('Positions')
plt.ylabel('Rank')
plt.xticks(range(len(dataframe.columns[1:-1])), dataframe.columns[1:-1], rotation=45)
plt.grid(False)

# Show the plot
plt.tight_layout()
plt.show()


In [None]:
feat_dir = sae.W_dec[4411,:].detach()

f_tok_val,f_tok_ind = (feat_dir@model.W_U.detach()).topk(k = 10)
highest_logtis_dict = {}
highest_logtis_dict["Token"] = []
highest_logtis_dict["Value"] = []
for v,t in zip(f_tok_val,f_tok_ind):
    highest_logtis_dict["Token"] += [model.to_string(t)]
    highest_logtis_dict["Value"] += [v.to("cpu").item()]


In [None]:

df_highest_tok = pd.DataFrame.from_dict(highest_logtis_dict).reset_index().set_index("Token")

df_highest_tok


In [None]:
model.to_string(tokens[0,tok_pos])


In [None]:
tokens = model.to_str_tokens(string, prepend_bos=False)[:tok_pos]
unique_tokens = [f"{i}/{t}" for i, t in enumerate(tokens)]

px.bar(x = unique_tokens,
       y = feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0].sum(-1).detach().cpu().numpy()[:tok_pos])


In [None]:
# Just for the feature 4411


px.bar(x = unique_tokens,
       y = feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0][:,4411].detach().cpu().numpy()[:tok_pos],title ="Attribution for feature 4411")


### Get the top features in the last position

Promote end of list [ 4411, 13491,  5325,  1777,  6004, 11942,  9369,  1000,  4384,  3855]


Promote continuation of list[10529,   152, 12523,  8323,   492, 10548,  5169, 14540,   561,  7368]


In [None]:
val, ind = feature_attribution_df.sae_feature_attributions[sae.cfg.hook_name][0,tok_pos-1].topk(k = 10,dim = -1)
ind = ind.tolist()
ind


In [None]:
from functools import partial

def prompt_with_ablation(model, sae, prompt, ablation_features):
    
    def ablate_feature_hook(feature_activations, hook, feature_ids, position = None):
    
        if position is None:
            feature_activations[:,:,feature_ids] = 0
        else:
            feature_activations[:,position,feature_ids] = 30

        return feature_activations
        
    ablation_hook = partial(ablate_feature_hook, feature_ids = ablation_features, position = 46)
    
    model.add_sae(sae)
    hook_point = sae.cfg.hook_name + '.hook_sae_acts_post'
    model.add_hook(hook_point, ablation_hook, "fwd")
    
    
    logits = model(prompt)


    
    model.reset_saes()
    model.reset_hooks()
    return logits



## Zero ablation experiment
**How do the llogits change when we ablate each one of the features**


In [213]:


tokens = model.to_tokens(string,prepend_bos=False)[0][:tok_pos]
with torch.no_grad():
    original_logits = model(tokens)
    original_logit_diff = original_logits[0,-1,3119] - original_logits[0,-1,2577]

original_logit_diff


In [272]:
all_logit_diff = []
for feat in ind:
#for i in range(0,len(ind)):
    model.reset_hooks(including_permanent=True)
    logits = prompt_with_ablation(model, sae, tokens,feat)
    logit_diff = logits[0,-1,3119] - logits[0,-1,2577]
    all_logit_diff.append(logit_diff.cpu().item()) 


print(all_logit_diff)


In [248]:
feature_effect = defaultdict(list) 
for i in range(len(all_logit_diff)):
    feature_effect["Feature"] += [ind[i]]
    feature_effect["Zero Ablation Logit Diff "] += [all_logit_diff[i]]

df_feature_effect = pd.DataFrame.from_dict(feature_effect)
df_feature_effect


### Generate with intervention

Set the feature 4411 in the last position to 10


In [256]:
feat_4411_acts = feature_attribution_df.sae_feature_activations['blocks.22.hook_resid_post'][0,:,4411]
filterd_feat_4411_acts = [feat_4411_acts[i] for i in positions]


In [257]:
filterd_feat_4411_acts


In [285]:

def steering_hook(feature_activations, hook, feature_ids=4411, position = None,val= filterd_feat_4411_acts[2]):

    if feature_activations.shape[1]>1:
        # Inital batch of activations
        if position is None:
            feature_activations[:,:,feature_ids] = 0
        else:
            if type(feature_ids)==list:
                for f in feature_ids:
                    feature_activations[:,position,f] = 0 
            else:    
                print(feature_activations.shape)
                feature_activations[:,position,feature_ids] = val

        return feature_activations
    else:
        return feature_activations

string = "<bos><body><h1>List of My Brother's Favourite Cities</h1><ul><li>Bangkok, Thailand</li><li>London, England</li><li>Paris, France</li><li>Melbourne, Australia</li><li>Toronto, Canada</li><li>New York, USA</li></ul></body>\n<eos>"
tokens = model.to_tokens(string, prepend_bos=False)[:,:tok_pos]
model.add_sae(sae)
with model.hooks(fwd_hooks=[(sae.cfg.hook_name+".hook_sae_acts_post", steering_hook)]):
    output = model.generate(
        tokens,
        max_new_tokens=100,
        temperature=0.7,
        top_p=0.9,
        stop_at_eos =  True,
        prepend_bos = sae.cfg.prepend_bos,
    )


In [286]:
print(model.to_string(output))
visualize_html(model.to_string(output)[0])


### DFA







In [32]:
# What are the features that most contribute to the fireing of feature 4411
with torch.no_grad():
    logits, cache = model.run_with_cache(tokens)


In [155]:

pattern = cache[get_act_name("pattern",22)]
out = cache["blocks.22.hook_attn_out"]
stacked_values = cache['blocks.22.attn.hook_v']
v_cat = torch.repeat_interleave(stacked_values, dim=2, repeats=2)
v_cat = einops.rearrange(v_cat, "batch src_pos n_heads  d_head -> batch src_pos (n_heads d_head) ",d_head = model.cfg.d_head,n_heads = model.cfg.n_heads)


In [204]:

attn_weight = einops.repeat(pattern, "batch n_heads dest src -> batch dest src (n_heads d_head)",d_head = model.cfg.d_head, n_heads = model.cfg.n_heads)
decompose_z_cat = attn_weight*v_cat.unsqueeze(0)
W_O = model.state_dict()["blocks.18.attn.W_O"].detach()
W_O_conc = einops.rearrange(W_O,"n_heads d_head d_model -> (n_heads d_head) d_model") 
decompose_out = einops.einsum(W_O_conc,decompose_z_cat,"d_attn d_model, batch dest src d_attn -> batch d_model dest src")
scale = cache['blocks.22.ln2.hook_scale']
resid_rms = (decompose_out**2).mean(dim = 1,keepdim = True).sqrt()
decompose_out_normalized =  (decompose_out/resid_rms)*scale
dir_decompose_out = einops.einsum(decompose_out_normalized, dir, "batch d_model dest src, d_model -> batch dest src")


In [205]:
sns.heatmap(dir_decompose_out[0].cpu())
plt.show()


In [211]:

tokens = model.to_str_tokens(string, prepend_bos=False)[:tok_pos]
unique_tokens = [f"{i}/{t}" for i, t in enumerate(tokens)]

px.bar(x = unique_tokens,
       y =dir_decompose_out[0,-1].cpu().numpy()) 


In [208]:
dir_decompose_out[0,-1]
