In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [2]:
import json

with open("cleanup_results_final_detectors_2.jsonl") as f:
    lines = f.readlines()
    results = [json.loads(line) for line in lines]

In [3]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [4]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [5]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [6]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

fatal: destination path 'data/itv' already exists and is not an empty directory.


In [7]:
# task_names = ["en_es", "antonyms", "person_profession", "es_en", "present_simple_gerund", "present_simple_past_simple", "person_profession", "person_language", "country_capital", "football_player_position"]
# task_name = task_names[1]
task_names = list(tasks.keys())

In [8]:
import jax.numpy as jnp
import jax


from sprint.task_vector_utils import ICLRunner, logprob_loss
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper


get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)



def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


In [9]:
from sprint.icl_sfc_utils import AblatedModule
layer = 11
# mask_name = "arrow"

In [10]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite


sae = get_nev_it_sae_suite(layer=layer)

In [11]:
import numpy as np
from micrlhf.utils.load_sae import sae_encode

features = []
task_features = {}
tv_features = {}

for task_name in task_names:
    task_results = [result for result in results if result["task"] == task_name and result["layer"] == layer]
    

    for result in task_results:
        weights = np.array(result["weights"])
        tv = np.array(result["tv"])
        # s = jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"]
        # threshold = jnp.maximum(0, sae["b_gate"] - sae["b_enc"] * s)
        # w = weights
        # w = w * (w > 0)

        _, w, _ = sae_encode(sae, None, pre_relu=weights)

        _, w_tv, _ = sae_encode(sae, tv)
        # print(threshold)

        new_features = np.nonzero(w)[0].tolist()
        features += new_features
        print(task_name, "TVC:", result["loss"], "TV:", result["tv_loss"], new_features)

        task_features[task_name] = new_features
        tv_features[task_name] = np.nonzero(w_tv)[0].tolist()
    

features = list(set(features))

len(features)

location_continent TVC: 7.40625 TV: 6.71875 [11459, 24964, 25185, 25334, 27001, 29362, 30338]
football_player_position TVC: 13.5 TV: 14.0625 [3844, 13181, 19916, 31427]
location_religion TVC: 6.4375 TV: 8.125 [3466, 10685, 12898, 30338]
location_language TVC: 7.90625 TV: 7.9375 [1132, 10884, 12898, 13181, 20079, 32677]
person_profession TVC: 11.875 TV: 10.625 [4258, 7323, 9995, 13181, 16205, 26436]
location_country TVC: 6.125 TV: 5.125 [11459, 13181, 20983, 28297]
country_capital TVC: 3.96875 TV: 2.515625 [6267, 13181, 13529, 20983, 26783]
person_language TVC: 8.1875 TV: 8.6875 [3775, 10884, 13181, 14996, 32677]
singular_plural TVC: 5.4375 TV: 4.09375 [1322, 10672, 12898, 13181, 15764, 27936]
present_simple_past_simple TVC: 4.96875 TV: 2.453125 [10672, 12898, 13181, 16172, 19628, 21327, 27936]
antonyms TVC: 3.046875 TV: 2.53125 [5971, 11050, 12898, 18472, 26470, 30338, 32142]
plural_singular TVC: 4.28125 TV: 4.15625 [1753, 15764, 16205, 18472, 28297]
present_simple_past_perfect TVC: 5.

In [12]:
import dataclasses
from tqdm.auto import tqdm
from functools import partial
from micrlhf.utils.activation_manipulation import add_vector

task_losses_positive = {}

n_few_shots, batch_size, max_seq_len = 20, 12, 256
seed = 10

prompt = "Follow the pattern:\n{}"

def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

prompt_length = len(tokenizer.encode(prompt))

for task_name in tqdm(task_names):

    sep = 3978
    pad = 0
    newline = 108


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 8

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=1, max_seq_len=max_seq_len, seed=seed, prompt=prompt, vector_type="detector")

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.eval_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    scale = 25

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = tokens == newline
    mask = jnp.roll(mask, -1, axis=-1)
    mask = mask.at[:, :prompt_length].set(False)

    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)
    positions = sorted_indices[:, :1]
    
    def steer_with_direction(direction):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        return logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2)

    task_losses_positive[task_name] = [[steer_with_direction(sae["W_dec"][feature]).tolist() for feature in tqdm(features)]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    task_losses_positive[task_name].append(logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2).tolist())

  0%|          | 0/23 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

  0%|          | 0/53 [00:00<?, ?it/s]

