## Circuit Components

In [84]:
from typing import Tuple, List
from functools import partial

import os
import numpy as np 
import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from einops import rearrange

import einops
from fancy_einsum import einsum

from utils.data_processing import (
    load_edge_scores_into_dictionary,
    get_ckpts,
    load_metrics,
    get_ckpts
)
from utils.backup_analysis import load_model, get_past_nmhs_for_checkpoints, compute_copy_score
from utils.data_utils import generate_data_and_caches
from utils.cspa_main import prepare_data
from path_patching_cm.ioi_dataset import IOIDataset
from path_patching_cm.path_patching import Node, path_patch

from utils.visualization import imshow_p, plot_attention_heads

#%%
def convert_head_names_to_tuple(head_name):
    head_name = head_name.replace('a', '')
    head_name = head_name.replace('h', '')
    layer, head = head_name.split('.')
    return (int(layer), int(head))

def collate_fn(ds, device):
    if not ds:
        return {}
    return {k: torch.stack([d[k] for d in ds], dim=0).to(device) for k in ds[0].keys()}

class BatchIOIDataset(IOIDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __len__(self):
        return self.N
    
    def __getitem__(self, i):
        return {'toks':self.toks[i], 'io_token_id': torch.tensor(self.io_tokenIDs[i]), 's_token_id': torch.tensor(self.s_tokenIDs[i]), **{f'{k}_pos':v[i] for k, v in self.word_idx.items()}}
    
##%

def make_s2i(layer, head):
    return Node(f'blocks.{layer}.attn.hook_z', layer, head)
def make_nmh(layer, head):
    return Node(f'blocks.{layer}.hook_q_input', layer, head)

from utils.head_metrics import S2I_head_metrics, S2I_token_pos

In [85]:
TASK = 'ioi'
PERFORMANCE_METRIC = 'logit_diff'
BASE_MODEL = "pythia-160m"
VARIANT = None #"EleutherAI/pythia-160m-attndropout"
MODEL_SHORTNAME = BASE_MODEL if not VARIANT else VARIANT[11:]
CACHE = "model_cache"
DATASET_SIZE = 100
SEED = 42
BATCH_SIZE = 70
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f4abd566890>

In [86]:
folder_path = f'results/graphs/{MODEL_SHORTNAME}/{TASK}'
df = load_edge_scores_into_dictionary(folder_path)

directory_path = 'results'
perf_metrics = load_metrics(directory_path)

ckpts = get_ckpts(schedule="exp_plus_detail")

# filter everything before 1000 steps
df = df[df['checkpoint'] >= 1000]

df[['source', 'target']] = df['edge'].str.split('->', expand=True)
len(df['target'].unique())

Processing file 1/153: results/graphs/pythia-160m/ioi/57000.json
Processing file 2/153: results/graphs/pythia-160m/ioi/141000.json
Processing file 3/153: results/graphs/pythia-160m/ioi/95000.json
Processing file 4/153: results/graphs/pythia-160m/ioi/107000.json
Processing file 5/153: results/graphs/pythia-160m/ioi/34000.json
Processing file 6/153: results/graphs/pythia-160m/ioi/6000.json
Processing file 7/153: results/graphs/pythia-160m/ioi/37000.json
Processing file 8/153: results/graphs/pythia-160m/ioi/39000.json
Processing file 9/153: results/graphs/pythia-160m/ioi/104000.json
Processing file 10/153: results/graphs/pythia-160m/ioi/59000.json
Processing file 11/153: results/graphs/pythia-160m/ioi/67000.json
Processing file 12/153: results/graphs/pythia-160m/ioi/111000.json
Processing file 13/153: results/graphs/pythia-160m/ioi/16.json
Processing file 14/153: results/graphs/pythia-160m/ioi/76000.json
Processing file 15/153: results/graphs/pythia-160m/ioi/1.json
Processing file 16/153:

445

In [87]:
experiment_metrics = torch.load(f'results/backup/{MODEL_SHORTNAME}/nmh_backup_metrics.pt')

In [88]:
from utils.cspa_main import get_result_mean, get_cspa_results_batched, get_performance_recovered

def get_cspa_for_head(model, data_toks, cspa_semantic_dict, layer, head, verbose=False):

    current_batch_size = 17 # Smaller values so we can check more checkpoints in a reasonable amount of time
    current_seq_len = 61

    result_mean = get_result_mean([(layer, head)], data_toks[:100, :], model, verbose=True)
    cspa_results_qk_ov = get_cspa_results_batched(
        model=model,
        toks=data_toks[:current_batch_size, :current_seq_len],
        max_batch_size=1,  # 50,
        negative_head=(layer, head),
        interventions=["ov", "qk"],
        only_keep_negative_components=True,
        K_unembeddings=0.05,  # most interesting in range 3-8 (out of 80)
        K_semantic=1,  # either 1 or up to 8 to capture all sem similar
        semantic_dict=cspa_semantic_dict,
        result_mean=result_mean,
        use_cuda=True,
        verbose=True,
        compute_s_sstar_dict=False,
        computation_device="cpu",  # device
    )
    head_results = get_performance_recovered(cspa_results_qk_ov)

    if verbose:
        print(f"Layer {layer}, head {head} done. Performance: {head_results:.2f}")

    return head_results

In [89]:
def get_attention_to_ioi_token(
        model: HookedTransformer, 
        ioi_dataset: IOIDataset,  
        head_list: List[Tuple[int, int]], 
        batch_size
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

    ioi_dataset.__class__ = BatchIOIDataset
    ioi_dataloader = DataLoader(ioi_dataset, batch_size=batch_size, collate_fn=partial(collate_fn, device=model.cfg.device))
    
    NMH_layers, NMH_heads = zip(*head_list)
    NMH_layers = torch.tensor(NMH_layers, device=model.cfg.device)
    NMH_heads = torch.tensor(NMH_heads, device=model.cfg.device)

    # Initialize tensors to accumulate attention values
    n_layers = model.cfg.n_layers
    n_heads = model.cfg.n_heads
    s1_attention_accum = torch.zeros((n_layers, n_heads), device=model.cfg.device)
    s2_attention_accum = torch.zeros((n_layers, n_heads), device=model.cfg.device)
    io_attention_accum = torch.zeros((n_layers, n_heads), device=model.cfg.device)
    batch_count = 0

    for batch in ioi_dataloader:
        batch_count += 1
        toks = batch['toks']
        io_pos = batch['IO_pos']
        end_pos = batch['end_pos']
        s2_pos = batch['S2_pos']
        s1_pos = batch['S1_pos']
        s_token_ids = batch['s_token_id']
        io_token_ids = batch['io_token_id']

        cache, caching_hooks, _ = model.get_caching_hooks(lambda name: 'hook_pattern' in name)
        with model.hooks(caching_hooks):
            logits = model(toks)[torch.arange(len(toks)), end_pos]

        attention_patterns = torch.stack([cache[f'blocks.{n}.attn.hook_pattern'] for n in range(n_layers)])  #layer, batch, head, query, key
        attention_patterns_by_head = attention_patterns[NMH_layers, :, NMH_heads]

        nmh_s1_attention_values = attention_patterns_by_head[:, torch.arange(len(toks)), end_pos, s1_pos]  # batch, layer, head
        nmh_s2_attention_values = attention_patterns_by_head[:, torch.arange(len(toks)), end_pos, s2_pos]  # batch, layer, head
        nmh_io_attention_values = attention_patterns_by_head[:, torch.arange(len(toks)), end_pos, io_pos]  # batch, layer, head

        # Accumulate attention values
        for i, (layer, head) in enumerate(head_list):
            s1_attention_accum[layer, head] += nmh_s1_attention_values[:, i].mean()
            s2_attention_accum[layer, head] += nmh_s2_attention_values[:, i].mean()
            io_attention_accum[layer, head] += nmh_io_attention_values[:, i].mean()

    # Calculate mean attention values
    s1_attention_means = s1_attention_accum / batch_count
    s2_attention_means = s2_attention_accum / batch_count
    io_attention_means = io_attention_accum / batch_count

    return s1_attention_means, s2_attention_means, io_attention_means


In [90]:
# get NMH candidates
def evaluate_direct_effect_heads(model, edge_df, dataset, verbose=False):
    direct_effect_heads = edge_df[edge_df['target']=='logits']
    direct_effect_heads = direct_effect_heads[direct_effect_heads['in_circuit'] == True]

    head_list = direct_effect_heads['source'].unique().tolist()
    head_list = [convert_head_names_to_tuple(c) for c in head_list if (c[0] != 'm' and c != 'input')]

    head_data = dict()


    # Test for NMH behavior
    head_data['copy_scores'] = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))
    for layer, head in head_list:
        head_data['copy_scores'][layer, head] = compute_copy_score(model, layer, head, dataset, verbose=False, neg=False)

    # Test for attention to IOI tokens
    s1_attn_scores, s2_attn_scores, io_attn_scores = get_attention_to_ioi_token(model, dataset, head_list, batch_size=70)
    head_data['s1_attn_scores'], head_data['s2_attn_scores'], head_data['io_attn_scores'] = s1_attn_scores, s2_attn_scores, io_attn_scores
    
    # Test for copy suppression behavior
    model.cfg.use_split_qkv_input = False
    DATA_TOKS, DATA_STR_TOKS_PARSED, cspa_semantic_dict, indices = prepare_data(model)

    head_data['copy_suppression_scores'] = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))
    for layer, head in head_list:
        head_data['copy_suppression_scores'][layer, head] = get_cspa_for_head(model, DATA_TOKS, cspa_semantic_dict, layer, head, verbose=verbose)

    model.cfg.use_split_qkv_input = True
    
    return head_data

