In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [5]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

fatal: destination path 'data/itv' already exists and is not an empty directory.


In [6]:
# task_names = ["en_es", "antonyms", "person_profession", "es_en", "present_simple_gerund", "present_simple_past_simple", "person_profession", "person_language", "country_capital", "football_player_position"]
# task_name = task_names[1]
task_names = list(tasks.keys())

In [7]:
import jax.numpy as jnp
import jax


from sprint.task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper


get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)



def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


In [8]:
from sprint.icl_sfc_utils import AblatedModule
layer = 12
mask_name = "arrow"

In [9]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite, sae_encode


sae = get_nev_it_sae_suite(layer=layer)

In [10]:
import json

with open("cleanup_results_final.jsonl") as f:
    lines = f.readlines()
    results = [json.loads(line) for line in lines]

with open("cleanup_results_algo.jsonl") as f:
    lines = f.readlines()
    results_algo = [json.loads(line) for line in lines]

In [11]:
import numpy as np

features = []
task_features = {}
tv_features = {}

for task_name in task_names:
    if not task_name.startswith("algo"):
        task_results = [result for result in results if result["task"] == task_name and result["layer"] == layer]
    else:
        task_results = [result for result in results_algo if result["task"] == task_name and result["layer"] == layer]

    for result in task_results:
        weights = np.array(result["weights"])
        tv = np.array(result["tv"])
        # s = jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"]
        # threshold = jnp.maximum(0, sae["b_gate"] - sae["b_enc"] * s)
        # w = weights
        # w = w * (w > 0)

        _, w, _ = sae_encode(sae, None, pre_relu=weights)

        _, w_tv, _ = sae_encode(sae, tv)
        # print(threshold)

        new_features = np.nonzero(w)[0].tolist()
        features += new_features
        print(task_name, "TVC:", result["loss"], "TV:", result["tv_loss"], new_features)

        task_features[task_name] = new_features
        tv_features[task_name] = np.nonzero(w_tv)[0].tolist()
    

features = list(set(features))

len(features)

location_continent TVC: 7.28125 TV: 5.375 [6780, 7578, 9662, 14612, 24925]
football_player_position TVC: 11.0625 TV: 9.125 [7491, 13458, 27401]
location_religion TVC: 4.96875 TV: 6.75 [7578, 9600, 9662, 11172, 14612, 20832, 24925, 27401]
location_language TVC: 5.375 TV: 5.28125 [7578, 9662, 11172, 24925]
person_profession TVC: 6.5 TV: 6.03125 [6413, 7491, 7578, 13458, 27401]
location_country TVC: 3.625 TV: 4.25 [850, 9662, 11173, 24925]
country_capital TVC: 2.171875 TV: 2.046875 [11173, 17636, 18803]
person_language TVC: 4.5625 TV: 4.25 [850, 7578, 11172, 13458, 24925]
singular_plural TVC: 4.8125 TV: 3.59375 [2930, 6594, 12943, 14612]
present_simple_past_simple TVC: 1.5703125 TV: 0.53515625 [2930, 6594, 15356]
antonyms TVC: 2.234375 TV: 2.40625 [7578, 7739, 10720, 11618, 19097, 19112, 25576]
plural_singular TVC: 3.078125 TV: 3.0625 [2930, 6594]
present_simple_past_perfect TVC: 3.328125 TV: 3.046875 [2930, 6594, 15356]
present_simple_gerund TVC: 3.25 TV: 2.125 [6594, 15554]
en_it TVC: 9

In [12]:
n_few_shots, batch_size, max_seq_len = 20, 32, 256
seed = 10

prompt = "Follow the pattern:\n{}"

sep = 3978
pad = 0
newline = 108

In [13]:
import dataclasses
from tqdm.auto import tqdm
from functools import partial
from micrlhf.utils.activation_manipulation import add_vector

task_losses_positive = {}

n_few_shots, batch_size, max_seq_len = 20, 32, 64
seed = 10

prompt = "Follow the pattern:\n{}"

def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

for task_name in tqdm(task_names):

    sep = 3978
    pad = 0


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 8

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=1, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    scale = 15

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = train_tokens == sep
    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)
    positions = sorted_indices[:, :1]

    
    def steer_with_direction(direction):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        return logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2, do_ppl=False)

    task_losses_positive[task_name] = [[steer_with_direction(sae["W_dec"][feature]).tolist() for feature in tqdm(features)]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    task_losses_positive[task_name].append(logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2, do_ppl=False).tolist())

  0%|          | 0/23 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