In [20]:
normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (base_loss - losses) / base_loss
    
    max_loss = np.max(losses)
    # min_loss = np.min(losses)

    losses = losses / max_loss

    losses[losses < 0.2] = 0.0
    losses[losses > 1.0] = 1.0
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

In [21]:
avg_heatmap  = np.max(heatmap, axis=0)

sorted_idx = np.argsort(-avg_heatmap)

heatmap = heatmap[:, sorted_idx]

# min_heatmap = np.min(heatmap, axis=1)   

min_pos = np.argmax(heatmap, axis=1)

y_sorted_idx = np.argsort(min_pos)

heatmap = heatmap[y_sorted_idx]


In [22]:
sorted_tasks=[task_names[x] for x in y_sorted_idx]
sorted_features=[features_dropped[x] for x in sorted_idx]

In [23]:
best_features = {}

for i, task_name in enumerate(sorted_tasks):
    best_features[task_name] = sorted_features[np.argmax(heatmap[i])]

In [29]:
best_features["present_simple_gerund"] = 19628

In [31]:
import plotly.express as px
from plotly.subplots import make_subplots

from collections import defaultdict

feature_masses = defaultdict(lambda: defaultdict(lambda: 1e-6))

for task_name in tqdm(task_names):
    pairs = list(tasks[task_name].items())
    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 16

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_few_shots, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    activations_pre, activations, _ = sae_encode(sae, resids)

    prompt_length = len(tokenizer.tokenize(prompt))
    tokens = train_tokens

    masks = [
        ("prompt", jnp.zeros_like(tokens).at[:, :prompt_length].set(1).astype(bool)),
        # ("input", jnp.roll(tokens == sep, -1, axis=-1).at[:, :prompt_length].set(False)),
        ("arrow", jnp.array(tokens == sep).at[:, :prompt_length].set(False)), 
        # ("output", jnp.roll(tokens == newline, -1, axis=-1).at[:, :prompt_length].set(False)),
        ("newline", jnp.array(tokens == newline).at[:, :prompt_length].set(False)),
    ]

    input_mask = (tokens == sep) * -1
    input_mask += tokens == newline 
    input_mask = np.cumsum(input_mask, axis=1)
    input_mask -= tokens == newline 
    input_mask[:, :prompt_length] = 0
    input_mask = input_mask == 1

    masks.append(("input", input_mask))

    output_mask = (tokens == newline) * -1
    output_mask += tokens == sep
    output_mask[:, :prompt_length] = 0
    output_mask = np.cumsum(output_mask, axis=1)
    output_mask -= tokens == sep
    output_mask = output_mask == 1

    masks.append(("output", output_mask))

    remaining_mask = tokens != pad
    for mask_name, mask in masks:
        remaining_mask = jnp.logical_and(remaining_mask, jnp.logical_not(mask))

    masks.append(("remaining", remaining_mask))

    masks = {
        k: np.array(v) for k, v in masks
    }
    
    for task, feature in best_features.items():
        if task != task_name:
            continue
        feature_activations = activations[:, :, feature]

        feature_activations = np.array(feature_activations)

        mask_activations = {
            mask_name: feature_activations[mask].flatten()
            for mask_name, mask in masks.items()
        }
        
        mask_masses = {
            mask_name: sum(mask_activations[mask_name])
            for mask_name in masks
        }

        for mask_name, mask_mass in mask_masses.items():
            feature_masses[feature][mask_name] += mask_mass

        # for i, (mask_name, mask) in enumerate(masks.items()):

        #     macts = mask_activations[mask_name].tolist()

            # print(sum(macts))

            # if len(macts) > 0:
            #     fig.add_trace(px.histogram(macts, nbins=50).data[0], row=1 + f_idx, col=i + 1)

        # fig.update_layout(title=f"Feature {feature}, Task {task}")

    # fig.update_layout(height=1000)
    # fig

total_masses = {
    f: sum(m.values()) for f, m in feature_masses.items()
}

  0%|          | 0/23 [00:00<?, ?it/s]