In [91]:
def filter_name_movers(direct_effect_scores, copy_score_threshold):
    direct_effect_scores['filtered_copy_scores'] = direct_effect_scores['copy_scores'].clone()

    nmh_list = []

    for layer in range(direct_effect_scores['copy_scores'].shape[0]):
        for head in range(direct_effect_scores['copy_scores'].shape[1]):
            if direct_effect_scores['copy_scores'][layer, head] < copy_score_threshold:
                direct_effect_scores['filtered_copy_scores'][layer, head] = 0

            if direct_effect_scores['copy_scores'][layer, head] > copy_score_threshold \
                and direct_effect_scores['io_attn_scores'][layer, head] > direct_effect_scores['s1_attn_scores'][layer, head] \
                     and direct_effect_scores['io_attn_scores'][layer, head] > direct_effect_scores['s2_attn_scores'][layer, head]:
                nmh_list.append((layer, head))

    return nmh_list

In [92]:
def evaluate_s2i_candidates(model, checkpoint_df, ioi_dataset, name_mover_heads, batch_size, verbose=False):

    patch_dataset_names = ['token_same_pos_oppo', 'token_oppo_pos_same', 'token_oppo_pos_oppo']
    targeting_nmh = np.logical_or.reduce(np.array([checkpoint_df['target'] == f'a{layer}.h{head}<q>' for layer, head in name_mover_heads]))
    candidate_s2i = checkpoint_df[targeting_nmh]
    candidate_s2i = candidate_s2i[candidate_s2i['in_circuit'] == True]

    candidate_list = candidate_s2i['source'].unique().tolist()
    candidate_list = [convert_head_names_to_tuple(c) for c in candidate_list if (c[0] != 'm' and c != 'input')]


    s2i_heads = candidate_list # [(7,9), (7,2), (6,6), (6,5),]


    s2i_ablated_logit_diff_deltas = {patch_dataset_name: torch.zeros((model.cfg.n_layers, model.cfg.n_heads)) for patch_dataset_name in patch_dataset_names}
    s2i_io_attention_deltas = {patch_dataset_name: torch.zeros((model.cfg.n_layers, model.cfg.n_heads)) for patch_dataset_name in patch_dataset_names}
    s2i_s1_attention_deltas = {patch_dataset_name: torch.zeros((model.cfg.n_layers, model.cfg.n_heads)) for patch_dataset_name in patch_dataset_names}
    s2i_s2_attention_deltas = {patch_dataset_name: torch.zeros((model.cfg.n_layers, model.cfg.n_heads)) for patch_dataset_name in patch_dataset_names}
    true_s2i_mask = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))
    true_s2i_heads = []

    for head in s2i_heads:
        s2i_token_pos_results = S2I_token_pos(model, ioi_dataset, [head], name_mover_heads, batch_size)
        if verbose:
            print(f'Head {head}')
        mean_original_logit_diff = s2i_token_pos_results['ablated_logit_diffs']['token_same_pos_same'].mean()
        mean_original_io_attention = s2i_token_pos_results['io_attention_values']['token_same_pos_same'].mean()
        mean_original_s1_attention = s2i_token_pos_results['s1_attention_values']['token_same_pos_same'].mean()
        mean_original_s2_attention = s2i_token_pos_results['s2_attention_values']['token_same_pos_same'].mean()

        for dataset_name in patch_dataset_names:

            mean_ablated_logit_diff = s2i_token_pos_results['ablated_logit_diffs'][dataset_name].mean()
            mean_ablated_io_attention = s2i_token_pos_results['io_attention_values'][dataset_name].mean()
            mean_ablated_s1_attention = s2i_token_pos_results['s1_attention_values'][dataset_name].mean()
            mean_ablated_s2_attention = s2i_token_pos_results['s2_attention_values'][dataset_name].mean()

            logit_diff_delta = (mean_ablated_logit_diff - mean_original_logit_diff) / mean_original_logit_diff
            io_attention_delta = (mean_ablated_io_attention - mean_original_io_attention) / mean_original_io_attention
            s1_attention_delta = (mean_ablated_s1_attention - mean_original_s1_attention) / mean_original_s1_attention
            s2_attention_delta = (mean_ablated_s2_attention - mean_original_s2_attention) / mean_original_s2_attention

            s2i_ablated_logit_diff_deltas[dataset_name][head] = logit_diff_delta
            s2i_io_attention_deltas[dataset_name][head] = io_attention_delta
            s2i_s1_attention_deltas[dataset_name][head] = s1_attention_delta
            s2i_s2_attention_deltas[dataset_name][head] = s2_attention_delta
            
            if verbose:
                print(dataset_name)
                print(f"Logit diff after patching: {100 * logit_diff_delta:.2f}%")
                # should be high with pos = same, low with pos = diff
                print(f"NMH IO Attention Change: {100 * io_attention_delta:.2f}%")
                # should be low with pos = same, high with pos = diff
                print(f"NMH S1 Attention Change: {100 * s1_attention_delta:.2f}%")
                # shouldn't change much
                print(f"NMH S2 Attention Change: {100 * s2_attention_delta:.2f}%")
                print('\n')
        
        layer, head_idx = head
        if s2i_ablated_logit_diff_deltas['token_same_pos_oppo'][layer, head_idx] < 0 \
            and s2i_io_attention_deltas['token_same_pos_oppo'][layer, head_idx] < 0 \
            and s2i_s1_attention_deltas['token_same_pos_oppo'][layer, head_idx] > 0:
            true_s2i_mask[layer, head_idx] = 1
            true_s2i_heads.append(head)

    # mask the deltas
    s2i_ablated_logit_diff_deltas = {k: v * true_s2i_mask for k, v in s2i_ablated_logit_diff_deltas.items()}
    s2i_io_attention_deltas = {k: v * true_s2i_mask for k, v in s2i_io_attention_deltas.items()}
    s2i_s1_attention_deltas = {k: v * true_s2i_mask for k, v in s2i_s1_attention_deltas.items()}
    s2i_s2_attention_deltas = {k: v * true_s2i_mask for k, v in s2i_s2_attention_deltas.items()}

    return {
        's2i_ablated_logit_diff_deltas': s2i_ablated_logit_diff_deltas, 
        's2i_io_attention_deltas': s2i_io_attention_deltas, 
        's2i_s1_attention_deltas': s2i_s1_attention_deltas, 
        's2i_s2_attention_deltas': s2i_s2_attention_deltas
    }, true_s2i_heads
        
        
            

