In [1]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface'

from transformer_lens import HookedTransformer, ActivationCache, utils
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial
import ast
tqdm.pandas()

init_notebook_mode(connected=True)

Device: cuda


### Model loading

In [2]:
model_name = 'gemma-2b'

In [3]:
model = HookedTransformer.from_pretrained(model_name, torch_dtype=torch.float32, n_devices=1)

model.eval()
model.set_use_attn_result(True)
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [4]:
def generate_until_stop(prompt, stop_tokens, max_tokens=64, verbose=False, prepend_bos=True):
    if isinstance(stop_tokens[0], str):
        stop_tokens = [model.to_single_token(tok) for tok in stop_tokens]
        
    tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
    gen = True
    while gen:
        with torch.no_grad():
            new_tok = model(tokens).argmax(-1)[:, -1]
        
        if verbose: print(model.to_string(new_tok), end='')
        tokens = torch.cat([tokens, new_tok[None].to(tokens.device)], dim=-1)
        if new_tok.item() in stop_tokens or max_tokens == 0:
            gen = False
        max_tokens -= 1

    return model.to_string(tokens)[0]

### Data loading

In [5]:
def check_cot(x):
    try:
        check = all([x == y for x, y in zip(x['cot_gold'], x['cot_pred'])])
    except: check = False

    return check

In [6]:
n_shots = "7shots"

result = pd.read_csv(f'results/results_{n_shots}.csv')
result['correct_pred'] = result['label'] == result['pred']
result['correct_cot'] = result.apply(check_cot, axis=1)

correct_preds = result[result['correct_pred'] & result['correct_cot']]

## Direct Logit Attribution

In [7]:
def compute_dla(prompt, component, a_clean, a_corr=None, prepend_bos=True):
    
    tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
    dlas = []

    with torch.no_grad():
        logits, cache = get_cache_fw(tokens, component)

    cache = ActivationCache(cache, model).to('cpu')
    act = cache.stack_activation(component)[:, :, -1]
    if len(act.shape) == 4:
        act = cache.stack_head_results(-1)[:, :, -1]

    dla = model.unembed(act.to(model.W_U.device)).cpu()
    del cache
    
    return dla[..., a_clean].mean(-1) - dla[..., a_corr].mean(-1)

In [8]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def plot_dla(resid_dla, mlp_dla, attn_dla, max_val=50):
    fig = make_subplots(rows=1, cols=3, subplot_titles=("Residual Stream", "MLP", "Attention Heads"))

    # Add images to the subplots
    fig.add_trace(px.imshow(resid_dla.detach().cpu(), zmin=-max_val, zmax=max_val).data[0], row=1, col=1)
    fig.add_trace(px.imshow(mlp_dla.detach().cpu(), zmin=-max_val, zmax=max_val).data[0], row=1, col=2)
    fig.add_trace(px.imshow(attn_dla.detach().cpu(), zmin=-max_val, zmax=max_val).data[0], row=1, col=3)

    # Update layout
    fig.update_layout(
        coloraxis=dict(
            colorscale='RdBu',
            cmin=-max_val,
            cmax=max_val
        ),
        height=600,
        width=1400,
        title_text="Direct Logit Attribution"
    )
    
    return fig

def plot_patterns(prompt, patterns, n_cols, query_offset, key_offset):
    
    n_rows = len(patterns) // n_cols + int(len(patterns) % n_cols != 0)
    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=patterns)

    tokens = model.to_tokens(prompt)
    str_tokens = model.to_str_tokens(prompt)
    labels = [f"{tok} ({i})" for i, tok in enumerate(str_tokens)]
    query_labels = labels[query_offset:]
    key_labels = labels[key_offset:]

    with torch.no_grad():
        logits, cache = get_cache_fw(tokens, 'attn')

    for i, pid in enumerate(patterns):
        layer, head = pid.split('H')
        layer = int(layer[1:])
        head = int(head[:-1])

        pattern = cache[f'blocks.{layer}.attn.hook_pattern'][0, head, query_offset:, key_offset:].cpu()

        row = i // n_cols + 1
        col = i % n_cols + 1
    
        fig.add_trace(px.imshow(
            pattern,
            labels=dict(x="Keys", y="Queries", color="Attention Score"),
            x=key_labels,
            y=query_labels
        ).data[0], row=row, col=col)

        fig.update_xaxes(tickangle=35)
        fig.update_layout(coloraxis_colorbar=dict(title="Score"))

    fig.update_layout(
        coloraxis=dict(
            colorscale='Blues',
            cmin=1,
            cmax=0
        ),
        height=700 * n_rows,
        width=800 * n_cols,
        title_text="Attention Patterns"
    )
    return fig

## Attribution patching

In [9]:
import einops

def logits_diff(logits, a_clean, a_corr=None):
    if isinstance(a_clean, str):
        a_clean = model.to_single_token(a_clean)
    if a_corr:
        if isinstance(a_corr, str):
            a_corr = [model.to_single_token(a_corr)]
        
        return logits[0, -1, a_clean] - logits[0, -1, a_corr].mean(-1)
    else:
        return logits[0, -1, a_clean]