In [14]:
normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (base_loss - losses) / base_loss
    
    max_loss = np.max(losses)
    # min_loss = np.min(losses)

    losses = losses / max_loss

    losses[losses < 0.2] = 0.0
    losses[losses > 1.0] = 1.0
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

In [15]:
avg_heatmap  = np.max(heatmap, axis=0)

sorted_idx = np.argsort(-avg_heatmap)

heatmap = heatmap[:, sorted_idx]

# min_heatmap = np.min(heatmap, axis=1)   

min_pos = np.argmax(heatmap, axis=1)

y_sorted_idx = np.argsort(min_pos)

heatmap = heatmap[y_sorted_idx]


In [16]:
sorted_tasks=[task_names[x] for x in y_sorted_idx]
sorted_features=[features_dropped[x] for x in sorted_idx]

In [17]:
best_features = {}

for i, task_name in enumerate(sorted_tasks):
    best_features[task_name] = sorted_features[np.argmax(heatmap[i])]

In [18]:
best_features["singular_plural"] = 14612
best_features["fr_en"] = 5579
best_features["present_simple_past_perfect"] = 15356

In [19]:
n_few_shots, batch_size, max_seq_len = 20, 32, 256
seed = 10

prompt = "Follow the pattern:\n{}"

sep = 3978
pad = 0
newline = 108

In [45]:
task_name = "antonyms"

pairs = list(tasks[task_name].items())
n_shot = n_few_shots - 1
if task_name.startswith("algo"):
    n_shot = 16

runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_few_shots, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

tokenized = runner.get_tokens([
    x[:n_few_shots] for x in runner.train_pairs
], tokenizer)

inputs = tokenized_to_inputs(**tokenized)
train_tokens = tokenized["input_ids"]

# _, all_resids = get_resids_call(inputs)

# resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

# activations_pre, activations, _ = sae_encode(sae, resids)

In [46]:
prompt_length = len(tokenizer.tokenize(prompt))
tokens = train_tokens

masks = [
    ("prompt", jnp.zeros_like(tokens).at[:, :prompt_length].set(1).astype(bool)),
    # ("input", jnp.roll(tokens == sep, -1, axis=-1).at[:, :prompt_length].set(False)),
    ("arrow", jnp.array(tokens == sep).at[:, :prompt_length].set(False)), 
    # ("output", jnp.roll(tokens == newline, -1, axis=-1).at[:, :prompt_length].set(False)),
    ("newline", jnp.array(tokens == newline).at[:, :prompt_length].set(False)),
]

input_mask = (tokens == sep) * -1
input_mask += tokens == newline 
input_mask = np.cumsum(input_mask, axis=1)
input_mask -= tokens == newline 
input_mask[:, :prompt_length] = 0
input_mask = input_mask == 1

masks.append(("input", input_mask))

output_mask = (tokens == newline) * -1
output_mask += tokens == sep
output_mask[:, :prompt_length] = 0
output_mask = np.cumsum(output_mask, axis=1)
output_mask -= tokens == sep
output_mask = output_mask == 1

masks.append(("output", output_mask))


In [51]:
import plotly.express as px
from plotly.subplots import make_subplots

from collections import defaultdict

feature_masses = defaultdict(lambda: defaultdict(lambda: 1e-6))