In [93]:
def get_induction_scores(model):
    seq_len = 100
    batch_size = 2

    prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")

    def prev_token_hook(pattern, hook):
        layer = hook.layer()
        diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)
        prev_token_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")

    duplicate_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")

    def duplicate_token_hook(pattern, hook):
        layer = hook.layer()
        diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)
        duplicate_token_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")

    induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")

    def induction_hook(pattern, hook):
        layer = hook.layer()
        diagonal = pattern.diagonal(offset=seq_len-1, dim1=-1, dim2=-2)
        induction_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")

    original_tokens = torch.randint(100, 20000, size=(batch_size, seq_len))
    repeated_tokens = einops.repeat(original_tokens, "batch seq_len -> batch (2 seq_len)").cuda()

    pattern_filter = lambda act_name: act_name.endswith("hook_pattern")
    loss = model.run_with_hooks(repeated_tokens, return_type="loss", fwd_hooks=[(pattern_filter, prev_token_hook), (pattern_filter, duplicate_token_hook), (pattern_filter, induction_hook)])

    return induction_scores, prev_token_scores, duplicate_token_scores

In [94]:
def evaluate_induction_scores(model, checkpoint_df):
    
    circuit_heads = checkpoint_df[checkpoint_df['in_circuit'] == True]['source'].unique().tolist()
    circuit_heads = [convert_head_names_to_tuple(c) for c in circuit_heads if (c[0] != 'm' and c != 'input')]
    
    circuit_mask = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
    for layer, head in circuit_heads:
        circuit_mask[layer, head] = 1
    
    induction_scores, prev_token_scores, duplicate_token_scores = get_induction_scores(model)
    induction_scores = induction_scores * circuit_mask
    prev_token_scores = prev_token_scores * circuit_mask
    duplicate_token_scores = duplicate_token_scores * circuit_mask

    return {
        'induction_scores': induction_scores, 
        'prev_token_scores': prev_token_scores, 
        'duplicate_token_scores': duplicate_token_scores
    }

