In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("../..")

In [7]:
import json

with open("cleanup_results_final.jsonl") as f:
    lines = f.readlines()
    results = [json.loads(line) for line in lines]

In [32]:
# task_name = "antonyms"

# results = [r for r in results if r["task"] == task_name]


import pandas as pd

df = pd.DataFrame(results)

melted_df = df.melt(id_vars=["layer", "task"], value_vars=["loss", "tv_loss", "ito_loss", "recon_loss"], var_name="loss_type", value_name="loss value")

In [23]:
melted_df.head()

Unnamed: 0,layer,task,loss_type,loss value
0,1,location_continent,loss,10.3125
1,2,location_continent,loss,10.25
2,3,location_continent,loss,10.1875
3,4,location_continent,loss,10.375
4,5,location_continent,loss,10.1875


In [24]:
import plotly.express as px

task_name = "algo_last"

px.line(melted_df[melted_df["task"] == task_name], x="layer", y="loss value", color="loss_type")

In [36]:
task_names = df["task"].unique()
layers = sorted(df["layer"].unique())

task_losses = {task_name: {loss_type: df[df["task"] == task_name][loss_type].to_numpy() for loss_type in ["loss", "tv_loss"]} for task_name in task_names}

normalized_losses = {task_name: {loss_type: (task_losses[task_name][loss_type] - task_losses[task_name][loss_type].min()) / (task_losses[task_name][loss_type].max() - task_losses[task_name][loss_type].min()) for loss_type in ["loss", "tv_loss"]} for task_name in task_names}

In [43]:
heatmap = pd.DataFrame({task_name: normalized_losses[task_name]["loss"] for task_name in task_names}, index=layers)
px.imshow(heatmap)

In [44]:
%load_ext autoreload
%autoreload 2
import penzai
import jax_smi
jax_smi.initialise_tracking()
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [45]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True, device_map="tpu:0")

In [46]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [47]:
from sprint.task_vector_utils import load_tasks, ICLRunner
tasks = load_tasks()

fatal: destination path 'data/itv' already exists and is not an empty directory.


In [48]:
# task_names = ["en_es", "antonyms", "person_profession", "es_en", "present_simple_gerund", "present_simple_past_simple", "person_profession", "person_language", "country_capital", "football_player_position"]
# task_name = task_names[1]
task_names = list(tasks.keys())

In [49]:
import jax.numpy as jnp
import jax


from sprint.task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper


get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)



def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


In [52]:
from sprint.icl_sfc_utils import AblatedModule
layer = 12
mask_name = "arrow"

In [54]:
from micrlhf.utils.load_sae import get_nev_it_sae_suite


sae = get_nev_it_sae_suite(layer=layer)

In [58]:
import numpy as np

features = []

for task_name in task_names:
    task_results = [result for result in results if result["task"] == task_name and result["layer"] == layer]

    for result in task_results:
        w = np.array(result["weights"])
        # s = jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"]
        # threshold = jnp.maximum(0, sae["b_gate"] - sae["b_enc"] * s)
        w = w * (w > 0)

        # print(threshold)

        features += np.nonzero(w)[0].tolist()

features = list(set(features))

len(features)

In [60]:
import dataclasses
from tqdm.auto import tqdm
from functools import partial
from micrlhf.utils.activation_manipulation import add_vector

task_losses_positive = {}

n_few_shots, batch_size, max_seq_len = 20, 16, 256
seed = 10

prompt = "Follow the pattern:\n{}"

def make_taker(llama, layer):
    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    return taker

taker = make_taker(llama, layer)