for task_name in tqdm(task_names):
    pairs = list(tasks[task_name].items())
    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 16

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=n_few_shots, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    activations_pre, activations, _ = sae_encode(sae, resids)

    prompt_length = len(tokenizer.tokenize(prompt))
    tokens = train_tokens

    masks = [
        ("prompt", jnp.zeros_like(tokens).at[:, :prompt_length].set(1).astype(bool)),
        # ("input", jnp.roll(tokens == sep, -1, axis=-1).at[:, :prompt_length].set(False)),
        ("arrow", jnp.array(tokens == sep).at[:, :prompt_length].set(False)), 
        # ("output", jnp.roll(tokens == newline, -1, axis=-1).at[:, :prompt_length].set(False)),
        ("newline", jnp.array(tokens == newline).at[:, :prompt_length].set(False)),
    ]

    input_mask = (tokens == sep) * -1
    input_mask += tokens == newline 
    input_mask = np.cumsum(input_mask, axis=1)
    input_mask -= tokens == newline 
    input_mask[:, :prompt_length] = 0
    input_mask = input_mask == 1

    masks.append(("input", input_mask))

    output_mask = (tokens == newline) * -1
    output_mask += tokens == sep
    output_mask[:, :prompt_length] = 0
    output_mask = np.cumsum(output_mask, axis=1)
    output_mask -= tokens == sep
    output_mask = output_mask == 1

    masks.append(("output", output_mask))
    
    remaining_mask = tokens != pad
    for mask_name, mask in masks:
        remaining_mask = jnp.logical_and(remaining_mask, jnp.logical_not(mask))

    masks.append(("remaining", remaining_mask))

    masks = {
        k: np.array(v) for k, v in masks
    }
    
    for task, feature in best_features.items():
        if task != task_name:
            continue
        feature_activations = activations[:, :, feature]

        feature_activations = np.array(feature_activations)

        mask_activations = {
            mask_name: feature_activations[mask].flatten()
            for mask_name, mask in masks.items()
        }
        
        mask_masses = {
            mask_name: sum(mask_activations[mask_name])
            for mask_name in masks
        }

        for mask_name, mask_mass in mask_masses.items():
            feature_masses[feature][mask_name] += mask_mass

        # for i, (mask_name, mask) in enumerate(masks.items()):

        #     macts = mask_activations[mask_name].tolist()

            # print(sum(macts))

            # if len(macts) > 0:
            #     fig.add_trace(px.histogram(macts, nbins=50).data[0], row=1 + f_idx, col=i + 1)

        # fig.update_layout(title=f"Feature {feature}, Task {task}")

    # fig.update_layout(height=1000)
    # fig

total_masses = {
    f: sum(m.values()) for f, m in feature_masses.items()
}

  0%|          | 0/23 [00:00<?, ?it/s]

In [52]:

for f, m in feature_masses.items():
    print(f"{'Feature ' + str(f):>{8 + 8}}: ", end="")
    for k, v in m.items():
        v = v / total_masses[f]
        print(k, f"{v:>6.3f}", end=", ")
    print() 

     Feature 850: prompt  0.000, arrow  0.995, newline  0.000, input  0.005, output  0.000, remaining  0.000, 
   Feature 13458: prompt  0.000, arrow  0.969, newline  0.000, input  0.030, output  0.001, remaining  0.000, 
   Feature 11172: prompt  0.000, arrow  0.983, newline  0.009, input  0.000, output  0.008, remaining  0.000, 
    Feature 7491: prompt  0.000, arrow  1.000, newline  0.000, input  0.000, output  0.000, remaining  0.000, 
   Feature 11173: prompt  0.000, arrow  0.989, newline  0.001, input  0.009, output  0.000, remaining  0.000, 
   Feature 14612: prompt  0.000, arrow  0.866, newline  0.000, input  0.000, output  0.134, remaining  0.000, 
   Feature 15356: prompt  0.000, arrow  0.961, newline  0.000, input  0.000, output  0.039, remaining  0.000, 
   Feature 11618: prompt  0.000, arrow  0.999, newline  0.000, input  0.000, output  0.001, remaining  0.000, 
    Feature 2930: prompt  0.000, arrow  0.922, newline  0.070, input  0.000, output  0.007, remaining  0.000, 
 

In [56]:
mean_masses = defaultdict(lambda : 0)


for f, m in feature_masses.items():
    for k, v in m.items():
        mean_masses[k] += v / total_masses[f]

for k, v in mean_masses.items():
    mean_masses[k] = v / len(feature_masses)

# for f, m in feature_masses.items():
#     for k, v in m.items():
#         mean_masses[k] = mean_masses.get(k, 0) / total_masses[f] + v

# for k, v in mean_masses.items():
#     mean_masses[k] = v / len(feature_masses)

# total_mass = sum(mean_masses.values())

# for k, v in mean_masses.items():
#     mean_masses[k] = v / total_mass


for k, v in mean_masses.items():
    print(k, round(v * 100, 2))

mean_masses

prompt 0.0
arrow 89.8
newline 0.54
input 3.2
output 6.46
remaining 0.0


In [43]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (base_loss - losses) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    # losses = (losses - min_loss) / (max_loss - min_loss)

    # losses = losses / max_loss

    # losses[losses < 0.2] = 0.0
    # losses[losses > 1.0] = 1.0

    # losses = np.clip(losses, 0.2, 1.0)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