In [103]:
def main(overwrite=False):

    model = load_model(BASE_MODEL, VARIANT, 143000, CACHE)
    ioi_dataset, abc_dataset = generate_data_and_caches(model, 70, verbose=True, prepend_bos=True)

    for checkpoint in [5000, 10000, 30000]:
        # check if file exists; if not, create
        if not os.path.exists(f'results/components/{MODEL_SHORTNAME}/components_over_time.pt'):
            os.makedirs(f'results/components/{MODEL_SHORTNAME}', exist_ok=True)
            components_over_time = dict()
            heads_over_time = dict()
        else:
            components_over_time = torch.load(f'results/components/{MODEL_SHORTNAME}/components_over_time.pt')
            heads_over_time = torch.load(f'results/components/{MODEL_SHORTNAME}/heads_over_time.pt')

        if checkpoint in components_over_time and not overwrite:
            continue


        model = load_model(BASE_MODEL, VARIANT, checkpoint, CACHE)
        checkpoint_df = df[df['checkpoint'] == checkpoint].copy()
        component_scores = dict()
        model_heads = dict()

        component_scores['direct_effect_scores'] = evaluate_direct_effect_heads(model, checkpoint_df, ioi_dataset, verbose=False)
        nmh_list = filter_name_movers(component_scores['direct_effect_scores'], copy_score_threshold=10)
        
        model_heads['nmh'] = nmh_list
        print(f"Found {len(nmh_list)} NMHs")
        print(nmh_list)
        
        if len(nmh_list) > 0:
            component_scores['s2i_scores'], s2i_list = evaluate_s2i_candidates(model, checkpoint_df, ioi_dataset, nmh_list, batch_size=70, verbose=False)
            print(f"Found {len(s2i_list)} S2I heads")
            print(s2i_list)
        else:
            component_scores['s2i_scores'] = None
            s2i_list = []

        model_heads['s2i'] = s2i_list

        component_scores['tertiary_head_scores'] = evaluate_induction_scores(model, checkpoint_df)

        components_over_time[checkpoint] = component_scores
        heads_over_time[checkpoint] = model_heads

        torch.save(components_over_time, f'results/components/{MODEL_SHORTNAME}/components_over_time.pt')
        torch.save(heads_over_time, f'results/components/{MODEL_SHORTNAME}/heads_over_time.pt')

    return components_over_time
   
