# SAEval

### Setup

In [1]:
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/jbloomAus/SAELens
  
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import torch
import transformer_lens.utils as utils

import plotly.express as px
import tqdm
from functools import partial
import einops
import plotly.graph_objects as go

update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis",
     "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid",
     "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"
}

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, return_fig=False, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    fig = px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)
    if return_fig:
        return fig
    fig.show(renderer)

from typing import List
def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):


    y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]
    error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs] 

    fig = go.Figure(data=[go.Bar(
        x=x_axis,
        y=y_data,
        error_y=dict(
            type='data',  # specifies that the actual values are given
            array=error_y_data,  # the magnitudes of the errors
            visible=True  # make error bars visible
        ),
    )])

    # Customize layout
    fig.update_layout(title_text=f'Logit Diff after Interventions',
                    xaxis_title_text='Intervention',
                    yaxis_title_text='Logit diff',
                    plot_bgcolor='white')

    # Show the figure
    fig.show()

In [3]:
import os

os.environ['HF_HOME'] = '/workspace/huggingface/'

if torch.cuda.is_available():
    device = "cuda"
# elif torch.backends.mps.is_available():
#     device = "mps"
else: 
    device = "cpu"
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa59dfd8550>

### Load model and data

In [83]:
from sae_lens import HookedSAETransformer

model: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2").to(device)
#model.set_use_attn_in(True)
#model.set_use_attn_result(True)
#model.set_use_hook_mlp_in(True)
#model.set_use_split_qkv_input(True)



Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda


In [84]:
import json

with open('tasks/ioi/task.json') as f:
    task = json.load(f)

In [85]:
task['prompts'][0]

{'clean_prompt': 'When John and Mary went to the store, Mary gave a drink to',
 'corr_prompt': 'When John and Mary went to the store, John gave a drink to',
 'variables': {'IO': 'John', 'S1': 'Mary', 'S2': 'Mary', 'Pos': 'ABB'}}

In [130]:
from functools import partial
import re

def logits_diff(logits, correct_answer, incorrect_answer):
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

def zero_abl_hook(x, hook, pos, head):
    x[:, pos, head] = 0
    return x

def patching_hook(x, hook, pos, head, corr):
    x[:, pos, head] = corr[:, pos, head]
    return x

class TransformerCircuit:
    def __init__(self, model, task):
        self.model = model
        self.task = task

    def get_node(self, node_name):
        for node in self.task['nodes']:
            if node['name'] == node_name:
                return node
        return None

    def get_variable(self, variable_name):
        for variable in self.task['variables']:
            if variable['name'] == variable_name:
                return variable
        return None

    @classmethod
    def read_variable(self, x):
        if '+' in x:
            offset = int(x.split('+')[-1])
        elif '-' in x:
            offset = int(x.split('-')[-1])
        else:
            offset = 0

        pattern = r"\{([^}]*)\}"
        return re.findall(pattern, x)[0], offset

class IOICircuit(TransformerCircuit):
    def __init__(self, model, cfg):
        super().__init__(model, cfg)

    def run_with_patch(self, prompt, node_names, component_name, method='zero'):
        assert method in ['zero', 'corr'], "Method must be either 'zero' or 'corr'"
        assert component_name in ['q', 'k', 'v', 'z'], "Method must be either 'result', 'q', 'k', 'v'"
        
        io = ' '+prompt['variables']['IO']
        s1 = ' '+prompt['variables']['S1']
        pos = prompt['variables']['Pos']
        pos = 0 if pos == "ABB" else 1

        clean_prompt = prompt['clean_prompt']
        clean_tokens = model.to_tokens(clean_prompt)

        with torch.no_grad():
            clean_logits = self.model(clean_tokens)
        
        nodes = [self.get_node(node_name) for node_name in node_names]

        self.model.reset_hooks(including_permanent=True)

        for node in nodes:
            if component_name in ['q', 'result']:
                var, offset = self.read_variable(node['q'])
            else:
                var, offset = self.read_variable(node['kv'])

            var_pos = self.get_variable(var)['position'][pos] + offset

            for head in node['heads']:
                l, h = head.split('.')
                hook_name = utils.get_act_name(component_name, int(l))
                
                if method == 'zero':
                    hook_fn = partial(zero_abl_hook, pos=var_pos, head=int(h))
                else:
                    corr_prompt = prompt['corr_prompt']
                    corr_tokens = model.to_tokens(corr_prompt)
                    
                    with torch.no_grad():
                        _, corr_cache = self.model.run_with_cache(corr_tokens)

                    hook_fn = partial(patching_hook, pos=var_pos, head=int(h), corr=corr_cache[hook_name])

                print(hook_name, int(h))
                self.model.add_perma_hook(hook_name, hook_fn)

        with torch.no_grad():
            patched_logits = model(clean_tokens)

        self.model.reset_hooks(including_permanent=True)

        return logits_diff(clean_logits, io, s1) - logits_diff(patched_logits, io, s1)

