# SAEval

### Setup

In [1]:
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/jbloomAus/SAELens
  
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import torch
import transformer_lens.utils as utils

import plotly.express as px
import tqdm
from functools import partial
import einops
import plotly.graph_objects as go

update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis",
     "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid",
     "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"
}

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, return_fig=False, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    fig = px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)
    if return_fig:
        return fig
    fig.show(renderer)

from typing import List
def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):


    y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]
    error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs] 

    fig = go.Figure(data=[go.Bar(
        x=x_axis,
        y=y_data,
        error_y=dict(
            type='data',  # specifies that the actual values are given
            array=error_y_data,  # the magnitudes of the errors
            visible=True  # make error bars visible
        ),
    )])

    # Customize layout
    fig.update_layout(title_text=f'Logit Diff after Interventions',
                    xaxis_title_text='Intervention',
                    yaxis_title_text='Logit diff',
                    plot_bgcolor='white')

    # Show the figure
    fig.show()

In [3]:
import os

os.environ['HF_HOME'] = '/workspace/huggingface/'

if torch.cuda.is_available():
    device = "cuda"
# elif torch.backends.mps.is_available():
#     device = "mps"
else: 
    device = "cpu"
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fa59dfd8550>

### Load model and data

In [83]:
from sae_lens import HookedSAETransformer

model: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2").to(device)
#model.set_use_attn_in(True)
#model.set_use_attn_result(True)
#model.set_use_hook_mlp_in(True)
#model.set_use_split_qkv_input(True)



Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda


In [134]:
import json

with open('tasks/ioi/task.json') as f:
    task = json.load(f)

In [135]:
task['prompts'][0]

{'clean_prompt': 'When John and Mary went to the store, Mary gave a drink to',
 'corr_prompt': 'When John and Mary went to the store, John gave a drink to',
 'variables': {'IO': 'John', 'S1': 'Mary', 'S2': 'Mary', 'Pos': 'ABB'}}

In [329]:
from functools import partial
import re

def logits_diff(logits, correct_answer, incorrect_answer):
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

def zero_abl_hook(x, hook, pos, head):
    x[:, pos, head] = 0
    return x

def patching_hook(x, hook, pos, head, corr):
    x[:, pos, head] = corr[:, pos, head]
    return x

def feature_patching_hook(x, hook, pos, head, f_in, f_out):
    x[:, pos, head] = x[:, pos, head] + f_in[None, pos, head] - f_out[None, pos, head]
    return x

class TransformerCircuit:
    def __init__(self, model, task):
        self.model = model
        self.task = task

    def get_node(self, node_name):
        for node in self.task['nodes']:
            if node['name'] == node_name:
                return node
        return None

    def get_variable(self, variable_name):
        for variable in self.task['variables']:
            if variable['name'] == variable_name:
                return variable
        return None

    @classmethod
    def read_variable(self, x):
        if '+' in x:
            offset = int(x.split('+')[-1])
        elif '-' in x:
            offset = int(x.split('-')[-1])
        else:
            offset = 0

        pattern = r"\{([^}]*)\}"
        return re.findall(pattern, x)[0], offset