components_over_time = main()

# # #baseline_logit_diffs, end_s2_attention_values, baseline_nmh_s1_attention_values, new_logit_diffs, new_nmh_s1_attention_values
# s2i_results = S2I_head_metrics(model, ioi_dataset, candidate_list, name_mover_heads, batch_size)

# # our three measures are thus:

# # attention (higher is better)
# s2i_s2_attention = s2i_results['end_s2_attention_values'].mean(0)

# # logit diff change (lower is better)
# logit_diff_change = (s2i_results['new_logit_diffs'] - s2i_results['baseline_logit_diffs'].unsqueeze(1)).mean(0)

# # NMH s1 attention change (higher is better)
# nmh_s1_attention_change = (s2i_results['new_nmh_s1_attention_values'] - s2i_results['baseline_nmh_s1_attention_values'].unsqueeze(1)).mean(0).mean(-1)


  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Batch 1/17, shape torch.Size([1, 61]):   0%|          | 0/17 [00:00<?, ?it/s]



Found 0 NMHs
[]


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-160m into HookedTransformer
Copy circuit for head 8.1 (sign=1) : Top 5 accuracy: 98.57142857142858%
Copy circuit for head 8.10 (sign=1) : Top 5 accuracy: 99.04761904761905%


In [98]:
components_over_time.keys()

