In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
%env JAX_TRACEBACK_FILTERING=off
import jax
jax.config.update('jax_traceback_filtering', 'off')


env: JAX_TRACEBACK_FILTERING=off


In [4]:
from sprint.icl_sfc_utils import Circuitizer

In [5]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [6]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [7]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

fatal: destination path 'data/itv' already exists and is not an empty directory.


In [8]:
def check_if_single_token(token):
    return len(tokenizer.tokenize(token)) == 1

task_name = "country_capital"

task = tasks[task_name]

print(len(task))

# task = {
#     k:v for k,v in task.items() if check_if_single_token(k) and check_if_single_token(v)
# }

print(len(task))

pairs = list(task.items())

batch_size = 8 
n_shot=16
max_seq_len = 128
seed = 10

prompt = "Follow the pattern:\n{}"

runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt, use_same_examples=True)

139
139


In [9]:
runner.train_pairs

In [10]:
layers = list(range(6, 18))
circuitizer = Circuitizer(llama, tokenizer, runner, layers, prompt)

Setting up masks...
Running metrics...
Setting up RMS...


  0%|          | 0/18 [00:00<?, ?it/s]

Loading SAEs...


  0%|          | 0/12 [00:00<?, ?it/s]

Running node IEs...


  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

In [15]:

# layers = [10,11,12,13,14,15,16]
# layers = [8,9,10]
layers = list(range(10, 15))
mean_ablate = False

orig_metric = circuitizer.ablated_metric(llama).tolist()
zero_metric = circuitizer.run_ablated_metrics([100000], mean_ablate=mean_ablate, layers=layers)[0][0]

print(orig_metric, zero_metric)

  0%|          | 0/1 [00:00<?, ?it/s]

-17.0 -173.0


In [16]:
import numpy as np
# thresholds = np.linspace(0, 1e-4, 100)
# thresholds = np.linspace(1.4 * 1e-4, 1.45 * 1e-4, 200)
thresholds = np.logspace(-4, -1, 150)
topks = [4, 6, 12, 16, 24, 32]

inverse = False
do_abs = False
average_over_positions = True


ablated_metrics, n_nodes_counts = circuitizer.run_ablated_metrics(thresholds, inverse=inverse, 
                                                                  do_abs=do_abs, mean_ablate=mean_ablate, 
                                                                  average_over_positions=average_over_positions,
                                                                  token_prefix=None, layers=layers)

faithfullness = np.array(ablated_metrics)
faithfullness = (faithfullness - zero_metric) / (orig_metric - zero_metric)



# target_metric = (max(ablated_metrics) - min(ablated_metrics)) * 0.95 + min(ablated_metrics)
# target_threshold = [threshold for threshold, metric in reversed(list(zip(thresholds, ablated_metrics))) if metric > target_metric][0]

  0%|          | 0/150 [00:00<?, ?it/s]

In [17]:
layers

In [19]:
import matplotlib.pyplot as plt
import plotly.express as px

# plt.plot([max(n_nodes_counts) - x for x in n_nodes_counts], ablated_metrics)
# plt.plot(thresholds, ablated_metrics)
# plt.plot(thresholds, ablated_metrics)
# plt.plot(n_nodes_counts, ablated_metrics)
# plt.plot(thresholds, n_nodes_counts)
# plt.xscale("log")
# plt.plot(n_nodes_counts)

# px.line(x=list(range(len(ablated_metrics))), y=ablated_metrics)
# fig = px.line(x=thresholds, y=ablated_metrics)
# fig.update_xaxes(type="log", exponentformat="power")

fig = px.line(x=n_nodes_counts, y=faithfullness, title=f"inverse={inverse}, abs={do_abs}, mean={mean_ablate}, aop={average_over_positions}, layers={layers}")
fig.update_xaxes(title="Number of nodes")
fig.update_yaxes(title="Faithfullness")

fig


In [20]:
target_faithfullness = 0.6

target_threshold = [threshold for threshold, metric in reversed(list(zip(thresholds, faithfullness))) if metric > target_faithfullness][0]

target_threshold

In [21]:
from tqdm.auto import tqdm