class IOICircuit(TransformerCircuit):
    def __init__(self, model, cfg):
        super().__init__(model, cfg)

    def run_with_patch(self, prompt, node_names, component_name, method='zero', patches=None):
        assert method in ['zero', 'corr', 'feature'], "Method must be either 'zero', 'corr', 'feature'"
        assert component_name in ['q', 'k', 'v', 'z'], "Method must be either 'result', 'q', 'k', 'v'"
        
        io = ' '+prompt['variables']['IO']
        s1 = ' '+prompt['variables']['S1']
        pos = prompt['variables']['Pos']
        pos = 0 if pos == "ABB" else 1

        clean_prompt = prompt['clean_prompt']
        clean_tokens = model.to_tokens(clean_prompt)

        if method == 'corr':
            corr_prompt = prompt['corr_prompt']
            corr_tokens = model.to_tokens(corr_prompt)
            
            with torch.no_grad():
                _, corr_cache = self.model.run_with_cache(corr_tokens)
        elif method == 'feature':
            f_in = patches[0]
            f_out = patches[1]

        with torch.no_grad():
            clean_logits = self.model(clean_tokens)
        
        nodes = [self.get_node(node_name) for node_name in node_names]
        hooks = []

        for node in nodes:
            if component_name in ['q', 'result']:
                var, offset = self.read_variable(node['q'])
            else:
                var, offset = self.read_variable(node['kv'])

            var_pos = self.get_variable(var)['position'][pos] + offset

            for head in node['heads']:
                l, h = head.split('.')
                hook_name = utils.get_act_name(component_name, int(l))
                
                if method == 'zero':
                    hook_fn = partial(zero_abl_hook, pos=var_pos, head=int(h))
                elif method == 'corr':
                    hook_fn = partial(patching_hook, pos=var_pos, head=int(h), corr=corr_cache[hook_name])
                elif method == 'feature':
                    hook_fn = partial(feature_patching_hook, pos=var_pos, head=int(h), f_in=f_in[int(l)], f_out=f_out[int(l)])

                hooks.append((hook_name, hook_fn))

        with torch.no_grad():
            patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=hooks)

        return logits_diff(clean_logits, io, s1) - logits_diff(patched_logits, io, s1)

In [304]:
ioi_circuit = IOICircuit(model, task)

In [305]:
task['prompts'][0]

{'clean_prompt': 'When John and Mary went to the store, Mary gave a drink to',
 'corr_prompt': 'When John and Mary went to the store, John gave a drink to',
 'variables': {'IO': 'John', 'S1': 'Mary', 'S2': 'Mary', 'Pos': 'ABB'}}

In [306]:
ioi_circuit.run_with_patch(task['prompts'][0], ['IH'], 'q', method='corr')

tensor(0.8055, device='cuda:0')

## Eval

In [165]:
from tqdm.auto import tqdm

def get_task_activations(prompts, bs=64):
    activations = {i: [] for i in ['q', 'k', 'v', 'z']}
    for b in tqdm(range(0, len(prompts), bs)):
        tokens = model.to_tokens(prompts[b:b+bs])
        
        with torch.no_grad():
            _, cache = model.run_with_cache(tokens)

        for key in activations.keys():
            activations[key].append(cache.stack_activation(key))

    return {key: torch.cat(values, 1) for key, values in activations.items()}

In [167]:
prompts = [p['clean_prompt'] for p in task['prompts']]
activations = get_task_activations(prompts)

  0%|          | 0/15 [00:00<?, ?it/s]

In [168]:
activations['k'].shape

torch.Size([12, 930, 15, 12, 64])

In [175]:
import pandas as pd

task_df = {
    'prompt': [],
    'IO': [],
    'S1': [],
    'S2': [],
    'Pos': [],
    'IO_pos': [],
    'S1_pos': [],
    'S2_pos': []
}

for i, prompt in enumerate(task['prompts']):
    task_df['prompt'].append(prompt['clean_prompt'])
    task_df['IO'].append(prompt['variables']['IO'])
    task_df['S1'].append(prompt['variables']['S1'])
    task_df['S2'].append(prompt['variables']['S2'])

    pos = prompt['variables']['Pos']
    pos = 0 if pos == "ABB" else 1
    task_df['Pos'].append(pos)

    io_pos = ioi_circuit.get_variable('IO')['position'][pos]
    s1_pos = ioi_circuit.get_variable('S1')['position'][pos]
    s2_pos = ioi_circuit.get_variable('S2')['position'][pos]

    task_df['IO_pos'].append(io_pos)
    task_df['S1_pos'].append(s1_pos)
    task_df['S2_pos'].append(s2_pos)

task_df = pd.DataFrame(task_df)

In [227]:
task_df