for task_name in tqdm(task_names):

    sep = 3978
    pad = 0


    pairs = list(tasks[task_name].items())

    n_shot = n_few_shots - 1
    if task_name.startswith("algo"):
        n_shot = 8

    runner = ICLRunner(task_name, pairs, batch_size=batch_size, n_shot=1, max_seq_len=max_seq_len, seed=seed, prompt=prompt)

    tokenized = runner.get_tokens([
        x[:n_few_shots] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, all_resids = get_resids_call(inputs)

    scale = 25

    resids = all_resids[layer].value.unwrap("batch", "seq", "embedding")

    mask = train_tokens == sep
    col_indices = jnp.arange(mask.shape[1])
    col_indices_broadcasted = mask * col_indices
    sorted_indices = jnp.sort(col_indices_broadcasted, axis=1, descending=True)
    positions = sorted_indices[:, :1]
    
    def steer_with_direction(direction):
        direction = direction / jnp.linalg.norm(direction)
        direction = direction * scale
        
        modified = jax.vmap(lambda a, b: a.at[b].add(direction))(
            resids, positions
        )
        modified = pz.nx.wrap(modified, "batch", "seq", "embedding")

        _inputs = dataclasses.replace(inputs, tokens=modified)
        logits = taker(_inputs).unwrap("batch", "seq", "vocabulary")

        return logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2)

    task_losses_positive[task_name] = [[steer_with_direction(sae["W_dec"][feature]).tolist() for feature in tqdm(features)]]

    logits = llama(inputs)

    logits = logits.unwrap("batch", "seq", "vocabulary")

    task_losses_positive[task_name].append(logprob_loss(logits, train_tokens, sep=sep, pad_token=pad, n_first=2).tolist())

  0%|          | 0/23 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

  0%|          | 0/129 [00:00<?, ?it/s]

In [61]:
task_losses_positive["antonyms"]

In [119]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = np.minimum(losses, base_loss)

    losses = (losses - base_loss) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names)

fig.show()


In [120]:
features_dropped.index(7491)

In [121]:
mean_heatmap = np.mean(heatmap, axis=0)

drop_features = np.where(mean_heatmap < 0.9)[0]

drop_features = [features_dropped[x] for x in drop_features]

drop_features

In [115]:
heatmap.shape

In [116]:
(22 + 0.1)/23

In [127]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in task_losses_positive.items():
    base_loss = losses[1]
    losses = losses[0]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)


    losses = np.minimum(losses, base_loss * 1.5)

    losses = (losses - base_loss) / base_loss
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / ((max_loss - min_loss) + 1e-6)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

# heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)

heatmap = np.log(heatmap + 1)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names, color_continuous_scale='YlGnBu')

fig.show()

In [125]:
task_losses_positive["present_simple_past_perfect"]

In [84]:
mean_heatmap

In [71]:
import plotly.express as px
import pandas as pd
import numpy as np

normalized_losses = {}

drop_features = []
drop_ids = [features.index(feature) for feature in drop_features]

features_dropped = [feature for feature in features if feature not in drop_features]

for task_name, losses in negative_task_losses.items():
    base_loss = losses[1]
    losses = [x[0].tolist() for x in losses[0]]
    losses = np.array(losses)
    losses = np.delete(losses, drop_ids)

    losses = losses - base_loss[0]
    
    max_loss = np.max(losses)
    min_loss = np.min(losses)

    losses = (losses - min_loss) / (max_loss - min_loss)
    
    # mean_loss = np.mean(losses - base_loss)

    # base_acc = losses[1][1]

    # accs = [base_acc - loss[1] for loss in losses[0]]
    normalized_losses[task_name] = losses


heatmap = np.zeros((len(task_names), len(features_dropped)))

for i, task_name in enumerate(task_names):
    for j, feature in enumerate(features_dropped):
        heatmap[i, j] = normalized_losses[task_name][j]

heatmap /= np.mean(heatmap, axis=0, keepdims=True)



# heatmap = np.where(heatmap > 0, np.log(heatmap), -10)
# heatmap[np.isnan(heatmap)] = np.min(heatmap[np.isfinite(heatmap)])
# heatmap[np.isinf(heatmap)] = np.max(heatmap[np.isfinite(heatmap)])

# heatmap = np.clip(heatmap, -5, 5)


fig = px.imshow(heatmap, x=[str(x) for x in features_dropped], y=task_names)

fig.show()


NameError: name 'negative_task_losses' is not defined