In [32]:
for f, m in feature_masses.items():
    print(f"{'Feature ' + str(f):>{8 + 8}}: ", end="")
    for k, v in m.items():
        v = v / total_masses[f]
        print(k, f"{v:>6.3f}", end=", ")
    print() 

   Feature 11459: prompt  0.000, arrow  0.000, newline  0.000, input  0.082, output  0.918, remaining  0.000, 
   Feature 19916: prompt  0.000, arrow  0.000, newline  0.000, input  0.000, output  1.000, remaining  0.000, 
    Feature 3466: prompt  0.000, arrow  0.000, newline  0.000, input  0.014, output  0.986, remaining  0.000, 
   Feature 10884: prompt  0.000, arrow  0.000, newline  0.000, input  0.107, output  0.893, remaining  0.000, 
   Feature 26436: prompt  0.000, arrow  0.000, newline  0.000, input  0.000, output  1.000, remaining  0.000, 
   Feature 13529: prompt  0.000, arrow  0.000, newline  0.000, input  0.103, output  0.897, remaining  0.000, 
    Feature 1132: prompt  0.000, arrow  0.000, newline  0.000, input  0.000, output  1.000, remaining  0.000, 
    Feature 1322: prompt  0.000, arrow  0.000, newline  0.000, input  0.037, output  0.963, remaining  0.000, 
   Feature 21327: prompt  0.000, arrow  0.000, newline  0.000, input  0.000, output  1.000, remaining  0.000, 
 

In [33]:
mean_masses = defaultdict(lambda : 0)


for f, m in feature_masses.items():
    for k, v in m.items():
        mean_masses[k] += v / total_masses[f]

for k, v in mean_masses.items():
    mean_masses[k] = v / len(feature_masses)

# for f, m in feature_masses.items():
#     for k, v in m.items():
#         mean_masses[k] = mean_masses.get(k, 0) / total_masses[f] + v

# for k, v in mean_masses.items():
#     mean_masses[k] = v / len(feature_masses)

# total_mass = sum(mean_masses.values())

# for k, v in mean_masses.items():
#     mean_masses[k] = v / total_mass


for k, v in mean_masses.items():
    print(k, round(v * 100, 2))

mean_masses

prompt 0.0
arrow 0.0
newline 0.01
input 3.22
output 96.76
remaining 0.0


In [20]:
mean_masses_exec = {
'prompt': 1.2580385270555209e-08,
'input': 0.002129799117443222,
'arrow': 0.8980206741304761,
'output': 0.051759925762009384,
'newline': 0.005361420477192891,
'remaining': 0.04272816793249293,
}

In [33]:

print(f"{' '*16}  {'Detector':<8} {'Executor':<8}")
print("-" * 30)

for k, v in mean_masses.items():
    print(f"{k:<20} {round(v, 3):<10} {round(mean_masses_exec[k], 3):<10}")

                  Detector Executor
------------------------------
prompt               0.01       0.0       
input                0.021      0.002     
arrow                0.01       0.898     
output               0.889      0.052     
newline              0.01       0.005     
remaining            0.061      0.043     


: 

In [None]:
n_few_shots, batch_size, max_seq_len = 20, 32, 256
seed = 10

prompt = "Follow the pattern:\n{}"

sep = 3978
pad = 0
newline = 108

In [18]:
import dataclasses
from tqdm.auto import tqdm
from functools import partial
from micrlhf.utils.activation_manipulation import add_vector

negative_task_losses = {}

n_few_shots, batch_size, max_seq_len = 20, 16, 256
seed = 10

prompt = "Follow the pattern:\n{}"

def calc_acc(tokens, sep, logits, runner):
    arrow_pos = jnp.nonzero(tokens == sep)
    arrow_pos_single = []
    for i in range(batch_size):
        arrow_pos_single.append(arrow_pos[1][arrow_pos[0] == i].max())

    arrow_pos_single = np.array(arrow_pos_single)

    hits = 0

    for i, (ap, l) in enumerate(zip(arrow_pos_single, logits)):
        l = l.argmax(-1)
        tgt = runner.eval_pairs[i][-1][1]
        hits += int(tgt in repr(tokenizer.decode(l[ap:ap+3])))
    return hits / runner.eval_batch_size

def calc_acc(tokens, sep, logits, runner):
    logits = logits.argmax(-1)
    logits = logits[:, :-1]
    tokens = tokens[:, 1:]

    mask = tokens == sep

    hits = tokens == logits

    hits = hits * mask

    hits = hits.sum()
    return hits / mask.sum()




def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

for task_name in tqdm(task_names):

    sep = 3978
    pad = 0


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 16

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_shot, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    scale = 30

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = train_tokens == sep
    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)

    k = jnp.sum(mask[0]).astype(int)

    positions = sorted_indices[:, :k]
    
    def steer_with_direction(direction):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        acc = calc_acc(train_tokens, sep, logits, runner)

        return logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2), acc

    negative_task_losses[task_name] = [[steer_with_direction(-sae["W_dec"][feature]) for feature in tqdm(features)]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    acc = calc_acc(train_tokens, sep, logits, runner)
    negative_task_losses[task_name].append((logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2), acc))

  0%|          | 0/23 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [21]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (losses - base_loss) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

