In [2]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [4]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [10]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [11]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [12]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

fatal: destination path 'data/itv' already exists and is not an empty directory.


In [13]:
import jax.numpy as jnp
import jax


from sprint.task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper


get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)



def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


In [14]:
from sprint.icl_sfc_utils import AblatedModule
layer = 12
mask_name = "arrow"

In [15]:
from sprint.icl_sfc_utils import Circuitizer

from sprint.task_vector_utils import load_tasks, ICLRunner
import numpy as np

def important_features(task_name):
    def check_if_single_token(token):
        return len(tokenizer.tokenize(token)) == 1

    task = tasks[task_name]

    print(len(task))

    # task = {
    #     k:v for k,v in task.items() if check_if_single_token(k) and check_if_single_token(v)
    # }

    print(len(task))

    pairs = list(task.items())

    batch_size = 8 
    n_shot=16
    if task_name.startswith("algo"):
        n_shot = 12

    max_seq_len = 128
    seed = 10

    prompt = "Follow the pattern:\n{}"

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)


    layers = list(range(10, 17))
    circuitizer = Circuitizer(llama, tokenizer, runner, layers, prompt)

        
    layers = list(range(10, 15))
    average_over_positions = True
    # layers = [5]
    orig_metric = circuitizer.ablated_metric(llama).tolist()
    zero_metric = circuitizer.run_ablated_metrics([100000], layers=layers, average_over_positions=average_over_positions)[0][0]

    print(orig_metric, zero_metric)
    
    thresholds = np.logspace(-5, 0, 200)
    topks = [4, 6, 12, 16, 24, 32]

    inverse = False
    do_abs = False
    mean_ablate = False


    ablated_metrics, n_nodes_counts = circuitizer.run_ablated_metrics(thresholds, inverse=inverse, 
                                                                    do_abs=do_abs, mean_ablate=mean_ablate, 
                                                                    average_over_positions=average_over_positions,
                                                                    token_prefix=None, layers=layers)

    faithfullness = np.array(ablated_metrics)
    faithfullness = (faithfullness - zero_metric) / (orig_metric - zero_metric)



    # target_metric = (max(ablated_metrics) - min(ablated_metrics)) * 0.95 + min(ablated_metrics)

    target_faithfullness = 0.6

    # print(target_metric)
    # target_threshold = [threshold for threshold, metric in list(zip(thresholds, ablated_metrics)) if metric > target_metric][0]
    
    target_threshold = [threshold for threshold, faith in reversed(list(zip(thresholds, faithfullness))) if faith > target_faithfullness][0]
    print(target_threshold)


    selected_threshold = target_threshold


    ablation_masks = {}

    for layer in layers:
        mask_attn_out, _ = circuitizer.mask_ie(circuitizer.ie_attn[layer], selected_threshold, None, inverse=inverse, average_over_positions=average_over_positions, do_abs=do_abs)
        mask_resid, _ = circuitizer.mask_ie(circuitizer.ie_resid[layer], selected_threshold, None, inverse=inverse, average_over_positions=average_over_positions, do_abs=do_abs)
        try:
            mask_transcoder, _ = circuitizer.mask_ie(circuitizer.ie_transcoder[layer], selected_threshold, None, inverse=inverse, average_over_positions=average_over_positions, do_abs=do_abs)
        except KeyError:
            mask_transcoder = None

        ablation_masks[layer] = {
            "attn_out": mask_attn_out,
            "resid": mask_resid,
            "transcoder": mask_transcoder
        }

    ablated_nodes = []

    for layer, masks in ablation_masks.items():
        for mask_type, mask in masks.items():
            if mask is not None:
                for token_type, mask in mask.items():
                        # deleted = (1 - mask)
                        deleted = mask
                        node_ids = np.where(deleted)[0]

                        for node_id in node_ids:
                            ablated_nodes.append((layer, mask_type, token_type, node_id.tolist()))


    typed_ies = {
        "r": circuitizer.ie_resid,
        "a": circuitizer.ie_attn,
        "t": circuitizer.ie_transcoder,
    }

    ablated_nodes_with_ie = []

    for node in ablated_nodes:
        layer, sae_type, token_type, node_id = node
        ies = typed_ies[sae_type[0]][layer]
        masked_ies = circuitizer.mask_average(ies, token_type)
        ablated_nodes_with_ie.append(node + (masked_ies[node_id].tolist(),))


    return ablated_nodes_with_ie