def get_cache_fw(tokens, component):
    if component == 'all':
        filter = lambda name: "_input" not in name
    elif component == 'qkv':
        filter = lambda name: name.split('.')[-1].strip() in ['hook_q', 'hook_k', 'hook_v'] and "_input" not in name
    else:
        filter = lambda name: component in name

    model.reset_hooks()

    cache = {}
    def fw_cache_hook(act, hook):
        cache[hook.name] = act.detach()

    model.add_hook(filter, fw_cache_hook, "fwd")
    
    logits = model(tokens)
    model.reset_hooks()
    return logits, ActivationCache(cache, model)

def get_cache_fw_and_bw(tokens, a_clean, a_corr, corr_logits, component='all'):
    if component == 'all':
        filter = lambda name: "_input" not in name
    elif component == 'qkv':
        filter = lambda name: name.split('.')[-1].strip() in ['hook_q', 'hook_k', 'hook_v'] and "_input" not in name
    else:
        filter = lambda name: component in name
        
    model.reset_hooks()
    
    cache = {}
    def fw_cache_hook(act, hook):
        cache[hook.name] = act.detach()

    model.add_hook(filter, fw_cache_hook, "fwd")
    
    grad_cache = {}
    def bw_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()
    
    model.add_hook(filter, bw_cache_hook, "bwd")

    clean_logits = model(tokens).cpu()
    value = logits_diff(clean_logits, a_clean, a_corr) #- logits_diff(corr_logits.cpu(), a_clean, a_corr)
    value.backward()
    
    model.reset_hooks()
    return (
        value.item(),
        ActivationCache(cache, model),
        ActivationCache(grad_cache, model),
    )

def stack_qkv_results(cache):
    q = cache.stack_activation('q')
    k = cache.stack_activation('k')
    v = cache.stack_activation('v')

    c, b, pos, h, d = q.shape
    q = q.reshape(c * h, b, pos, d)
    k = k.reshape(c, b, pos, d)
    v = v.reshape(c, b, pos, d)

    return q, k, v

def attribution_patching(x_clean, x_corr, a_clean, a_corr, component='all', prepend_bos=True):

    if isinstance(x_clean, str):
        clean_tokens = model.to_tokens(x_clean, prepend_bos=prepend_bos)
    else:
        clean_tokens = x_clean

    if isinstance(x_corr, str):
        corr_tokens = model.to_tokens(x_corr, prepend_bos=prepend_bos)
    else:
        corr_tokens = x_corr

    if isinstance(a_clean, str):
        a_clean = model.to_single_token(a_clean)

    if isinstance(a_corr, str):
        a_corr = model.to_single_token(a_corr)

    with torch.no_grad():
        corr_logits, corr_cache = get_cache_fw(corr_tokens, component)
    
    corr_cache = corr_cache
    logits_diff_, clean_cache, clean_grad_cache = get_cache_fw_and_bw(clean_tokens, a_clean, a_corr, corr_logits, component=component)

    clean_grad_cache = ActivationCache(clean_grad_cache, model).to('cpu')
    clean_cache = ActivationCache(clean_cache, model).to('cpu')
    corr_cache = ActivationCache(corr_cache, model).to('cpu')

    corr_act = []
    clean_act = []
    clean_grad_act = []

    if component in ['all', 'resid']:
        corr_act.append(corr_cache.accumulated_resid(-1, incl_mid=True, return_labels=False))
        clean_act.append(clean_cache.accumulated_resid(-1, incl_mid=True, return_labels=False))
        clean_grad_act.append(clean_grad_cache.accumulated_resid(-1, incl_mid=True, return_labels=False))
    if component in ['all', 'mlp']:
        clean_act.append(clean_cache.stack_activation('mlp_out'))
        corr_act.append(corr_cache.stack_activation('mlp_out'))
        clean_grad_act.append(clean_grad_cache.stack_activation('mlp_out'))
    if component in ['all', 'attn']:
        clean_act.append(clean_cache.stack_head_results(-1))
        corr_act.append(corr_cache.stack_head_results(-1))
        clean_grad_act.append(clean_grad_cache.stack_head_results(-1))
    if component in ['qkv']:
        clean_q, clean_k, clean_v = stack_qkv_results(clean_cache)
        corr_q, corr_k, corr_v = stack_qkv_results(corr_cache)
        clean_grad_q, clean_grad_k, clean_grad_v = stack_qkv_results(clean_grad_cache)
        
        clean_act.append(clean_q)
        corr_act.append(corr_q)
        clean_grad_act.append(clean_grad_q)
        clean_act.append(clean_k)
        corr_act.append(corr_k)
        clean_grad_act.append(clean_grad_k)
        clean_act.append(clean_v)
        corr_act.append(corr_v)
        clean_grad_act.append(clean_grad_v)

    patches = []
    for corr, clean, clean_grad in zip(corr_act, clean_act, clean_grad_act):
        patches.append(einops.reduce(
                clean_grad * (clean - corr),
                "component batch pos d_model -> component pos",
                "sum",
            ))
    if len(patches) == 1:
        return patches[0]
    else:
        return patches