# layers = circuitizer.layers
# layers = [15,16]
selected_threshold = target_threshold


ablation_masks = {}

for layer in tqdm(layers):
    mask_attn_out, _ = circuitizer.mask_ie(circuitizer.ie_attn[layer], selected_threshold, None, do_abs=do_abs, average_over_positions=average_over_positions, inverse=inverse)
    mask_resid, _ = circuitizer.mask_ie(circuitizer.ie_resid[layer], selected_threshold, None, do_abs=do_abs, average_over_positions=average_over_positions, inverse=inverse)

    # print(mask_resid["arrow"].shape)

    # break

    try:
        mask_transcoder, _ = circuitizer.mask_ie(circuitizer.ie_transcoder[layer], selected_threshold, None, do_abs=do_abs, average_over_positions=average_over_positions, inverse=inverse)
    except KeyError:
        mask_transcoder = None

    ablation_masks[layer] = {
        "attn_out": mask_attn_out,
        "resid": mask_resid,
        "transcoder": mask_transcoder
    }

  0%|          | 0/5 [00:00<?, ?it/s]

In [22]:
circuit_nodes = []
n_nodes = 0

for layer, masks in ablation_masks.items():
    for mask_type, mask in masks.items():
        if mask is not None:
            for token_type, mask in mask.items():
                    n_nodes += mask.sum()
                    
                    node_ids = np.where(mask)

                    if len(node_ids) ==2:
                        for pos, feat in zip(*node_ids):
                            circuit_nodes.append((layer, mask_type, token_type, feat, pos))
                    else:
                        for feat in node_ids[0]:
                            circuit_nodes.append((layer, mask_type, token_type, feat, None))
                    

n_nodes

In [23]:
typed_ies = {
    "r": circuitizer.ie_resid,
    "a": circuitizer.ie_attn,
    "t": circuitizer.ie_transcoder,
}

circuit_nodes_with_ies = []

for node in circuit_nodes:
    layer, sae_type, token_type, node_id, pos = node
    ies = typed_ies[sae_type[0]][layer]

    if average_over_positions:
        masked_ies = circuitizer.mask_average(ies, token_type, average_over_positions=True)
        circuit_nodes_with_ies.append((*node, masked_ies[node_id].tolist()))
    else:
        masked_ies = circuitizer.mask_average(ies, token_type, average_over_positions=False)
        circuit_nodes_with_ies.append((*node, masked_ies[pos, node_id].tolist()))

circuit_nodes_with_ies = sorted(circuit_nodes_with_ies, key=lambda x: x[-1], reverse=True)

In [24]:
circuit_nodes_with_ies[:10]

In [25]:
from tqdm.auto import tqdm
import numpy as np

combined_ies = {}

if average_over_positions:
    for node in circuit_nodes_with_ies:
        layer, type, mask, idx, pos, ie = node
        combined_ies[(layer, mask, type[0], idx)] = ie

else:
    for node in circuit_nodes_with_ies:
        layer, type, mask, idx, pos, ie = node
        combined_ies[(layer, mask, type[0], idx, pos)] = ie

In [26]:
combined_ies = [
    key + (weight,)
    for key, weight in combined_ies.items()
]

In [27]:
typed_ies_error = {
    "er": circuitizer.ie_error_resid,
    "ea": circuitizer.ie_error_attn,
    "et": circuitizer.ie_error_transcoder,
}

for layer in tqdm(layers):
    for type in typed_ies_error:
        if layer in typed_ies_error[type]: 
            ies = typed_ies_error[type][layer]
            for mask in circuitizer.masks:
                ies_mask = circuitizer.mask_average(ies, mask, average_over_positions=average_over_positions)
                # print(ies_mask.tolist())
                # raise

                if average_over_positions:
                    combined_ies.append((layer, mask, type, 0, ies_mask.tolist()))

                else:
                    for pos, ie in enumerate(ies_mask):
                        if ie > selected_threshold:
                            combined_ies.append((layer, mask, type, 0, pos, ie))

  0%|          | 0/5 [00:00<?, ?it/s]

In [28]:
combined_ies = sorted(combined_ies, key=lambda x: -x[-1])

In [29]:

from collections import defaultdict
circuit_node_dict = defaultdict(list)

if average_over_positions:
    for node in combined_ies:
        layer, mask, type, idx, weight = node
        circuit_node_dict[(type, layer, mask)].append(idx)

    circuit_node_dict = {
        k: np.array(v) for k,v in circuit_node_dict.items()
    }
else:
    for node in combined_ies:
        layer, mask, type, idx, pos, weight = node
        circuit_node_dict[(type, layer, mask)].append((pos, idx))

    circuit_node_dict = {
        k: np.array(v) for k,v in circuit_node_dict.items()
    }

In [30]:
import jax.numpy as jnp
from tqdm.auto import trange

if average_over_positions:
    important_feats_masks = {}
    for mask in circuitizer.masks:
        important_feats_masks[mask] = [
            (type, layer, feat) for layer, f_mask, type, feat, _ in combined_ies if f_mask == mask
            ]


    flat_feats = defaultdict(list)
    for k, v in important_feats_masks.items():
        for type, layer, feat in v:
            flat_feats[(k, type, layer)].append(feat)


    graph = []

    batch_size = 16
    # k = 32
    for type, features in tqdm(sorted(flat_feats.items(), key=lambda x: (-x[0][-1], x[0][-2], x[0][-3]))):
        mask, feature_type, layer = type
        mask = jnp.array(list(circuitizer.masks.keys()).index(mask))
        for batch in trange(0, len(features), batch_size, postfix=str(type)):
            batch_features = features[batch:batch+batch_size]
            orig_length = len(batch_features)
            batch_features = batch_features + [0] * (batch_size - len(batch_features))
            feature_effectss = jax.vmap(lambda x: circuitizer.compute_feature_effects(feature_type, layer, x, mask, layer_window=1, position=None))(jnp.asarray(batch_features))
            # feature_effectss = circuitizer.compute_feature_effects(feature_type, layer, batch_features, mask, layer_window=1)
            top_effects = defaultdict(list)
            for key, featuress in feature_effectss.items():
                for elem, feature_effects in enumerate(featuress):
                    if elem >= orig_length:
                        continue
                    if feature_effects.ndim == 0:
                        top_effects[elem].append((float(feature_effects), key, 0))
                        continue

                    nodes_to_keep = circuit_node_dict.get(key, np.empty(0, dtype=np.int32))
                    effects = feature_effects[nodes_to_keep]
                    for idx, effect in zip(nodes_to_keep, effects):
                        top_effects[elem].append((float(effect), key, int(idx)))
            for elem, effects in top_effects.items():
                effects.sort(reverse=True)
                edges = effects
                graph.extend([(weight,  key + (upstream_feature,), (type[1], type[2], type[0], batch_features[elem],) ) for weight, key, upstream_feature in edges])
            


    combined_ies = [
        (type, layer, mask, idx, weight) for layer, mask, type, idx, weight in combined_ies
    ] 


    sorted_graph = sorted(graph, reverse=True, key=lambda x: x[0])

    n_nodes = sum(map(len, important_feats_masks.values()))
    k_connections = 4
    weight_threshold = sorted_graph[n_nodes * k_connections][0]

  0%|          | 0/142 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s, ('arrow', 'a', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'ea', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'ea', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'ea', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'ea', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'ea', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'ea', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'er', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'er', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'er', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'er', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'er', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'er', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'et', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'et', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'et', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'et', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'et', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'et', 14)]

  0%|          | 0/3 [00:00<?, ?it/s, ('arrow', 'r', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'r', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'r', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'r', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'r', 14)]

  0%|          | 0/3 [00:00<?, ?it/s, ('arrow', 't', 14)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'a', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'ea', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'ea', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'ea', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'ea', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'ea', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'ea', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'er', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'er', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'er', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'er', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'er', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'er', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'et', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'et', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'et', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'et', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'et', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'et', 13)]

  0%|          | 0/2 [00:00<?, ?it/s, ('arrow', 'r', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'r', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'r', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'r', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'r', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'r', 13)]

  0%|          | 0/2 [00:00<?, ?it/s, ('arrow', 't', 13)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'a', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'ea', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'ea', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'ea', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'ea', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'ea', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'ea', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'er', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'er', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'er', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'er', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'er', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'er', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'et', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'et', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'et', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'et', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'et', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'et', 12)]

  0%|          | 0/2 [00:00<?, ?it/s, ('arrow', 'r', 12)]

  0%|          | 0/2 [00:00<?, ?it/s, ('input', 'r', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'r', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'r', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'r', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'r', 12)]

  0%|          | 0/2 [00:00<?, ?it/s, ('arrow', 't', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 't', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 't', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 't', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 't', 12)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'a', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'ea', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'ea', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'ea', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'ea', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'ea', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'ea', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'er', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'er', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'er', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'er', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'er', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'er', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'et', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'et', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'et', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'et', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'et', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'et', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'r', 11)]

  0%|          | 0/2 [00:00<?, ?it/s, ('input', 'r', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'r', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'r', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'r', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'r', 11)]

  0%|          | 0/2 [00:00<?, ?it/s, ('arrow', 't', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 't', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 't', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 't', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 't', 11)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'a', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'a', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'ea', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'ea', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'ea', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'ea', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'ea', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'ea', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'er', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'er', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'er', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'er', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'er', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'er', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'et', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'et', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'et', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'et', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'et', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'et', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 'r', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 'r', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 'r', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 'r', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 'r', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('remaining', 'r', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('arrow', 't', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('input', 't', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('newline', 't', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('output', 't', 10)]

  0%|          | 0/1 [00:00<?, ?it/s, ('prompt', 't', 10)]

In [31]:
if not average_over_positions:
    important_feats_masks = {}
    for mask in circuitizer.masks:
        important_feats_masks[mask] = [
            (type, layer, feat, pos) for layer, f_mask, type, feat, pos, _ in combined_ies if f_mask == mask
            ]


    flat_feats = defaultdict(list)
    for k, v in important_feats_masks.items():
        for type, layer, feat, pos in v:
            flat_feats[(k, type, layer)].append((pos, feat))


    circuit_node_dict

    graph = []

    batch_size = 16
    # k = 32
    for type, features in tqdm(sorted(flat_feats.items(), key=lambda x: (-x[0][-1], x[0][-2], x[0][-3]))):
        mask, feature_type, layer = type
        mask = jnp.array(list(circuitizer.masks.keys()).index(mask))
        for batch in trange(0, len(features), batch_size, postfix=str(type)):
            batch_features = features[batch:batch+batch_size]
            orig_length = len(batch_features)
            batch_features = batch_features + [(0, 0)] * (batch_size - len(batch_features))
            feature_effectss = jax.vmap(lambda x: circuitizer.compute_feature_effects(feature_type, layer, x[1], mask, layer_window=1, position=x[0]))(jnp.asarray(batch_features))
            # feature_effectss = circuitizer.compute_feature_effects(feature_type, layer, batch_features, mask, layer_window=1)
            top_effects = defaultdict(list)
            for key, featuress in feature_effectss.items():
                nodes_to_keep = circuit_node_dict.get(key, np.empty((0, 2), dtype=np.int32))

                for elem, feature_effects in enumerate(featuress):
                    if elem >= orig_length:
                        continue
                    if feature_effects.ndim == 1:
                        for idx, _ in nodes_to_keep:
                            top_effects[elem].append((float(feature_effects[idx]), key, 0, idx))
                        continue
                    effects = feature_effects[nodes_to_keep[:, 0], nodes_to_keep[:, 1]]

                    for idx, effect in zip(nodes_to_keep, effects):
                        top_effects[elem].append((float(effect), key, int(idx[1]), int(idx[0])))

                    
            for elem, effects in top_effects.items():
                effects.sort(reverse=True)
                edges = effects
                graph.extend([(weight,  key + (upstream_feature,upos,), (type[1], type[2], type[0], batch_features[elem][1], batch_features[elem][0],) ) for weight, key, upstream_feature, upos in edges])
            


    combined_ies = [
        (type, layer, mask, idx, pos, weight) for layer, mask, type, idx, pos, weight in combined_ies
    ] 


    sorted_graph = sorted(graph, reverse=True, key=lambda x: x[0])

    n_nodes = sum(map(len, important_feats_masks.values()))
    k_connections = 4
    weight_threshold = sorted_graph[n_nodes * k_connections][0]

In [32]:
if average_over_positions:

    _graph = [
        (w, l, (*r[:-1], int(r[-1]))) for w, l, r in sorted_graph
    ]
else:
    _graph = [
        (w, (*l[:-2], int(l[-2]), int(l[-1])), (*r[:-2], int(r[-2]), int(r[-1]))) for w, l, r in sorted_graph
    ]

In [33]:
if average_over_positions:
    _combined_ies = [
        (type, layer, mask, int(idx), weight) for type, layer, mask, idx, weight in combined_ies
    ]
else:
    _combined_ies = [
        (type, layer, mask, int(idx), int(pos), float(weight)) for type, layer, mask, idx, pos, weight in combined_ies
    ]

In [34]:
tokens_decoded = [tokenizer.convert_ids_to_tokens(x) for x in circuitizer.train_tokens]
tokens_decoded = [[x for x in y if x != "<pad>"] for y in tokens_decoded]
tokens_decoded = [[x.replace("Ġ", " ") for x in y] for y in tokens_decoded]
tokens_decoded = [[x.replace("▁", " ") for x in y] for y in tokens_decoded]
tokens_decoded = [[x.replace("\n", " ") for x in y] for y in tokens_decoded]

In [35]:
if not average_over_positions:

    position_maps = defaultdict(defaultdict)

    for layer, mask, type, idx, pos, weight in _combined_ies:
        partial_id = (layer, mask, type, idx)
        partial_id = ":".join(str(x) for x in partial_id)
        
        # position_maps[partial_id].append(":".join(str(x) for x in (layer, mask, type, idx, pos)))
        position_maps[partial_id][pos] = weight

In [36]:
import json
if average_over_positions:
    with open(f"micrlhf-progress/graph-rebirth-{task_name}_faith_0.6_l{min(layers)}_l{max(layers)}.json", 'w') as f:
        json.dump({"edges": _graph, "nodes": _combined_ies, "threshold": weight_threshold, "tokens": None}, f)
else:
    with open(f"micrlhf-progress/graph-rebirth-{task_name}_faith_{target_faithfullness}_non_aop_n_shot_{n_shot}_l{min(layers)}_l{max(layers)}_mean_{mean_ablate}.json", 'w') as f:
        json.dump({"edges": _graph, "nodes": _combined_ies, "threshold": weight_threshold, "tokens": tokens_decoded, "position_maps": position_maps}, f)

In [42]:
import json
with open("micrlhf-progress/all-graph-antonyms.json") as f:
    all_graph = json.load(f)

nodes = all_graph["nodes"]

nodes[:100]

In [41]:
important_feats_masks = {}
for mask in circuitizer.masks:
    important_feats_masks[mask] = [
        (type, layer, feat) for layer, f_mask, type, feat, _ in combined_ies if f_mask == mask
        ]

In [20]:
from collections import defaultdict
flat_feats = defaultdict(list)
for k, v in important_feats_masks.items():
    for type, layer, feat in v:
        flat_feats[(k, type, layer)].append(feat)

In [36]:
from tqdm.auto import trange
import jax.numpy as jnp
graph = []

batch_size = 16
k = 32
for type, features in tqdm(sorted(flat_feats.items(), key=lambda x: (-x[0][-1], x[0][-2], x[0][-3]))):
    mask, feature_type, layer = type
    mask = jnp.array(list(circuitizer.masks.keys()).index(mask))
    for batch in trange(0, len(features), batch_size, postfix=str(type)):
        batch_features = features[batch:batch+batch_size]
        orig_length = len(batch_features)
        batch_features = batch_features + [0] * (batch_size - len(batch_features))
        feature_effectss = jax.vmap(lambda x: circuitizer.compute_feature_effects(feature_type, layer, x, mask, layer_window=1))(jnp.asarray(batch_features))
        # feature_effectss = circuitizer.compute_feature_effects(feature_type, layer, batch_features, mask, layer_window=1)
        top_effects = defaultdict(list)
        for key, featuress in feature_effectss.items():
            for elem, feature_effects in enumerate(featuress):
                if elem >= orig_length:
                    continue
                if feature_effects.ndim == 0:
                    top_effects[elem].append((float(feature_effects), key, 0))
                    continue
                effects, indices = jax.lax.top_k(jnp.abs(feature_effects), k)
                for i, e in zip(indices.tolist(), effects.tolist()):
                    top_effects[elem].append((e, key, i))
        for elem, effects in top_effects.items():
            effects.sort(reverse=True)
            edges = effects[:k]
            graph.extend([(weight,  key + (upstream_feature,), (type[1], type[2], type[0], batch_features[elem],) ) for weight, key, upstream_feature in edges])
        
        # # edges = circuitizer.compute_edges(*feature, mask, layer_window=1)
        # graph.extend([(weight, feature, downstream_feature) for weight, _, downstream_feature in edges])
    
        # for feature in tqdm(batch_features):
        #     feature_effects = circuitizer.compute_feature_effects(feature_type, layer, feature, mask, layer_window=1)
        #     top_effects = []
        #     for key, features in feature_effects.items():
        #         if features.ndim == 0:
        #             top_effects.append((float(features), key, 0))
        #             continue
        #         effects, indices = jax.lax.top_k(jnp.abs(features), k)
        #         for i, e in zip(indices.tolist(), effects.tolist()):
        #             top_effects.append((e, key, i))
        #     top_effects.sort(reverse=True)
        #     edges = top_effects[:k]
            
        #     # edges = circuitizer.compute_edges(*feature, mask, layer_window=1)
        #     graph.extend([(weight, feature, downstream_feature) for weight, _, downstream_feature in edges])



# for mask, features in tqdm(important_feats_masks.items()):
#     for batch in trange(0, len(features), batch_size):
#         batch_features = features[batch:batch+batch_size]
        
        
#         for feature in tqdm(batch_features):
#             # edges = circuitizer.compute_edges(*feature, mask, layer_window=1)
#             graph.extend([(weight, feature, downstream_feature) for weight, _, downstream_feature in edges])


#     # for downstream_feature in tqdm(features):
#     #     edges = compute_edges(downstream_feature, mask, layer_window=1)
#     #     graph.extend([(weight, upstream_feature_key + (upstream_feature,), downstream_feature[:2] + (mask,) + downstream_feature[2:])
#     #                   for weight, upstream_feature_key, upstream_feature in edges])

  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s, ('arrow', 'a', 16)]

dict_keys([('er', 16, 'arrow'), ('er', 16, 'input'), ('er', 16, 'newline'), ('er', 16, 'output'), ('er', 16, 'prompt'), ('r', 16, 'arrow'), ('r', 16, 'input'), ('r', 16, 'newline'), ('r', 16, 'output'), ('r', 16, 'prompt')]) [26950, 25539, 1085, 14538, 19232, 7123, 20911, 12264, 12539, 12413, 32120, 26574, 23636, 23114, 178, 12793]
('er', 16, 'arrow') (16,)
('er', 16, 'input') (16,)
('er', 16, 'newline') (16,)
('er', 16, 'output') (16,)
('er', 16, 'prompt') (16,)
('r', 16, 'arrow') (16, 32768)
('r', 16, 'input') (16, 32768)
('r', 16, 'newline') (16, 32768)
('r', 16, 'output') (16, 32768)
('r', 16, 'prompt') (16, 32768)
[(6.0595695686060935e-05, ('er', 16, 'input', 0), ('a', 16, 'arrow', 26950)), (4.51656014774926e-05, ('r', 16, 'prompt', 5241), ('a', 16, 'arrow', 26950)), (3.955366264563054e-05, ('r', 16, 'arrow', 29818), ('a', 16, 'arrow', 26950)), (3.92732436012011e-05, ('r', 16, 'arrow', 24991), ('a', 16, 'arrow', 26950)), (3.684913463075645e-05, ('er', 16, 'prompt', 0), ('a', 16, '

ZeroDivisionError: division by zero

In [37]:
combined_ies[0]

In [None]:
sorted_graph = sorted(graph, reverse=True, key=lambda x: x[0])