In [131]:
ioi_circuit = IOICircuit(model, task)

In [132]:
ioi_circuit.run_with_patch(task['prompts'][0], ['NMH', 'bNMH'], 'k', method='corr')

blocks.9.attn.hook_k 6
blocks.9.attn.hook_k 9
blocks.10.attn.hook_k 0
blocks.0.attn.hook_k 1
blocks.3.attn.hook_k 0


blocks.0.attn.hook_k 10


tensor(0., device='cuda:0')

## Eval

In [119]:
from tqdm.auto import tqdm

def get_task_activations(task, bs=64):
    activations = []
    for b in tqdm(range(0, len(task), bs)):
        prompts = task['prompt'].iloc[b:b+bs]
        tokens = model.to_tokens(prompts.tolist())
        
        with torch.no_grad():
            _, cache = model.run_with_cache(tokens)

        activations.append(cache.stack_activation('resid_post'))

    return torch.cat(activations, 1)

In [120]:
activations = get_task_activations(ioi_task)

  0%|          | 0/16 [00:00<?, ?it/s]

In [149]:
def ioi_supervised_dictionary(task, activations):
    l, n, p, dm = activations.shape
    centered_activations = activations - activations.mean(2).mean(1)[:, None, None]
    # IO
    io_vec = {}
    for name in names:
        mask = task['io'] == ' '+name
        pos = torch.tensor(task.loc[mask, 'io_pos'].tolist(), device=device)
        features = centered_activations[:, mask].gather(2, pos[None, :, None, None].expand(l, len(pos), 1, dm))
        io_vec[name] = features.mean(1)[:, 0]

    # S1
    s1_vec = {}
    s2_vec = {}
    for name in names:
        mask = task['s'] == ' '+name
        
        pos = torch.tensor(task.loc[mask, 's1_pos'].tolist(), device=device)
        features = centered_activations[:, mask].gather(2, pos[None, :, None, None].expand(l, len(pos), 1, dm))
        s1_vec[name] = features.mean(1)[:, 0]
    
        pos = torch.tensor(task.loc[mask, 's2_pos'].tolist(), device=device)
        features = centered_activations[:, mask].gather(2, pos[None, :, None, None].expand(l, len(pos), 1, dm))
        s2_vec[name] = features.mean(1)[:, 0]

    # Pos
    X = centered_activations[:, :, -1]
    y = torch.tensor(task['pos'] == 'ABB', device=device)
    print(X.shape)
    pos_vec = torch.linalg.lstsq(X, y.type(torch.float32)[None])

    return (io_vec, s1_vec, s2_vec, pos_vec)

In [150]:
ioi_supervised_dictionary(ioi_task, activations)

torch.Size([6, 1024, 15, 512])


RuntimeError: torch.linalg.lstsq: input.dim() must be greater or equal to other.dim() and (input.dim() - other.dim()) <= 1

In [137]:
vec[0]

KeyError: 0

In [124]:
ioi_task

Unnamed: 0,prompt,io,s,pos,io_pos,s1_pos,s2_pos
0,"When Paul and Nancy went to the store, Paul ga...",Nancy,Paul,BAB,4,2,10
1,"When Henry and Lucy went to the store, Lucy ga...",Henry,Lucy,ABB,2,4,10
2,"When Aaron and Maria went to the store, Maria ...",Aaron,Maria,ABB,2,4,10
3,"When Megan and Chloe went to the store, Chloe ...",Megan,Chloe,ABB,2,4,10
4,"When Maria and Sara went to the store, Maria g...",Sara,Maria,BAB,4,2,10
...,...,...,...,...,...,...,...
1019,"When Linda and Diane went to the store, Linda ...",Diane,Linda,BAB,4,2,10
1020,"When Grace and Susan went to the store, Susan ...",Grace,Susan,ABB,2,4,10
1021,"When Maria and Olivia went to the store, Olivi...",Maria,Olivia,ABB,2,4,10
1022,"When Julie and Amber went to the store, Julie ...",Amber,Julie,BAB,4,2,10