dict_keys([5000])

In [78]:
imshow_p(
        component_scores['s2i_scores']['s2i_ablated_logit_diff_deltas']['token_same_pos_oppo'] * 100, 
        title=f's2i_ablated_logit_diff_deltas for {MODEL_SHORTNAME} at checkpoint {checkpoint}',
        labels={"x": "Head", "y": "Layer", "color": "metric_value"},
        border=True,
        coloraxis=dict(colorbar_ticksuffix=" %"),
        width=600,
        margin={"r": 100, "l": 100}
        )

In [79]:
imshow_p(
        component_scores['s2i_scores']['s2i_io_attention_deltas']['token_same_pos_oppo'] * 100, 
        title=f's2i_io_attention_deltas for {MODEL_SHORTNAME} at checkpoint {checkpoint}',
        labels={"x": "Head", "y": "Layer", "color": "metric_value"},
        border=True,
        coloraxis=dict(colorbar_ticksuffix=" %"),
        width=600,
        margin={"r": 100, "l": 100}
        )

In [80]:
imshow_p(
        component_scores['s2i_scores']['s2i_s1_attention_deltas']['token_same_pos_oppo'] * 100, 
        title=f's2i_s1_attention_deltas for {MODEL_SHORTNAME} at checkpoint {checkpoint}',
        labels={"x": "Head", "y": "Layer", "color": "metric_value"},
        border=True,
        coloraxis=dict(colorbar_ticksuffix=" %"),
        width=600,
        margin={"r": 100, "l": 100}
        )

In [81]:
#component_scores['tertiary_head_scores']['induction_scores']
imshow_p(
        component_scores['tertiary_head_scores']['induction_scores'] * 100, 
        title=f'induction_scores_deltas for {MODEL_SHORTNAME} at checkpoint {checkpoint}',
        labels={"x": "Head", "y": "Layer", "color": "metric_value"},
        border=True,
        coloraxis=dict(colorbar_ticksuffix=" %"),
        width=600,
        margin={"r": 100, "l": 100}
        )

In [82]:
imshow_p(
        component_scores['tertiary_head_scores']['prev_token_scores'] * 100, 
        title=f'prev_token_scores for {MODEL_SHORTNAME} at checkpoint {checkpoint}',
        labels={"x": "Head", "y": "Layer", "color": "metric_value"},
        border=True,
        coloraxis=dict(colorbar_ticksuffix=" %"),
        width=600,
        margin={"r": 100, "l": 100}
        )

In [83]:
imshow_p(
        component_scores['tertiary_head_scores']['duplicate_token_scores'] * 100, 
        title=f'duplicate_token_scores for {MODEL_SHORTNAME} at checkpoint {checkpoint}',
        labels={"x": "Head", "y": "Layer", "color": "metric_value"},
        border=True,
        coloraxis=dict(colorbar_ticksuffix=" %"),
        width=600,
        margin={"r": 100, "l": 100}
        )

In [74]:
for metric in component_scores['direct_effect_scores'].keys():
    # if metric contains "attn"
    if 'attn' in metric:
        min = -1
        max = 1
    else:
        min = None
        max = None
    imshow_p(
        component_scores['direct_effect_scores'][metric], 
        title=f'{metric} for {MODEL_SHORTNAME} at checkpoint {checkpoint}',
        labels={"x": "Head", "y": "Layer", "color": "metric_value"},
        border=True,
        coloraxis=dict(cmin=min, cmax=max),
        width=600,
        margin={"r": 100, "l": 100}
        )

In [13]:
component_scores['direct_effect_scores'].keys()

dict_keys(['copy_scores', 's1_attn_scores', 's2_attn_scores', 'io_attn_scores', 'copy_suppression_scores'])

In [38]:
name_mover_heads = checkpoint_nmhs[checkpoint]
targeting_nmh = np.logical_or.reduce(np.array([df['target'] == f'a{layer}.h{head}<q>' for layer, head in name_mover_heads]))
candidate_s2i = df[targeting_nmh]
candidate_s2i = candidate_s2i[candidate_s2i['in_circuit'] == True]

candidate_list = candidate_s2i[candidate_s2i['checkpoint']==checkpoint]['source'].unique().tolist()
candidate_list = [convert_head_names_to_tuple(c) for c in candidate_list if (c[0] != 'm' and c != 'input')]