def make_heatmap():
    normalized_losses = {}

    drop_features = []
    drop_ids = [features.index(feature) for feature in drop_features]

    features_dropped = [feature for feature in features if feature not in drop_features]

    for task_name, losses in task_losses_positive.items():
        base_loss = losses[1]
        losses = losses[0]
        losses = np.array(losses)
        losses = np.delete(losses, drop_ids)

        losses = np.minimum(losses, base_loss)

        losses = (base_loss - losses) / base_loss
        
        max_loss = np.max(losses)
        # min_loss = np.min(losses)

        losses = losses / max_loss

        losses[losses < 0.2] = 0.0
        losses[losses > 1.0] = 1.0
        
        # mean_loss = np.mean(losses - base_loss)

        # base_acc = losses[1][1]

        # accs = [base_acc - loss[1] for loss in losses[0]]
        normalized_losses[task_name] = losses


    heatmap = np.zeros((len(task_names), len(features_dropped)))

    for i, task_name in enumerate(task_names):
        for j, feature in enumerate(features_dropped):
            heatmap[i, j] = normalized_losses[task_name][j]

    return heatmap


# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.log(1 + heatmap * 40)

# heatmap = np.clip(heatmap, -5, 5)

# avg_heatmap = np.sum(heatmap != 1, axis=0)

sort_map = make_heatmap()
# sort_map = heatmap

avg_heatmap  = np.max(sort_map, axis=0)

sorted_idx = np.argsort(-avg_heatmap)

heatmap = heatmap[:, sorted_idx]

sort_map = sort_map[:, sorted_idx]

# min_heatmap = np.min(heatmap, axis=1)   

min_pos = np.argmax(sort_map, axis=1)

y_sorted_idx = np.argsort(min_pos)

heatmap = heatmap[y_sorted_idx]


heatmap = np.flip(heatmap, axis=0)

fig = px.imshow(heatmap, x=[str(features_dropped[x]) for x in sorted_idx], y=[task_names[x] for x in y_sorted_idx][::-1], width=1000, height=600, aspect="auto", color_continuous_scale="Blues")

def make_filp_mask():
    heatmap = np.zeros((len(task_names), len(features)))
    feat_to_idx = {feat: i for i, feat in enumerate(features)}


    for i, task_name in enumerate(task_names):
        feats = tv_features[task_name]
        for feat in feats:
            if feat in feat_to_idx:
                heatmap[i, feat_to_idx[feat]] = 2


        feats = task_features[task_name]
        for feat in feats:
            if feat in feat_to_idx:
                heatmap[i, feat_to_idx[feat]] = 1

    heatmap = heatmap[y_sorted_idx]
    heatmap = heatmap[:, sorted_idx]

    return heatmap

filp_mask = make_filp_mask()


filp_mask = np.flip(filp_mask, axis=0)

annotation_style = dict(
    showarrow=False,
    font=dict(color="black", size=7),
    xanchor="center",
    yanchor="middle"
)

special_cells = {
    tuple(x): ["🟢"] for x in np.argwhere(filp_mask == 1)
}

special_cells.update({
    tuple(x): ["🔴"] for x in np.argwhere(filp_mask == 2)
})

# special_cells.update({
#     tuple(x): ["✔", "*"] for x in np.argwhere(filp_mask == 3)
# })

annotations = [
    {**annotation_style, **dict(x=cell[1], y=cell[0], text=icon, xshift=0, yshift=0 - i * 6)}
    for cell, icons in special_cells.items()
    for i, icon in enumerate(icons)
]
fig.update_layout(annotations=annotations)

fig.show()


In [45]:
fig = px.imshow(heatmap, x=[str(features_dropped[x]) for x in sorted_idx], y=[task_names[x] for x in y_sorted_idx][::-1], width=1200, height=600, aspect="auto", color_continuous_scale="Blues", labels=dict(x="Feature", y="Task", color="Relative loss decrease"))

annotation_style = dict(
    showarrow=False,
    font=dict(color="MediumSeaGreen", size=7),
    xanchor="center",
    yanchor="middle"
)

special_cells = {
    tuple(x): ["●"] for x in np.argwhere(filp_mask == 1)
}

annotations = [
    {**annotation_style, **dict(x=cell[1], y=cell[0], text=icon, xshift=0, yshift=0 - i * 6)}
    for cell, icons in special_cells.items()
    for i, icon in enumerate(icons)
]