def plot_atp(atp, x_clean, component, n_last_tokens=128, val=1, prepend_bos=True):

    str_tokens = model.to_str_tokens(x_clean, prepend_bos=prepend_bos)
    xs = [f"{tok} | {i}" for i, tok in enumerate(str_tokens[-n_last_tokens:])]
    
    if component in ['z', 'q']:
        ys = [f'L{i}H{j}' for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
    else:
        ys = [f"L{l} {component.upper()}" for l in range(model.cfg.n_layers)]
        
    fig = px.imshow(
        atp[:, -n_last_tokens:].cpu().numpy(), 
        x=xs,
        y=ys,
        color_continuous_scale='RdBu', zmin=-val, zmax=val, aspect='auto'
    )
    
    return fig

In [10]:
# IG
import torch
import einops
import gc
import sys

def attribution_patching(x_clean, x_corr, a_clean, a_corr, component, prepend_bos=True, method='standard', num_alphas=5, n_last_tokens=128):
    if isinstance(x_clean, str):
        clean_tokens = model.to_tokens(x_clean, prepend_bos=prepend_bos)
    else:
        clean_tokens = x_clean

    if isinstance(x_corr, str):
        corr_tokens = model.to_tokens(x_corr, prepend_bos=prepend_bos)
    else:
        corr_tokens = x_corr

    if isinstance(a_clean, str):
        a_clean = model.to_single_token(a_clean)

    if isinstance(a_corr, str):
        a_corr = model.to_single_token(a_corr)

    with torch.no_grad():
        corr_logits, corr_cache = get_cache_fw(corr_tokens, component)

    if method == 'standard':
        logits_diff_, clean_cache, clean_grad_cache = get_cache_fw_and_bw(clean_tokens, a_clean, a_corr, corr_logits, component=component)
    elif method == 'ig':
        with torch.no_grad():
            clean_logits, clean_cache = get_cache_fw(clean_tokens, component)

    clean_cache = ActivationCache(clean_cache, model).to('cpu')
    corr_cache = ActivationCache(corr_cache, model).to('cpu')

    corr_act = clean_cache.stack_activation(component)[:, 0, -n_last_tokens:]
    clean_act = corr_cache.stack_activation(component)[:, 0, -n_last_tokens:] # comp, pos dm
    del clean_cache, corr_cache
    
    if clean_act.ndim > 3:
        clean_act = clean_act.reshape(-1, clean_act.size(1), clean_act.size(3))
        corr_act = corr_act.reshape(-1, corr_act.size(1), corr_act.size(3))
        
    if method == 'standard':
        clean_grad_cache = ActivationCache(clean_grad_cache, model).to('cpu')
        clean_grad_act = clean_grad_cache.stack_activation(component).squeeze()
        if clean_grad_act.ndim > 3:
            clean_grad_act = clean_grad_act.reshape(-1, clean_grad_act.size(1), clean_grad_act.size(3))
        clean_grad_act = clean_grad_act[:, -n_last_tokens:].cpu()
    elif method == 'ig':
        clean_grad_act = []
        alphas = torch.linspace(0, 1, num_alphas)
        k = clean_act.shape[0] // model.cfg.n_layers
        for l in tqdm(range(model.cfg.n_layers)):
            ig_patch = torch.zeros_like(clean_act[k*l:k*(l+1)], device=clean_act.device)
            for alpha in alphas:
                a_alpha = alpha * clean_act[k*l:k*(l+1)] + (1 - alpha) * corr_act[k*l:k*(l+1)]
                logits_alpha, grad_alpha = get_cache_fw_with_modified_activations(clean_tokens, a_alpha, a_clean, a_corr, l, component)
                if grad_alpha.ndim > 3:
                    grad_alpha = grad_alpha.reshape(-1, grad_alpha.size(1), grad_alpha.size(3))
                grad_alpha = grad_alpha[:, -n_last_tokens:].cpu()
                ig_patch += grad_alpha * (clean_act[k*l:k*(l+1)] - corr_act[k*l:k*(l+1)])
                del a_alpha, logits_alpha, grad_alpha
            clean_grad_act.append(ig_patch / num_alphas)
            torch.cuda.empty_cache()
            gc.collect()
        clean_grad_act = torch.cat(clean_grad_act, dim=0)

    print("Gradients collected! Computing the patch...")
    patch = einops.reduce(
        clean_grad_act * (corr_act - clean_act),
        "component pos d_model -> component pos",
        "sum",
    )
    del clean_act, corr_act, clean_grad_act
    torch.cuda.empty_cache()

    return patch

def get_cache_fw(tokens, component):
    filter = lambda name: utils.get_act_name(component) in name

    model.reset_hooks()

    cache = {}
    def fw_cache_hook(act, hook):
        cache[hook.name] = act.detach()

    model.add_hook(filter, fw_cache_hook, "fwd")
    logits = model(tokens)
    model.reset_hooks()
    return logits, ActivationCache(cache, model)

def get_cache_fw_with_modified_activations(tokens, x_int, a_clean, a_corr, layer, component):
    hook_point = utils.get_act_name(component, layer)
    model.reset_hooks()
    
    def fw_hook(act, mod_act, hook):
        act = mod_act

    fw_hook_fn = partial(fw_hook, mod_act=x_int.squeeze())
    model.add_hook(hook_point, fw_hook_fn, "fwd")
    
    grad_cache = {}
    def bw_cache_hook(act, hook):
        grad_cache[hook.name] = act.detach()
    
    model.add_hook(hook_point, bw_cache_hook, "bwd")
    logits = model(tokens)
    value = logits_diff(logits, a_clean, a_corr)
    value.backward()
    
    model.reset_hooks()
    return value.item(), grad_cache[hook_point]

## Subtasks

We then explore each subtask mechanistically to understand which are the components responsible for each choice made by the model.

### S1 - Choosing the right species
The first step is choosing the right species to focus on. This is a key step since it preceeds the attribute check. It is also the most difficult one since the choice doesn't depend only on the species and the entity, but has to be made already considering the attribute.

In [11]:
all_species = ['grimpus', 'lorpus', 'wumpus', 'zumpus', 'sterpus', 'numpus', 'jompus', 'brimpus', 'yumpus', 'tumpus', 'dumpus', 'vumpus', 'rompus', 'lempus', 'gorpus', 'shumpus', 'impus']

In [24]:
idx = 2
x_clean = correct_preds['prompt'].iloc[idx]
cot_gold = correct_preds['cot_gold'].iloc[idx]
label = correct_preds['label'].iloc[idx]
cot_gold = ast.literal_eval(cot_gold)
print(x_clean)

Answer True or False to the following question. Answer as in the examples.

Wumpuses are shumpuses. Sterpuses are not discordant. Each yumpus is feisty. Each wumpus is a yumpus. Each yumpus is a lempus. Each yumpus is a sterpus. Each sterpus is a vumpus. Lempuses are fast. Shumpuses are large. Impuses are not metallic. Each zumpus is dull. Each zumpus is a brimpus. Every grimpus is an impus. Every grimpus is a gorpus. Vumpuses are not hot. Grimpuses are sunny. Wumpuses are not mean. Sterpuses are grimpuses. Fae is a zumpus. Fae is a grimpus.
Question: Is Fae sunny?
Think step-by-step.

(1) Fae is a grimpus.
(2) Grimpuses are sunny.
(3) Fae is sunny.
Answer: True

Impuses are lorpuses. Jompuses are not windy. Each brimpus is a jompus. Each sterpus is opaque. Every jompus is an impus. Every numpus is cold. Impuses are lempuses. Lempuses are not happy. Each impus is luminous. Numpuses are wumpuses. Brimpuses are dull. Jompuses are sterpuses. Every zumpus is melodic. Every brimpus is a zum

In [25]:
example = x_clean.split('\n\n')[-2]
context, question = example.split('Question: ')

subject = question.split()[1]
species = [' ' + x.strip().split()[-1] for x in context.split('.') if subject in x]
species_token = [model.to_tokens(s, prepend_bos=False)[:, 0] for s in species]

for id_, s in enumerate(species):
    if s in cot_gold[0]:
        break

a_clean = species_token[id_].cpu()
a_corr = torch.cat(species_token[:id_] + species_token[id_+1:]).cpu()

In [26]:
stop_tokens = [' a',]

print("Generating...")
clean_out = generate_until_stop(x_clean, stop_tokens, prepend_bos=True)

Generating...


In [15]:
resid_dla = compute_dla(clean_out, 'resid_pre', a_clean, a_corr, prepend_bos=False)
mlp_dla = compute_dla(clean_out, 'mlp_out', a_clean, a_corr, prepend_bos=False)
attn_dla = compute_dla(clean_out, 'result', a_clean, a_corr, prepend_bos=False)

In [None]:
fig = plot_dla(resid_dla, mlp_dla, attn_dla)
fig.update_layout(title_text=f"Direct Logit Attribution (Subtask 1) | {species[id_]} -{species[1-id_]}")
fig.show()
fig.write_html('fig/s1_DLA.html')

In [None]:
patterns = ['L10H7+', 'L14H0+', 'L14H4+']
fig = plot_patterns(clean_out, patterns, n_cols=2, query_offset=1000, key_offset=1000)
fig.show()
fig.write_html('fig/s1_patterns.html')

In [27]:
# AtP
assert len(species) == 2, "More than two species detected!"
assert len(model.to_tokens(species[0])[0]) == len(model.to_tokens(species[1])[0]), "Species with different token length!"

In [28]:
cot_corr = cot_gold.copy()

cot_corr = [step.lower().replace(species[id_][1:], species[1-id_][1:]).capitalize() for step in cot_gold]
context = x_clean.split('\n\n')

context[-2] = context[-2].replace(cot_gold[1], cot_corr[1])
x_corr = '\n\n'.join(context)

In [29]:
corr_out = generate_until_stop(x_corr, stop_tokens)

In [30]:
atp_q = attribution_patching(clean_out, corr_out, a_clean, a_corr, component='q', prepend_bos=False, method='standard')
atp_k = attribution_patching(clean_out, corr_out, a_clean, a_corr, component='k', prepend_bos=False, method='standard')
atp_v = attribution_patching(clean_out, corr_out, a_clean, a_corr, component='v', prepend_bos=False, method='standard')

Gradients collected! Computing the patch...
Gradients collected! Computing the patch...
Gradients collected! Computing the patch...


In [31]:
from plotly.subplots import make_subplots

def plot_qkv_atp(atp_q, atp_k, atp_v, **kwargs):
    fig = make_subplots(rows=3, cols=1, subplot_titles=("Queries", "Keys", "Values"), shared_xaxes=True, vertical_spacing=0.05)
    
    for i, (atp, hook) in enumerate(zip([atp_q, atp_k, atp_v], ['q', 'k', 'v'])):
        plot = plot_atp(atp, clean_out, component=hook, prepend_bos=False)
        for trace in plot.data:
            fig.add_trace(trace, row=1+i, col=1)
        
    fig.update_layout(
        coloraxis1=dict(colorscale='RdBu', cmin=-0.5, cmax=0.5),
        showlegend=False,
        **kwargs
    )
    
    return fig

In [32]:
fig = plot_qkv_atp(atp_q, atp_k, atp_v)
fig.write_html("fig/s1_AtP_qkv.html")

In [33]:
atp_q = attribution_patching(clean_out, corr_out, a_clean, a_corr, component='q', prepend_bos=False, method='ig')
atp_k = attribution_patching(clean_out, corr_out, a_clean, a_corr, component='k', prepend_bos=False, method='ig')
atp_v = attribution_patching(clean_out, corr_out, a_clean, a_corr, component='v', prepend_bos=False, method='ig')

100%|██████████| 18/18 [02:48<00:00,  9.37s/it]


Gradients collected! Computing the patch...


100%|██████████| 18/18 [01:53<00:00,  6.30s/it]


Gradients collected! Computing the patch...


100%|██████████| 18/18 [01:53<00:00,  6.31s/it]

Gradients collected! Computing the patch...





In [34]:
fig = plot_qkv_atp(atp_q, atp_k, atp_v)
fig.write_html(f"fig/s1_AtP_qkv_ig.html")

### S2 - Attribute check

In [47]:
stop_tokens = ['2']
clean_out = generate_until_stop(x_clean, stop_tokens)
stop_tokens = [' is', ' are']
clean_out = generate_until_stop(clean_out, stop_tokens, prepend_bos=False)

In [48]:
cot_steps = cot_gold[1].split()
for i, step in enumerate(cot_steps):
    if "is" in step or "are" in step:
        attribute = cot_steps[i+1].replace('.', '')
        a_clean = ' ' + attribute

a_corr = [model.to_single_token(tok) for tok in [' not',]]
a_clean = [model.to_single_token(a_clean)]

In [49]:
clean_tokens = model.to_tokens(clean_out, prepend_bos=False)
with torch.no_grad():
    clean_logits = model(clean_tokens).cpu()

clean_logit_diff = logits_diff(clean_logits, a_clean, a_corr)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

Clean logit difference: 3.683


In [None]:
resid_dla, mlp_dla, attn_dla = compute_dla(clean_out, a_clean, a_corr, prepend_bos=False)

In [None]:
fig = plot_dla(resid_dla, mlp_dla, attn_dla)
fig.update_layout(title_text=f"Direct Logit Attribution (Subtask 2) | {attribute} - not")
fig.show()
fig.write_html('fig/s2_DLA.html')

In [None]:
patterns = ['L8H6-', 'L10H5+', 'L10H7+', 'L13H2+', 'L13H4+', 'L14H0+', 'L14H3-', 'L14H4+', 'L16H0+', 'L16H4+']
fig = plot_patterns(clean_out, patterns, n_cols=3, query_offset=1000, key_offset=1000)
fig.show()
fig.write_html('fig/s2_patterns.html')

In [60]:
# AtP
import random
def not_a_species(x):
    for s in all_species:
        if s in x.lower():
            return False
    return True

def is_single_token(x):
    try:
        model.to_single_token(x)
        return True
    except:
        return False

icl_examples = '\n\n'.join(clean_out.split('\n\n')[:-2])
test_example = '\n\n'.join(clean_out.split('\n\n')[-2:])
context, clean_question = test_example.split('Question: ')
clues = context.split('. ')

for s in all_species:
    if s in clean_question.lower(): 
        s_star = s
        break

a_clean = ' ' + clean_question.split()[2][:-1]
other_attributes = [' '+c.split()[-1] for c in clues if not_a_species(c.split()[-1])]
other_attributes = [a for a in other_attributes if is_single_token(a) and a != a_clean]
a_corr = other_attributes[random.randint(0, len(other_attributes)-1)]
corr_question = clean_question.replace(a_clean, a_corr)

for i, c in enumerate(clues):
    if a_corr in c:
        for s in all_species:
            if s in c.lower(): break

        clues[i] = c.lower().replace(s, s_star).capitalize()

clean_out_new = icl_examples + '\n\n' + '. '.join(clues) + clean_question
corr_out_new = icl_examples + '\n\n' + '. '.join(clues) + corr_question

In [54]:
atp = attribution_patching(clean_out_new, corr_out_new, a_clean, a_corr, component='resid', prepend_bos=False)

In [58]:
fig = plot_atp(atp, clean_out_new, component="resid", prepend_bos=False)
fig.write_html("fig/s2_AtP_resid.html")

### S3 - The right connection

In [59]:
example = x_clean.split('\n\n')[-2]
context, question = example.split('Question: ')

subject = question.split()[1]
attribute = question.split()[2][:-1]

if 'not' in cot_gold[2]:
    a_clean = [model.to_single_token(' not')]
    a_corr = [model.to_single_token(' ' + attribute)]
    clean_label = 'not'
    corr_label = attribute
else:
    a_clean = [model.to_single_token(' ' + attribute)]
    a_corr = [model.to_single_token(' not')]
    clean_label = attribute
    corr_label = 'not'

In [24]:
stop_tokens = ['3']
clean_out = generate_until_stop(x_clean, stop_tokens)
stop_tokens = [' is']
clean_out = generate_until_stop(clean_out, stop_tokens, prepend_bos=False)

In [25]:
clean_tokens = model.to_tokens(clean_out, prepend_bos=False)
with torch.no_grad():
    clean_logits = model(clean_tokens).cpu()

clean_logit_diff = logits_diff(clean_logits, a_clean, a_corr)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

Clean logit difference: 0.566


In [46]:
resid_dla, mlp_dla, attn_dla = compute_dla(clean_out, a_clean, a_corr, prepend_bos=False)

In [None]:
fig = plot_dla(resid_dla, mlp_dla, attn_dla)
fig.update_layout(title_text=f"Direct Logit Attribution (Subtask 3) | {clean_label} - {corr_label}")
fig.show()
fig.write_html('fig/s3_DLA.html')

In [None]:
patterns = ['L12H0+', 'L14H0+', 'L14H4-', 'L14H5+', 'L14H6-', 'L16H7-', 'L17H7+']
fig = plot_patterns(clean_out, patterns, n_cols=3, query_offset=1000, key_offset=1000)
fig.show()
fig.write_html('fig/s3_patterns.html')

### S4 - Answering

In [45]:
clean_out = generate_until_stop(x_clean, stop_tokens=[':'])
a_clean = model.to_single_token(' ' + str(label))
a_corr = model.to_single_token(' False' if label else ' True')

In [46]:
clean_tokens = model.to_tokens(clean_out, prepend_bos=False)
with torch.no_grad():
    clean_logits = model(clean_tokens).cpu()

clean_logit_diff = logits_diff(clean_logits, a_clean, a_corr)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

Clean logit difference: 0.200


In [49]:
resid_dla, mlp_dla, attn_dla = compute_dla(clean_out, [a_clean], [a_corr], prepend_bos=False)

In [None]:
fig = plot_dla(resid_dla, mlp_dla, attn_dla)
fig.update_layout(title_text=f"Direct Logit Attribution (Subtask 4) | '{model.to_string(a_clean)}' - '{model.to_string(a_corr)}'")
fig.show()
fig.write_html('fig/s4_DLA.html')

### Patterns

In [None]:
patterns = ['L9H2+', 'L9H4+', 'L10H7+', 'L11H6-' 'L14H0+', 'L14H1-', 'L15H1+', 'L15H4-', 'L17H2+', 'L17H7-']
fig = plot_patterns(clean_out, patterns, n_cols=3, query_offset=1000, key_offset=1000)
fig.show()
fig.write_html('fig/s1_patterns.html')

## Aggregate measures

### S1

In [20]:
import ast
stop_tokens = [' a',]

resid_dlas = []
mlp_dlas = []
attn_dlas = []

for idx in tqdm(range(len(correct_preds))):
    x_clean = correct_preds['prompt'].iloc[idx]
    cot_gold = correct_preds['cot_gold'].iloc[idx]
    cot_gold = ast.literal_eval(cot_gold)

    example = x_clean.split('\n\n')[-2]
    context, question = example.split('Question: ')
    
    subject = question.split()[1]
    species = [' ' + x.strip().split()[-1] for x in context.split('.') if subject in x]
    species_token = [model.to_tokens(s, prepend_bos=False)[:, 0] for s in species]
    
    for id_, s in enumerate(species):
        if s in cot_gold[0]:
            break
    
    a_clean = species_token[id_].cpu()
    a_corr = torch.cat(species_token[:id_] + species_token[id_+1:]).cpu()

    clean_out = generate_until_stop(x_clean, stop_tokens, prepend_bos=True)
    resid_dla, mlp_dla, attn_dla = compute_dla(clean_out, a_clean, a_corr, prepend_bos=False)
    resid_dlas.append(resid_dla.cpu())
    mlp_dlas.append(mlp_dla.cpu())
    attn_dlas.append(attn_dla.cpu())
    del resid_dla, mlp_dla, attn_dla

100%|██████████| 26/26 [02:41<00:00,  6.22s/it]


In [23]:
resid_dla_agg = torch.stack(resid_dlas).mean(0)
mlp_dla_agg = torch.stack(mlp_dlas).mean(0)
attn_dla_agg = torch.stack(attn_dlas).mean(0)

In [None]:
fig = plot_dla(resid_dla_agg, mlp_dla_agg, attn_dla_agg)
fig.update_layout(title_text=f"Aggregated Direct Logit Attribution (Subtask 1) | {species[id_]} -{species[1-id_]}")
fig.show()
fig.write_html('fig/s1_DLA_agg.html')

In [None]:
stop_tokens = [' a',]
components = ['resid', 'mlp_out', 'attn']

resid_atps = []
mlp_atps = []
attn_atps = []

for idx in tqdm(range(len(correct_preds))):
    for component in components:
        x_clean = correct_preds['prompt'].iloc[idx]
        cot_gold = correct_preds['cot_gold'].iloc[idx]
        cot_gold = ast.literal_eval(cot_gold)
    
        example = x_clean.split('\n\n')[-2]
        context, question = example.split('Question: ')
        
        subject = question.split()[1]
        species = [' ' + x.strip().split()[-1] for x in context.split('.') if subject in x]
        species_token = [model.to_tokens(s, prepend_bos=False)[:, 0] for s in species]
        
        for id_, s in enumerate(species):
            if s in cot_gold[0]:
                break
        
        a_clean = species_token[id_].cpu()
        a_corr = torch.cat(species_token[:id_] + species_token[id_+1:]).cpu()
    
        try:
            assert len(species) == 2, "More than two species detected!"
            assert len(model.to_tokens(species[0])[0]) == len(model.to_tokens(species[1])[0]), "Species with different token length!"
    
            cot_corr = cot_gold.copy()
            
            cot_corr = [step.lower().replace(species[id_][1:], species[1-id_][1:]).capitalize() for step in cot_gold]
            context = x_clean.split('\n\n')
            
            context[-2] = context[-2].replace(cot_gold[1], cot_corr[1])
            x_corr = '\n\n'.join(context)
    
            clean_out = generate_until_stop(x_clean, stop_tokens)    
            corr_out = generate_until_stop(x_corr, stop_tokens)

            assert len(corr_tokens[0]) == len(clean_tokens[0]), f"Clean and corrupted tokens have different length, {len(clean_tokens[0])} and {len(corr_tokens[0])}, respectively."
    
            atp = attribution_patching(clean_out, corr_out, a_clean, a_corr, component=component, prepend_bos=False)

            if 'resid' in component:
                resid_atps.append(atp)
            elif 'mlp' in component:
                mlp_atps.append(atp)
            elif 'attn' in component:
                attn_atps.append(atp)
        
        except Exception as e: print(e)

In [58]:
resid_atp_agg = torch.stack([atp.max(dim=-1).values for atp in resid_atps]).mean(0).unsqueeze(-1)
mlp_atp_agg = torch.stack([atp.max(dim=-1).values for atp in mlp_atps]).mean(0).unsqueeze(-1)
attn_atp_agg = torch.stack([atp.max(dim=-1).values for atp in attn_atps]).mean(0).reshape(model.cfg.n_layers, -1)

In [None]:
fig = plot_dla(resid_atp_agg, mlp_atp_agg, attn_atp_agg, max_val=1)
fig.update_layout(title_text=f"Aggregated Attribution Patching (Subtask 1)")
fig.show()
fig.write_html('fig/s1_AtP_agg.html')

In [None]:
# Recursive AtP

### S2

In [27]:
resid_dlas = []
mlp_dlas = []
attn_dlas = []

for idx in tqdm(range(len(correct_preds))):
    x_clean = correct_preds['prompt'].iloc[idx]
    cot_gold = correct_preds['cot_gold'].iloc[idx]
    cot_gold = ast.literal_eval(cot_gold)
        
    stop_tokens = ['2']
    clean_out = generate_until_stop(x_clean, stop_tokens)
    stop_tokens = [' is', ' are']
    clean_out = generate_until_stop(clean_out, stop_tokens, prepend_bos=False)

    cot_steps = cot_gold[1].split()
    try:
        for i, step in enumerate(cot_steps):
            if "is" in step or "are" in step:
                attribute = cot_steps[i+1].replace('.', '')
                a_clean = ' ' + attribute
        
        a_corr = [model.to_single_token(tok) for tok in [' not',]]
        a_clean = [model.to_single_token(a_clean)]
        
        resid_dla, mlp_dla, attn_dla = compute_dla(clean_out, a_clean, a_corr, prepend_bos=False)
        resid_dlas.append(resid_dla.cpu())
        mlp_dlas.append(mlp_dla.cpu())
        attn_dlas.append(attn_dla.cpu())
        del resid_dla, mlp_dla, attn_dla
    except: pass

100%|██████████| 26/26 [05:28<00:00, 12.62s/it]


In [28]:
resid_dla_agg = torch.stack(resid_dlas).mean(0)
mlp_dla_agg = torch.stack(mlp_dlas).mean(0)
attn_dla_agg = torch.stack(attn_dlas).mean(0)

In [None]:
fig = plot_dla(resid_dla_agg, mlp_dla_agg, attn_dla_agg)
fig.update_layout(title_text=f"Aggregated Direct Logit Attribution (Subtask 2) | attribute - not")
fig.show()
fig.write_html('fig/s2_DLA_agg.html')

In [18]:
import random
import ast
import einops

def not_a_species(x):
    for s in all_species:
        if s in x.lower():
            return False
    return True

def is_single_token(x):
    try:
        model.to_single_token(x)
        return True
    except:
        return False

In [None]:
stop_tokens = [' a',]
components = ['resid', 'mlp_out', 'attn']

resid_atps = []
mlp_atps = []
attn_atps = []

for idx in tqdm(range(len(correct_preds))):
    x_clean = correct_preds['prompt'].iloc[idx]
    cot_gold = correct_preds['cot_gold'].iloc[idx]
    cot_gold = ast.literal_eval(cot_gold)
        
    stop_tokens = ['2']
    clean_out = generate_until_stop(x_clean, stop_tokens)
    stop_tokens = [' is', ' are']
    clean_out = generate_until_stop(clean_out, stop_tokens, prepend_bos=False)

    icl_examples = '\n\n'.join(clean_out.split('\n\n')[:-2])
    test_example = '\n\n'.join(clean_out.split('\n\n')[-2:])
    context, clean_question = test_example.split('Question: ')
    clues = context.split('. ')
    
    for s in all_species:
        if s in clean_question.lower(): 
            s_star = s
            break
    
    a_clean = ' ' + clean_question.split()[2][:-1]
    other_attributes = [' '+c.split()[-1] for c in clues if not_a_species(c.split()[-1])]
    other_attributes = [a for a in other_attributes if is_single_token(a) and a != a_clean]
    a_corr = other_attributes[random.randint(0, len(other_attributes)-1)]
    corr_question = clean_question.replace(a_clean, a_corr)
    
    for i, c in enumerate(clues):
        if a_corr in c:
            for s in all_species:
                if s in c.lower(): break
    
            clues[i] = c.lower().replace(s, s_star).capitalize()
    
    clean_out_new = icl_examples + '\n\n' + '. '.join(clues) + clean_question
    corr_out_new = icl_examples + '\n\n' + '. '.join(clues) + corr_question

    for component in components:    
        try:
            assert len(corr_tokens[0]) == len(clean_tokens[0]), f"Clean and corrupted tokens have different length, {len(clean_tokens[0])} and {len(corr_tokens[0])}, respectively."
    
            atp = attribution_patching(clean_out_new, corr_out_new, a_clean, a_corr, component=component, prepend_bos=False)

            if 'resid' in component:
                resid_atps.append(atp)
            elif 'mlp' in component:
                mlp_atps.append(atp)
            elif 'attn' in component:
                attn_atps.append(atp)
        
        except Exception as e: print(e)

In [23]:
resid_atp_agg = torch.stack([atp.max(dim=-1).values for atp in resid_atps]).mean(0).unsqueeze(-1)
mlp_atp_agg = torch.stack([atp.max(dim=-1).values for atp in mlp_atps]).mean(0).unsqueeze(-1)
attn_atp_agg = torch.stack([atp.max(dim=-1).values for atp in attn_atps]).mean(0).reshape(model.cfg.n_layers, -1)

In [None]:
fig = plot_dla(resid_atp_agg, mlp_atp_agg, attn_atp_agg, max_val=1)
fig.update_layout(title_text=f"Aggregated Attribution Patching (Subtask 2)")
fig.show()
fig.write_html('fig/s2_AtP_agg.html')

In [136]:
clean_tokens = model.to_tokens(clean_out_new, prepend_bos=False)
with torch.no_grad():
    clean_logits = model(clean_tokens).cpu()

clean_logit_diff = logits_diff(clean_logits, a_clean, a_corr)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corr_tokens = model.to_tokens(corr_out_new, prepend_bos=False)
with torch.no_grad():
    corr_logits = model(corr_tokens).cpu()

corr_logit_diff = logits_diff(corr_logits, a_clean, a_corr)
print(f"Corrupted logit difference: {corr_logit_diff.item():.3f}")

Clean logit difference: 2.838
Corrupted logit difference: -3.524


### S3

In [32]:
resid_dlas = []
mlp_dlas = []
attn_dlas = []

for idx in tqdm(range(len(correct_preds))):
    x_clean = correct_preds['prompt'].iloc[idx]
    cot_gold = correct_preds['cot_gold'].iloc[idx]
    cot_gold = ast.literal_eval(cot_gold)
        
    example = x_clean.split('\n\n')[-2]
    context, question = example.split('Question: ')
    
    subject = question.split()[1]
    attribute = question.split()[2][:-1]

    try:
        if 'not' in cot_gold[2]:
            a_clean = [model.to_single_token(' not')]
            a_corr = [model.to_single_token(' ' + attribute)]
            clean_label = 'not'
            corr_label = attribute
        else:
            a_clean = [model.to_single_token(' ' + attribute)]
            a_corr = [model.to_single_token(' not')]
            clean_label = attribute
            corr_label = 'not'

        resid_dla, mlp_dla, attn_dla = compute_dla(clean_out, a_clean, a_corr, prepend_bos=False)
        resid_dlas.append(resid_dla.cpu())
        mlp_dlas.append(mlp_dla.cpu())
        attn_dlas.append(attn_dla.cpu())
        del resid_dla, mlp_dla, attn_dla
    except: pass

100%|██████████| 26/26 [02:22<00:00,  5.46s/it]


In [33]:
resid_dla_agg = torch.stack(resid_dlas).mean(0)
mlp_dla_agg = torch.stack(mlp_dlas).mean(0)
attn_dla_agg = torch.stack(attn_dlas).mean(0)

In [None]:
fig = plot_dla(resid_dla_agg, mlp_dla_agg, attn_dla_agg)
fig.update_layout(title_text=f"Aggregated Direct Logit Attribution (Subtask 3) | a1 - a2")
fig.show()
fig.write_html('fig/s3_DLA_agg.html')

### S4

In [56]:
resid_dlas = []
mlp_dlas = []
attn_dlas = []

for idx in tqdm(range(len(correct_preds))):
    x_clean = correct_preds['prompt'].iloc[idx]
    label = correct_preds['label'].iloc[idx]

    clean_out = generate_until_stop(x_clean, stop_tokens=[':'])
    a_clean = model.to_single_token(' ' + str(label))
    a_corr = model.to_single_token(' False' if label else ' True')
        
    try:
        resid_dla, mlp_dla, attn_dla = compute_dla(clean_out, a_clean, a_corr, prepend_bos=False)
        resid_dlas.append(resid_dla.cpu())
        mlp_dlas.append(mlp_dla.cpu())
        attn_dlas.append(attn_dla.cpu())
        del resid_dla, mlp_dla, attn_dla
    except Exception as e: print(e)

100%|██████████| 26/26 [08:43<00:00, 20.14s/it]


In [60]:
resid_dla_agg = torch.stack(resid_dlas).mean(0).unsqueeze(-1)
mlp_dla_agg = torch.stack(mlp_dlas).mean(0).unsqueeze(-1)
attn_dla_agg = torch.stack(attn_dlas).mean(0).reshape(model.cfg.n_layers, -1)

In [None]:
fig = plot_dla(resid_dla_agg, mlp_dla_agg, attn_dla_agg)
fig.update_layout(title_text=f"Aggregated Direct Logit Attribution (Subtask 4) | a1 - a2")
fig.show()
fig.write_html('fig/s4_DLA_agg.html')