In [34]:
checkpoint

143000

In [37]:
candidate_list

[(6, 6),
 (6, 5),
 (6, 11),
 (5, 3),
 (5, 11),
 (7, 4),
 (7, 2),
 (7, 3),
 (7, 10),
 (7, 9)]

In [15]:
# for i in range(len(candidate_list)):
#     print(f'Head {candidate_list[i]}')
#     print(f'Attention  to S2:        {s2i_s2_attention[i]:.3f}')
#     print(f'Logit Diff Change:       {logit_diff_change[i]:.3f}')
#     print(f'NMH S1 Attention Change: {nmh_s1_attention_change[i]:.3f}')

In [30]:

s2i_heads = candidate_list # [(7,9), (7,2), (6,6), (6,5),]
print(s2i_heads)

for head in s2i_heads:
    s2i_token_pos_results = S2I_token_pos(model, ioi_dataset, [head], name_mover_heads, batch_size)
    print(f'Head {head}')
    mean_original_logit_diff = s2i_token_pos_results['ablated_logit_diffs']['token_same_pos_same'].mean()
    mean_original_io_attention = s2i_token_pos_results['io_attention_values']['token_same_pos_same'].mean()
    mean_original_s1_attention = s2i_token_pos_results['s1_attention_values']['token_same_pos_same'].mean()
    mean_original_s2_attention = s2i_token_pos_results['s2_attention_values']['token_same_pos_same'].mean()

    for dataset_name in patch_dataset_names:
        print(dataset_name)

        mean_ablated_logit_diff = s2i_token_pos_results['ablated_logit_diffs'][dataset_name].mean()
        mean_ablated_io_attention = s2i_token_pos_results['io_attention_values'][dataset_name].mean()
        mean_ablated_s1_attention = s2i_token_pos_results['s1_attention_values'][dataset_name].mean()
        mean_ablated_s2_attention = s2i_token_pos_results['s2_attention_values'][dataset_name].mean()

        logit_diff_delta = (mean_ablated_logit_diff - mean_original_logit_diff) / mean_original_logit_diff
        print(f"Logit diff after patching: {100 * logit_diff_delta:.2f}%")
        # should be high with pos = same, low with pos = diff

        io_attention_delta = (mean_ablated_io_attention - mean_original_io_attention) / mean_original_io_attention
        print(f"NMH IO Attention Value: {100 * io_attention_delta:.2f}%")
        # should be low with pos = same, high with pos = diff

        s1_attention_delta = (mean_ablated_s1_attention - mean_original_s1_attention) / mean_original_s1_attention
        print(f"NMH S1 Attention Value: {100 * s1_attention_delta:.2f}%")
        # shouldn't change much

        s2_attention_delta = (mean_ablated_s2_attention - mean_original_s2_attention) / mean_original_s2_attention
        print(f"NMH S2 Attention Value: {100 * s2_attention_delta:.2f}%")
        print('\n')


[(6, 6), (8, 9), (6, 5), (7, 2)]
Head (6, 6)
token_same_pos_oppo
Logit diff after patching: -29.44%
NMH IO Attention Value: -19.74%
NMH S1 Attention Value: 40.89%
NMH S2 Attention Value: -24.03%


token_diff_pos_oppo
Logit diff after patching: -27.85%
NMH IO Attention Value: -19.38%
NMH S1 Attention Value: 40.04%
NMH S2 Attention Value: -22.94%


token_oppo_pos_oppo
Logit diff after patching: -27.37%
NMH IO Attention Value: -19.19%
NMH S1 Attention Value: 38.70%
NMH S2 Attention Value: -18.16%


Head (8, 9)
token_same_pos_oppo
Logit diff after patching: -0.43%
NMH IO Attention Value: -0.53%
NMH S1 Attention Value: 1.19%
NMH S2 Attention Value: -0.41%


token_diff_pos_oppo
Logit diff after patching: -1.37%
NMH IO Attention Value: -1.48%
NMH S1 Attention Value: 3.04%
NMH S2 Attention Value: 10.35%


token_oppo_pos_oppo
Logit diff after patching: -2.78%
NMH IO Attention Value: -3.47%
NMH S1 Attention Value: 7.55%
NMH S2 Attention Value: 9.13%


Head (6, 5)
token_same_pos_oppo
Logit diff a