In [16]:
task_name = "antonyms"

features = important_features(task_name)

162
162
Setting up masks...
Running metrics...
Setting up RMS...


  0%|          | 0/18 [00:00<?, ?it/s]

Loading SAEs...


  0%|          | 0/7 [00:00<?, ?it/s]

Running node IEs...


  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

-20.0 -105.0


  0%|          | 0/200 [00:00<?, ?it/s]

0.0010843659686896108


In [19]:
features = sorted(features, key=lambda x: x[-1], reverse=True)

features[:10]

In [20]:
n_few_shots, batch_size, max_seq_len = 8, 12, 128
seed = 10

task_name = "antonyms"

sep = 3978
pad = 0

prompt = "Follow the pattern:\n{}"

pairs = list(tasks[task_name].items())

n_shot = n_few_shots
if task_name.startswith("algo"):
    n_shot = 8

runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_few_shots, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

tokenized = runner.get_tokens([
    x[:n_few_shots] for x in runner.train_pairs
], tokenizer)

inputs = tokenized_to_inputs(**tokenized)
train_tokens = tokenized["input_ids"]

_, all_resids = get_resids_call(inputs)

# scale = 25

# resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

In [22]:
resid_features = [
    x for x in features if x[1] == "resid"
][:30]

In [None]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite
from micrlhf.utils.load_sae import sae_encode
saes = {
    layer: get_nev_it_sae_suite(layer, llama) for layer in range(10, 15)
}

acts = {
    
}


In [8]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite


sae = get_nev_it_sae_suite(layer=layer)

In [9]:
import json

with open("cleanup_results_final.jsonl") as f:
    lines = f.readlines()
    results = [json.loads(line) for line in lines]

In [10]:
import pandas as pd

df = pd.DataFrame(results)

In [11]:

layer_df = df[df["layer"] == layer]

In [12]:
import numpy as np
from micrlhf.utils.load_sae import sae_encode

features = {}


for task in tasks:
    tv = layer_df[layer_df["task"] == task]["tv"].to_numpy()[0]
    tv = np.array(tv)

    weights = layer_df[layer_df["task"] == task]["weights"].to_numpy()[0]
    weights = np.array(weights)

    w = weights

    _, w, _ = sae_encode(sae, None, pre_relu=weights)


    features[task] = np.nonzero(w * (w > 0))[0]



features

In [13]:
n_few_shots, batch_size, max_seq_len = 20, 16, 256
seed = 10

task_name = "antonyms"

sep = 3978
pad = 0

prompt = "Follow the pattern:\n{}"

pairs = list(tasks[task_name].items())

n_shot = n_few_shots - 1
if task_name.startswith("algo"):
    n_shot = 8

runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_few_shots, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

tokenized = runner.get_tokens([
    x[:n_few_shots] for x in runner.train_pairs
], tokenizer)

inputs = tokenized_to_inputs(**tokenized)
train_tokens = tokenized["input_ids"]

_, all_resids = get_resids_call(inputs)

# scale = 25

resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

In [14]:
_, feature_activations, _ = sae_encode(sae, resids)

In [31]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

task_features = features["antonyms"]

fig = make_subplots(rows=len(task_features),
                    subplot_titles=[f"Feature {x}" for x in task_features],
                    shared_xaxes=True)

for i, feature in enumerate(task_features):
    acts = feature_activations[:, :100, feature]

    heatmap = go.Heatmap(
        z=acts,
        colorscale="Viridis",
    )

    fig.add_trace(heatmap, row=i + 1, col=1)

fig.update_layout(height=1000, width=800)
fig.show()


In [None]:
import os
import json
import numpy as np
from tqdm import tqdm

if "models" not in os.listdir("."):
    os.chdir("..")

import penzai
from penzai import pz

from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

from sprint.icl_sfc_utils import Circuitizer
from sprint.task_vector_utils import load_tasks, ICLRunner