Unnamed: 0,prompt,IO,S1,S2,Pos,IO_pos,S1_pos,S2_pos
0,"When John and Mary went to the store, Mary gav...",John,Mary,Mary,0,2,4,10
1,"When Paul and John went to the store, Paul gav...",John,Paul,Paul,1,4,2,10
2,"When Anna and John went to the store, Anna gav...",John,Anna,Anna,1,4,2,10
3,"When Mark and John went to the store, Mark gav...",John,Mark,Mark,1,4,2,10
4,"When John and Lucy went to the store, Lucy gav...",John,Lucy,Lucy,0,2,4,10
...,...,...,...,...,...,...,...,...
925,"When Megan and Molly went to the store, Megan ...",Molly,Megan,Megan,1,4,2,10
926,"When Ryan and Molly went to the store, Ryan ga...",Molly,Ryan,Ryan,1,4,2,10
927,"When Julie and Molly went to the store, Julie ...",Molly,Julie,Julie,1,4,2,10
928,"When Molly and Steve went to the store, Steve ...",Molly,Steve,Steve,0,2,4,10


In [330]:
from sklearn.linear_model import LogisticRegression

names = ioi_circuit.get_variable('IO')['values']

def ioi_supervised_dictionary(df, activations):
    centered_activations = activations - activations.mean(2).mean(1)[:, None, None]
    print(centered_activations.shape)
    
    # IO
    io_vec = {}
    for name in names:
        mask = df['IO'] == name
        #idx = df.loc[mask].index.tolist()
        #if len(idx) > 0:
        #    pos = df.loc[mask, 'IO_pos'].tolist()
        #    features = torch.cat([centered_activations[:, i, p, None] for i, p in zip(idx, pos)], 1)
        #    io_vec[name] = features.mean(1)
        io_vec[name] = centered_activations[:, mask].mean(1)

    # S
    s1_vec = {}
    s2_vec = {}
    for name in names:
        mask = df['S1'] == name
        #idx = df.loc[mask].index.tolist()
        #if len(idx) > 0:
        #    pos = df.loc[mask, 'S1_pos'].tolist()
        #    features = torch.cat([centered_activations[:, i, p, None] for i, p in zip(idx, pos)], 1)
        #    s1_vec[name] = features.mean(1)
        #
        #    pos = df.loc[mask, 'S2_pos'].tolist()
        #    features = torch.cat([centered_activations[:, i, p, None] for i, p in zip(idx, pos)], 1)
        #    s2_vec[name] = features.mean(1)
        s1_vec[name] = centered_activations[:, mask].mean(1)
        s2_vec[name] = centered_activations[:, mask].mean(1)

    # Pos
    lr = LogisticRegression(penalty=None)
    X = centered_activations[:, :, -1]
    y = df['Pos'].to_numpy()
    L, N, H, *_ = X.shape
    
    pos_vec = []
    #for l in range(L):
    #    pos_vec.append([])
    #    for h in range(H):
    #        lr.fit(X[l, :, h].cpu().numpy(), y)
    #        pos_vec[l].append(lr.coef_[0])
#
    #pos_vec = torch.tensor(pos_vec)

    return (io_vec, s1_vec, s2_vec, pos_vec)

In [331]:
io_vec, s1_vec, s2_vec, pos_vec = ioi_supervised_dictionary(task_df, activations['k'])

torch.Size([12, 930, 15, 12, 64])


In [332]:
idx = 1

example = task['prompts'][idx]
io = example['variables']['IO']
s2 = example['variables']['S2']

in_patch = s2_vec[io]
out_patch = s2_vec[s2]

In [334]:
nodes = ['NMH', 'bNMH']

f_patch = ioi_circuit.run_with_patch(example, nodes, 'q', method='feature', patches=(in_patch, out_patch))
c_patch = ioi_circuit.run_with_patch(example, nodes, 'q', method='corr')
z_patch = ioi_circuit.run_with_patch(example, nodes, 'q', method='zero')

In [335]:
print(f_patch.item(), c_patch.item(), z_patch.item())

0.19946670532226562 5.596443176269531 2.8782215118408203


tensor(-0.1915, device='cuda:0')

tensor(1.6125, device='cuda:0')