annotation_style = dict(
    showarrow=False,
    font=dict(color="red", size=7),
    xanchor="center",
    yanchor="middle"
)

special_cells ={
    tuple(x): ["●"] for x in np.argwhere(filp_mask == 2)
}

# special_cells.update({
#     tuple(x): ["✔", "*"] for x in np.argwhere(filp_mask == 3)
# })


annotations += [
    {**annotation_style, **dict(x=cell[1], y=cell[0], text=icon, xshift=0, yshift=0 - i * 6)}
    for cell, icons in special_cells.items()
    for i, icon in enumerate(icons)
]

# legend_annotations = [
#     dict(
#         x=1.05, y=0.8, text="🟢 Feature is present after cleaning", showarrow=False,
#         font=dict(size=8), xanchor="left", yanchor="top"
#     ),
#     dict(
#         x=1.05, y=0.95, text="🔴 Feature is in task vector", showarrow=False,
#         font=dict(size=8), xanchor="left", yanchor="top"
#     )
# ]

# fig.update_layout(width =600, height=250, 
#                 font_family="Serif", font_size=7, 
#                 margin_l=5, margin_t=5, margin_b=5, margin_r=5,
#         title=dict(
#         text=legend_html,
#         x=0.525,  # Center the title
#         y=0.09,  # Place the legend below the plot
#         xanchor="center",
#         yanchor="top"
#     ))


# legend_html = """
# <span style="font-size:8px;">
# <span style="color:red">•</span> Feature is in task vector | <span style="color:green">•</span> Feature is present after cleaning
# </span>
# """

# Update layout with annotations
fig.update_layout(annotations=annotations)

legend_html = """
<span style="font-size:8px;">
<span style="color:red">●</span> Feature is in task vector | <span style="color:MediumSeaGreen">●</span> Feature is present after cleaning
</span>
"""

fig.update_layout(width =600, height=300, 
                font_family="Serif", font_size=7,
                margin_l=5, margin_t=5, margin_b=5, margin_r=5,
        title=dict(
        text=legend_html,
        x=0.525,  # Center the title
        y=0.079,  # Place the legend below the plot
        xanchor="center",
        yanchor="top"
    ))


# fig.update_layout(annotations=annotations)

import plotly.io as pio
pio.write_image(fig, "micrlhf-progress/images/executor_heatmap_l12_new_full_non_n.pdf", width =600, height=300)
fig

In [42]:
fig = px.imshow(heatmap[:, :23], x=[str(features_dropped[x]) for x in sorted_idx[:23]], y=[task_names[x] for x in y_sorted_idx][::-1], width=800, height=600, aspect="auto", color_continuous_scale="Blues", labels=dict(x="Feature", y="Task", color="Effect strength"))

annotation_style = dict(
    showarrow=False,
    font=dict(color="MediumSeaGreen", size=7),
    xanchor="center",
    yanchor="middle"
)

special_cells = {
    tuple(x): ["●"] for x in np.argwhere(filp_mask == 1)
}

annotations = [
    {**annotation_style, **dict(x=cell[1], y=cell[0], text=icon, xshift=0, yshift=0 - i * 6)}
    for cell, icons in special_cells.items()
    for i, icon in enumerate(icons)
]

annotation_style = dict(
    showarrow=False,
    font=dict(color="red", size=7),
    xanchor="center",
    yanchor="middle"
)

special_cells ={
    tuple(x): ["●"] for x in np.argwhere(filp_mask == 2)
}

annotations += [
    {**annotation_style, **dict(x=cell[1], y=cell[0], text=icon, xshift=0, yshift=0 - i * 6)}
    for cell, icons in special_cells.items()
    for i, icon in enumerate(icons)
]
fig.update_layout(width =400, height=300, 
                font_family="Serif", font_size=7,
                margin_l=5, margin_t=5, margin_b=5, margin_r=5,
        )


# fig.update_layout(annotations=annotations)

import plotly.io as pio
pio.write_image(fig, "micrlhf-progress/images/executor_heatmap_l12_new.pdf", width =400, height=300)
fig

In [98]:
import dataclasses
from tqdm.auto import tqdm
from functools import partial
from micrlhf.utils.activation_manipulation import add_vector, add_vector_many

positive_task_losses_testing = {}

n_few_shots, batch_size, max_seq_len = 1, 16, 64
seed = 10

prompt = "Follow the pattern:\n{}"