avg_heatmap = np.sum(heatmap != 1, axis=0)

sorted_idx = np.argsort(avg_heatmap)

heatmap = heatmap[:, sorted_idx]


fig = px.imshow(heatmap, x=[str(features_dropped[x]) for x in sorted_idx], y=task_names)

fig.show()


In [20]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (base_loss - losses) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)

    losses = np.clip(losses, 0.2, 10)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.log(1 + heatmap * 40)

# heatmap = np.clip(heatmap, -5, 5)

# avg_heatmap = np.sum(heatmap != 1, axis=0)

avg_heatmap  = np.max(heatmap, axis=0)

sorted_idx = np.argsort(-avg_heatmap)

heatmap = heatmap[:, sorted_idx]

# min_heatmap = np.min(heatmap, axis=1)   

min_pos = np.argmax(heatmap, axis=1)

y_sorted_idx = np.argsort(min_pos)

heatmap = heatmap[y_sorted_idx]

fig = px.imshow(heatmap[:, :100], x=[str(features_dropped[x]) for x in sorted_idx][:100], y=[task_names[x] for x in y_sorted_idx], width=2000, height=600, aspect="auto", color_continuous_scale="Blues", title="Positive steering with detectors on L11")

fig.show()


In [21]:
with open("micrlhf-progress/detector_heatmap_l11.json", "w") as f:
    json.dump({"heatmap": heatmap.tolist(), "features": [features_dropped[x] for x in sorted_idx], "task_names": [task_names[x] for x in y_sorted_idx]}, f)

: 

In [31]:
import dataclasses
from tqdm.auto import tqdm
from functools import partial
from micrlhf.utils.activation_manipulation import add_vector

task_losses_positive = {}

n_few_shots, batch_size, max_seq_len = 20, 12, 256
seed = 10

prompt = "Follow the pattern:\n{}"

def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

prompt_length = len(tokenizer.encode(prompt))

for task_name in tqdm(["antonyms"]):

    sep = 3978
    pad = 0
    newline = 108


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 8

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=1, max_seq_len=max_seq_len, seed=seed, prompt=prompt, vector_type="detector")

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.eval_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    # scale = 25

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = tokens == newline
    mask = jnp.roll(mask, -1, axis=-1)
    mask = mask.at[:, :prompt_length].set(False)

    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)
    positions = sorted_indices[:, :1]

    feature = 11050
    
    def steer_with_direction(direction, scale):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        return logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2)

    task_losses_positive[task_name] = [[steer_with_direction(sae["W_dec"][feature], scale).tolist() for scale in tqdm(np.logspace(0, 2, 100))]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    task_losses_positive[task_name].append(logprob_loss(logits, tokens, sep=sep, pad_token=pad, n_first=2).tolist())

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



In [32]:
px.line(
    x=np.logspace(0, 2, 100),
    y=task_losses_positive["antonyms"][0]
)

In [46]:
heatmap.shape

In [45]:
sorted_idx

In [29]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = [22113]
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in negative_task_losses.items():
    base_loss = losses[1]
    losses = [x[0].tolist() for x in losses[0]]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = losses - base_loss[0]
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

std_heatmap = np.std(heatmap, axis=0)

sorted_idx = np.argsort(-std_heatmap)

heatmap = heatmap[:, sorted_idx]

labels = [str(features_dropped[x]) for x in sorted_idx]

fig = px.imshow(heatmap, x=labels, y=task_names)

fig.show()


In [120]:
features_dropped.index(7491)

In [121]:
mean_heatmap = np.mean(heatmap, axis=0)

drop_features = np.where(mean_heatmap < 0.9)[0]

drop_features = [features_dropped[x] for x in drop_features]

drop_features

In [115]:
heatmap.shape

In [116]:
(22 + 0.1)/23

In [127]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)


    losses = np.minimum(losses, base_loss * 1.5)

    losses = (losses - base_loss) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / ((max_loss - min_loss) + 1e-6)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

heatmap = np.log(heatmap + 1)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names, color_continuous_scale='YlGnBu')

fig.show()

In [125]:
task_losses_positive["present_simple_past_perfect"]

In [84]:
mean_heatmap

In [71]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in negative_task_losses.items():
    base_loss = losses[1]
    losses = [x[0].tolist() for x in losses[0]]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = losses - base_loss[0]
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names)

fig.show()


NameError: name 'negative_task_losses' is not defined