# Load tasks
tasks = load_tasks()

batch_size = 8 
n_shot = 12
max_seq_len = 128
seed = 10

# List of task names
task_names = list(tasks.keys())

# Prepare output file for jsonl
# output_filepath = "task_faithfulness_metrics.jsonl"

# Initialize tqdm for progress tracking
# task_pairs_progress = tqdm(total=len(task_names) * (len(task_names)), desc="Processing task pairs")


task_name = "antonynms"


    # Load and prepare first task
    first_task = tasks[task_name]
    first_pairs = list(first_task.items())
    prompt = "Follow the pattern:\n{}"
    layers = list(range(11, 17))
    n_few_shot = n_shot
    if task_name.startswith("algo"):
        n_few_shot = 8
    
    # Define first runner and circuitizer
    first_runner = ICLRunner(task_name, first_pairs, batch_size=batch_size, n_shot=n_few_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt, use_same_examples=False, use_same_target=False)
    circuitizer = Circuitizer(llama, tokenizer, first_runner, layers, prompt=prompt)

    # Calculate original and zero metrics for the first task
    first_orig_metric = circuitizer.ablated_metric(llama).tolist()
    first_zero_metric = circuitizer.run_ablated_metrics([100000], layers=layers)[0][0]

    # Log thresholds and metrics settings
    thresholds = np.logspace(-5, 0, 200)
    topks = [4, 6, 12, 16, 24, 32]

    inverse = True
    do_abs = False
    mean_ablate = False
    average_over_positions = True

    # 1. Metrics for first_runner on first_task, while ablating using second_runner
    first_ablated_metrics, first_n_nodes_counts = circuitizer.run_ablated_metrics(
        thresholds, 
        inverse=inverse, 
        do_abs=do_abs, 
        mean_ablate=mean_ablate, 
        average_over_positions=average_over_positions,
        token_prefix=None, 
        layers=layers,
    )
    first_faithfullness = (np.array(first_ablated_metrics) - first_zero_metric) / (first_orig_metric - first_zero_metric)

    # Save metrics data for first runner
    first_metrics_data = {
        "task": first_task,
        "inverse": inverse,
        "orig_metric": first_orig_metric,
        "zero_metric": first_zero_metric,
        "thresholds": thresholds.tolist(),
        "n_nodes_counts": first_n_nodes_counts,
        "ablated_metrics": first_ablated_metrics,
        "faithfullness": first_faithfullness.tolist(),
        "layers": layers
    }

        # Save both results in the JSON Lines file
    with open(output_filepath, 'a') as jsonl_file:
        jsonl_file.write(json.dumps(first_metrics_data) + "\n")


        # task_pairs_progress.update(1)

    inverse = False
    do_abs = False
    mean_ablate = False
    average_over_positions = True

    # 1. Metrics for first_runner on first_task, while ablating using second_runner
    first_ablated_metrics, first_n_nodes_counts = circuitizer.run_ablated_metrics(
        thresholds, 
        inverse=inverse, 
        do_abs=do_abs, 
        mean_ablate=mean_ablate, 
        average_over_positions=average_over_positions,
        token_prefix=None, 
        layers=layers,
    )
    first_faithfullness = (np.array(first_ablated_metrics) - first_zero_metric) / (first_orig_metric - first_zero_metric)

    # Save metrics data for first runner
    first_metrics_data = {
        "task": first_task,
        "inverse": inverse,
        "orig_metric": first_orig_metric,
        "zero_metric": first_zero_metric,
        "thresholds": thresholds.tolist(),
        "n_nodes_counts": first_n_nodes_counts,
        "ablated_metrics": first_ablated_metrics,
        "faithfullness": first_faithfullness.tolist(),
        "layers": layers
    }

        # Save both results in the JSON Lines file
    with open(output_filepath, 'a') as jsonl_file:
        jsonl_file.write(json.dumps(first_metrics_data) + "\n")
        # task_pairs_progress.update(1)


from argparse import ArgumentParser


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("task_name", type=str, help="Name of the task to run the circuit ablation on.")
    args = parser.parse_args()

    main(args.task_name)