def calc_acc(tokens, sep, logits):
    preds = logits.argmax(-1)
    preds = preds[:, :-1]
    targets = tokens[:, 1:]

    mask = tokens[:, :-1] == sep

    hits = targets == preds

    hits = hits * mask

    hits = hits.sum()
    return hits / mask.sum()




def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

for task_name in tqdm(task_names):

    sep = 3978
    pad = 0
    pairs = list(tasks[task_name].items())

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=1, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = train_tokens == sep
    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)
    positions = sorted_indices[:, :1]

    def steer_with_direction(direction, scale):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)

        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        acc = calc_acc(train_tokens, sep, logits)

        return logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2), acc

    # _features = task_features[task_name]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    acc = calc_acc(train_tokens, sep, logits)


    for feature in tqdm(features):
        positive_task_losses_testing[(task_name, feature)] = [[steer_with_direction(sae["W_dec"][feature], scale) for scale in np.logspace(0, 1.5, 50)]]
        positive_task_losses_testing[(task_name, feature)].append((logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2), acc))

  0%|          | 0/23 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



  0%|          | 0/43 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



In [103]:


feature = 11618

heatmap = np.zeros((len(task_names), 50))

for i, task_name in enumerate(task_names):
    losses, base_loss = positive_task_losses_testing[(task_name, feature)]
    # print(task_name, base_loss[0])
    if task_name.startswith("algo"):
        continue
    # if base_loss[1] < 0.4:
    #     continue
    # negative_task_losses_testing[(task_name, feature)]
    heatmap[i] =  float(base_loss[0]) - np.array([float(loss) for loss, _  in losses])
    heatmap[i] = heatmap[i] / float(base_loss[0])
    # heatmap[i] = heatmap[i] / float(base_loss[1])
# clipped = np.maximum(heatmap, 0)
# heatmap /= clipped.mean(axis=0)
# print(clipped.mean(0).shape)

px.imshow(heatmap, x=np.logspace(0, 1.5, 50), y=task_names, height=600)

In [112]:
import plotly.express as px

heatmap = np.zeros((len(task_names), len(features)))    

for i, feature in enumerate(features):
    run_data = np.zeros((len(task_names), 50))
    for j, task_name in enumerate(task_names):
        losses, base_loss = positive_task_losses_testing[(task_name, feature)]
        n_loss = 0
        # if task_name.startswith("algo"):
        #     continue
        # if base_loss[1] < 0.3:
        #     continue
        # if task_name == "present_simple_gerund":
        #     continue
        run_data[j] = float(base_loss[n_loss]) - np.array([float(loss[n_loss]) for loss  in losses])
        # print(
        #     task_name, base_loss
        # )
        run_data[j] = run_data[j] / float(base_loss[n_loss])
        # run_data[j] *= -1
    # run_data = run_data.max(0)
    # find first idx where > 1
    first_idx = np.argmax(run_data.max(0))

    # print(first_idx)
    
    # run_data = np.maximum(run_data, 0)
    run_data[run_data < 0] = 0
    # run_data[run_data > 0.1] = 0.3
    heatmap[:, i] = run_data[:, first_idx]
    # heatmap[:, i] = run_data[:, first_idx]
    # heatmap[:, i] = heatmap[:, i] / heatmap[:, i].mean() 

# heatmap = heatmap / heatmap.mean(0)


# heatmap = np.maximum(heatmap, 0)
sorted_indices = np.argsort((heatmap).max(0))

sorted_indices = sorted_indices[::-1]

heatmap = heatmap[:, sorted_indices]

min_pos = np.argmax(heatmap, axis=1)

y_sorted_idx = np.argsort(min_pos)

heatmap = heatmap[y_sorted_idx]

def make_filp_mask():
    heatmap = np.zeros((len(task_names), len(features)))
    feat_to_idx = {feat: i for i, feat in enumerate(features)}


    for i, task_name in enumerate(task_names):
        feats = task_features[task_name]
        for feat in feats:
            if feat in feat_to_idx:
                heatmap[i, feat_to_idx[feat]] = -2

    heatmap = heatmap[y_sorted_idx]
    heatmap = heatmap[:, sorted_indices]

    return heatmap + 1



filp_mask = make_filp_mask()
fig = px.imshow(heatmap, x=[str(features[i]) for i in sorted_indices] ,  y=[task_names[i] for i in y_sorted_idx] , height=600, color_continuous_scale="Blues")

special_cells = {
    tuple(x): ["🔴"] for x in np.argwhere(filp_mask == -1)
}

for cell, icons in special_cells.items():
    for i, icon in enumerate(icons):
        fig.add_annotation(
            x=cell[1],  # Column index
            y=cell[0],  # Row index
            text=icon,  # Icon or text
            showarrow=False,
            font=dict(color="white", size=5),
            xanchor="center",
            yanchor="middle",
            yshift=i * 10  # Adjust vertical position to stack icons
        )

fig


In [19]:
with open("micrlhf-progress/executor_heatmap_l12.json", "w") as f:
    json.dump({"heatmap": heatmap.tolist(), "features": [features_dropped[x] for x in sorted_idx], "task_names": [task_names[x] for x in y_sorted_idx]}, f)

In [24]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (losses - base_loss) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

# avg_heatmap = np.sum(heatmap != 1, axis=0)

avg_heatmap  = np.mean(heatmap, axis=0)

sorted_idx = np.argsort(avg_heatmap)

heatmap = heatmap[:, sorted_idx]


fig = px.imshow(heatmap, x=[str(features_dropped[x]) for x in sorted_idx], y=task_names, width=1400, height=400, aspect="auto")

fig.show()


In [34]:
tvs = [x["tv"] for x in results if x["layer"] == layer]

len(tvs)

In [43]:
avg_tv = np.mean(tvs, axis=0)

avg_tv = avg_tv.astype(jnp.float32)

In [58]:
import dataclasses
from tqdm.auto import tqdm
from functools import partial
from micrlhf.utils.activation_manipulation import add_vector

task_losses_positive = {}

n_few_shots, batch_size, max_seq_len = 20, 16, 256
seed = 10

prompt = "Follow the pattern:\n{}"

def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

for task_name in tqdm(["location_religion"]):

    sep = 3978
    pad = 0


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 8

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=1, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    # scale = 25

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = train_tokens == sep
    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)
    positions = sorted_indices[:, :1]
    
    feature = 11172

    def steer_with_direction(direction, scale):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale + avg_tv * 0.5
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        return logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2)

    task_losses_positive[task_name] = [[steer_with_direction(sae["W_dec"][feature], scale).tolist() for scale in tqdm(np.logspace(0, 2, 100))]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    task_losses_positive[task_name].append(logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2).tolist())

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]


scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=bfloat16 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.



In [59]:
px.line(x=np.logspace(0, 2, 100), y=task_losses_positive["location_religion"][0])

In [54]:
px.line(x=np.logspace(0, 2, 100), y=task_losses_positive["antonyms"][0])

In [47]:
task_losses_positive["antonyms"][0]

In [46]:
heatmap.shape

In [45]:
sorted_idx

In [67]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in negative_task_losses.items():
    base_loss = losses[1]
    losses = [x[1].tolist() for x in losses[0]]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = (losses - base_loss[1]) 
    
    # max_loss = np.max(losses)
    # min_loss = np.min(losses)

    # losses = (losses - min_loss) / (max_loss - min_loss)

    # losses = (losses - min_loss) / (max_loss - min_loss)
    # loss = (base_loss - losses) / base_loss
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

std_heatmap = np.std(heatmap, axis=0)

sorted_idx = np.argsort(-std_heatmap)

heatmap = heatmap[:, sorted_idx]

labels = [str(features_dropped[x]) for x in sorted_idx]

fig = px.imshow(heatmap, x=labels, y=task_names, width=1400, height=800, aspect="auto")

fig.show()


In [66]:
task_features

In [120]:
features_dropped.index(7491)

In [121]:
mean_heatmap = np.mean(heatmap, axis=0)

drop_features = np.where(mean_heatmap < 0.9)[0]

drop_features = [features_dropped[x] for x in drop_features]

drop_features

In [115]:
heatmap.shape

In [116]:
(22 + 0.1)/23

In [127]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)


    losses = np.minimum(losses, base_loss * 1.5)

    losses = (losses - base_loss) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / ((max_loss - min_loss) + 1e-6)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

heatmap = np.log(heatmap + 1)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names, color_continuous_scale='YlGnBu')

fig.show()

In [125]:
task_losses_positive["present_simple_past_perfect"]

In [84]:
mean_heatmap

In [71]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in negative_task_losses.items():
    base_loss = losses[1]
    losses = [x[0].tolist() for x in losses[0]]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = losses - base_loss[0]
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names)

fig.show()


NameError: name 'negative_task_losses' is not defined