In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:

pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [3]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [4]:
import plotly.express as px

In [5]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")
from micrlhf.sampling import sample
from transformers import AutoTokenizer
import jax
# tokenizer = load_tokenizer(filename)



In [6]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
from task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()

Cloning into 'itv'...
fatal: unable to access 'https://github.com/roeehendel/icl_task_vectors data/itv/': URL using bad/illegal format or missing URL


In [8]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [9]:
def generate_task_prompt(task, n_shots):
    prompt = "<user>Follow the pattern\n{}"
    examples = []

    for s, t in random.sample(list(tasks[task].items()), n_shots):
        examples.append(f"{s} -> {t}")
    prompt = prompt.format("\n".join(examples))

    # print(prompt)

    return prompt

def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [10]:
task_names = [
    "en_es"
]
layer = 18
n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 16, 256

In [11]:
from task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder

In [12]:
from micrlhf.utils.load_sae import get_sae, sae_encode_gated
sae = get_sae(layer, 4)

--2024-05-28 20:56:47--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-4-8.86E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.51, 108.156.211.95, 108.156.211.90, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.51|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/fa68513c10a8cdd065e4a0e66c05816325e4d72fb272857ca70564fca7fa808f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1717189007&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNzE4OTAwN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZmE2ODUxM2Mx

In [13]:
from micrlhf.utils.ito import grad_pursuit

In [25]:

tasks_names = [
    # "en_es",
    # "en_fr",
    "antonyms"
]

collected_resids = {}
for task in tasks_names:
    pairs = list(tasks[task].items())

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=10)

    n_shot = 20

    tokenized = runner.get_tokens([
        x[:n_shot] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, resids = get_resids_call(inputs)

    resids = resids[layer].value.unwrap(
        "batch", "seq", "embedding"
    )

    tv = get_tv(resids, train_tokens, shift = 0)

    tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    logits = llama(inputs)
    

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"Zero: {task}, loss: {loss}"
    )

    add_act = make_act_adder(llama, tv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"TV: {task}, L: {layer}, Loss: {loss}"  
    )

    _, _, rtv = sae_encode_gated(sae, tv)

    add_act = make_act_adder(llama, rtv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"Recon TV: {task}, L: {layer}, Loss: {loss}"  
    )

    _, gtv = grad_pursuit(tv, sae["W_dec"], 100)

    add_act = make_act_adder(llama, gtv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"Grad pursuit TV: {task}, L: {layer}, Loss: {loss}"
    )


    mask = train_tokens == 1599

    tv = resids[mask]

    _, _, tv = sae_encode_gated(sae, tv)

    tv = tv.mean(axis=0)
    tv = tv.astype('bfloat16')

    add_act = make_act_adder(llama, tv, tokens, layer, length=1, shift= 0)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"TV on recon: {task}, L: {layer}, Loss: {loss}"  
    )

    
    

Zero: en_es, loss: 14.625
TV: en_es, L: 18, Loss: 4.65625
Recon TV: en_es, L: 18, Loss: 5.875
Grad pursuit TV: en_es, L: 18, Loss: 7.53125
TV on recon: en_es, L: 18, Loss: 5.84375


In [134]:
task = "en_es"

pairs = list(tasks[task].items())

runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=128, seed=10)

n_shot = 20

tokenized = runner.get_tokens([
    x[:n_shot] for x in runner.train_pairs
], tokenizer)

inputs = tokenized_to_inputs(**tokenized)
train_tokens = tokenized["input_ids"]

logits, resids = get_resids_call(inputs)

loss = logprob_loss(
    logits.unwrap("batch", "seq", "vocabulary"), train_tokens, shift=1 if task.startswith("algo") else 0, n_first=2
)

print(
    f"Full: {task}, loss: {loss}"
)

resids = resids[layer].value.unwrap(
    "batch", "seq", "embedding"
)

tv = get_tv(resids, train_tokens, shift = 0)

tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
inputs = tokenized_to_inputs(**tokenized)
tokens = tokenized["input_ids"]

add_act = make_act_adder(llama, tv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

logits = add_act(inputs)

loss = logprob_loss(
    logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
)

print(
    f"TV: {task}, L: {layer}, Loss: {loss}"  
)

Full: en_es, loss: 1.51562
TV: en_es, L: 18, Loss: 4.71875


In [27]:
task_names = [
    "en_es", "en_fr", "antonyms", "es_en", "en_it", "fr_en", "en_de", "location_continent", "location_language", "person_profession" 
]

In [14]:
from task_vector_utils import FeatureSearch

seed = 10


collected_weights = []

for task in tqdm(task_names):
    pairs = list(tasks[task].items())

    n_shot = 40

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=seed)


    tokenized = runner.get_tokens([
        x[:n_shot] for x in runner.train_pairs
    ], tokenizer)

    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, resids = get_resids_call(inputs)

    resids = resids[layer].value.unwrap(
        "batch", "seq", "embedding"
    )

    tv = get_tv(resids, train_tokens, shift = 0)

    tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    add_act = make_act_adder(llama, tv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"TV: {task}, L: {layer}, Loss: {loss}"  
    )

    _, pr, rtv = sae_encode_gated(sae, tv)

    add_act = make_act_adder(llama, rtv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"Recon TV: {task}, L: {layer}, Loss: {loss}"  
    )

    fs = FeatureSearch(task, pairs, 18, llama, tokenizer, n_shot=1, seed=seed+100, init_w=pr, early_stopping_steps=200, n_first=2)

    w, m = fs.find_weights()

    collected_weights.append(
        w
    )

    weights = (w > 0) * jax.nn.relu(w * jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"] + sae["b_gate"])   

    recon = jnp.einsum("fv,f->v", sae["W_dec"], weights) + sae["b_dec"]
    recon = recon.astype('bfloat16')

    add_act = make_act_adder(llama, recon, tokens, layer, length=1, shift= 0)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"Recon fs: {task}, L: {layer}, Loss: {loss}"  
    )



  0%|          | 0/1 [00:00<?, ?it/s]

TV: en_es, L: 18, Loss: 4.71875
Recon TV: en_es, L: 18, Loss: 7.875


  0%|          | 0/2000 [00:00<?, ?it/s]

Recon fs: en_es, L: 18, Loss: 1.95312


In [15]:
jax.lax.top_k(weights, 22)

In [13]:
import pickle

with open("weights.pkl", "rb") as f:
    collected_weights = pickle.load(f)

In [40]:
with open("weights_dict.pkl", "wb") as f:
    pickle.dump(collected_weights_dict, f)

In [39]:
collected_weights_dict = {
    t:w.tolist() for t, w in zip(task_names, collected_weights)
}

In [15]:
(collected_weights[0] > 0).sum()

In [19]:
top_features = [
    jax.lax.top_k(weights, (weights > 0).sum()) for weights in collected_weights
]

In [29]:
all_features = set()

for w, i in top_features:
    all_features.update(i.tolist())

In [30]:
all_features = list(all_features)

In [34]:
heatmap = np.zeros((len(top_features), len(all_features)))

feature_to_idx = {f: i for i, f in enumerate(all_features)}

for i, (w, j) in enumerate(top_features):
    for f, v in zip(j.tolist(), w):
        heatmap[i, feature_to_idx[f]] = v / max(w)

In [26]:
len(task_names)

In [35]:
px.imshow(heatmap, x=[str(x) for x in all_features], y=[str(x) for x in task_names], title="loss ito l18 v4")

In [17]:

task_names = [
    # "en_es",
    "antonyms"
]

n_few_shots = 60

collected_resids_new = {}
for task in task_names:
    collected_resids_new[task] = []
    pairs = list(tasks[task].items())

    runner = ICLRunner(task, pairs, batch_size=32, n_shot=n_few_shots-1, max_seq_len=512, seed=10)


    for n_shot in tqdm(list(range(0, 50))):
        tokenized = runner.get_tokens([
            x[:n_shot] for x in runner.train_pairs
        ], tokenizer)

        inputs = tokenized_to_inputs(**tokenized)
        train_tokens = tokenized["input_ids"]

        logits, resids = get_resids_call(inputs)

        loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), train_tokens, shift=1 if task.startswith("algo") else 0, n_first=2
        )

        print(
            f"Full: {task}, loss: {loss}, n_shot: {n_shot}"
        )

        mask = train_tokens == 1599

        resids = resids[layer].value.unwrap(
            "batch", "seq", "embedding"
        )

        resids = resids[mask]

        collected_resids_new[task].append(resids) 


        

  0%|          | 0/50 [00:00<?, ?it/s]

Full: antonyms, loss: -0, n_shot: 0
Full: antonyms, loss: 5.21875, n_shot: 1
Full: antonyms, loss: 1.94531, n_shot: 2
Full: antonyms, loss: 2.125, n_shot: 3
Full: antonyms, loss: 2.70312, n_shot: 4
Full: antonyms, loss: 2.5625, n_shot: 5
Full: antonyms, loss: 1.67188, n_shot: 6
Full: antonyms, loss: 1.16406, n_shot: 7
Full: antonyms, loss: 1.03125, n_shot: 8
Full: antonyms, loss: 1.39062, n_shot: 9
Full: antonyms, loss: 0.546875, n_shot: 10
Full: antonyms, loss: 0.585938, n_shot: 11
Full: antonyms, loss: 0.259766, n_shot: 12
Full: antonyms, loss: 0.40625, n_shot: 13
Full: antonyms, loss: 0.402344, n_shot: 14
Full: antonyms, loss: 0.263672, n_shot: 15
Full: antonyms, loss: 0.408203, n_shot: 16
Full: antonyms, loss: 0.917969, n_shot: 17
Full: antonyms, loss: 0.550781, n_shot: 18
Full: antonyms, loss: 0.660156, n_shot: 19
Full: antonyms, loss: 0.929688, n_shot: 20
Full: antonyms, loss: 0.462891, n_shot: 21
Full: antonyms, loss: 0.458984, n_shot: 22
Full: antonyms, loss: 1.14844, n_shot: 2

In [18]:
tv = collected_resids_new["antonyms"][-1].mean(0)

In [19]:
_, pr, rtv = sae_encode_gated(sae, tv)

In [20]:
fs = FeatureSearch(task, pairs, 18, llama, tokenizer, n_shot=1, seed=seed+100, init_w=pr, early_stopping_steps=200, n_first=2)

w, m = fs.find_weights()

weights = (w > 0) * jax.nn.relu(w * jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"] + sae["b_gate"])   

  0%|          | 0/2000 [00:00<?, ?it/s]

In [40]:

layer = 12

sae = get_sae(layer, 4)

task_names = [
    # "en_es",
    "antonyms"
]

n_few_shots = 60

results = []



for task in task_names:
    collected_resids_new[task] = []
    pairs = list(tasks[task].items())

    runner = ICLRunner(task, pairs, batch_size=32, n_shot=n_few_shots-1, max_seq_len=512, seed=10)


    for n_shot in tqdm(list(range(10, 50))):
        tokenized = runner.get_tokens([
            x[:n_shot] for x in runner.train_pairs
        ], tokenizer)

        inputs = tokenized_to_inputs(**tokenized)
        train_tokens = tokenized["input_ids"]

        _, resids = get_resids_call(inputs)

        resids = resids[layer].value.unwrap(
            "batch", "seq", "embedding"
        )

        tv = get_tv(resids, train_tokens, shift = 0)

        tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
        inputs = tokenized_to_inputs(**tokenized)
        tokens = tokenized["input_ids"]

        add_act = make_act_adder(llama, tv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

        logits = add_act(inputs)

        loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
        )

        print(
            f"TV: {task}, L: {layer}, N: {n_shot}, Loss: {loss}"  
        )

        _, pr, rtv = sae_encode_gated(sae, tv)

        add_act = make_act_adder(llama, rtv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

        logits = add_act(inputs)

        rloss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
        )

        print(
            f"RTV: {task}, L: {layer}, N: {n_shot}, Loss: {rloss}"  
        )

        results.append(
            (loss, rloss)
        )


        

--2024-05-28 23:54:41--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l12-test-run-4-3.94E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.90, 108.156.211.95, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/b91f7bb2110b9b8c9ba8c0aab31ba19c91d3acd15287732c379fc9a14329ee1d?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1717199682&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNzE5OTY4Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvYjkxZjdiYj

  0%|          | 0/40 [00:00<?, ?it/s]

TV: antonyms, L: 12, N: 10, Loss: 3.28125
RTV: antonyms, L: 12, N: 10, Loss: 3.125
TV: antonyms, L: 12, N: 11, Loss: 3.23438
RTV: antonyms, L: 12, N: 11, Loss: 3.10938
TV: antonyms, L: 12, N: 12, Loss: 3.23438
RTV: antonyms, L: 12, N: 12, Loss: 3.125
TV: antonyms, L: 12, N: 13, Loss: 3.23438
RTV: antonyms, L: 12, N: 13, Loss: 3.07812
TV: antonyms, L: 12, N: 14, Loss: 3.21875
RTV: antonyms, L: 12, N: 14, Loss: 3.09375
TV: antonyms, L: 12, N: 15, Loss: 3.20312
RTV: antonyms, L: 12, N: 15, Loss: 3.04688
TV: antonyms, L: 12, N: 16, Loss: 3.21875
RTV: antonyms, L: 12, N: 16, Loss: 3
TV: antonyms, L: 12, N: 17, Loss: 3.1875
RTV: antonyms, L: 12, N: 17, Loss: 3
TV: antonyms, L: 12, N: 18, Loss: 3.17188
RTV: antonyms, L: 12, N: 18, Loss: 2.95312


TypeError: Cannot determine dtype of Traced<ShapedArray(bfloat16[])>with<BatchTrace(level=3/0)> with
  val = Array([0.484375, 0.492188, 0.503906, ..., 0.490234, 0.490234, 0.474609],      dtype=bfloat16)
  batch_dim = 0

In [159]:
import pandas as pd

df = pd.DataFrame([list(x) + [i + 10] for i, x in enumerate(results)], columns=["TV", "Recon TV", "n_shot"])
# df

In [158]:
px.line(
    df, x="n_shot", y=["TV", "Recon TV"], title="en_es l18 v4"
)

In [23]:
tv_prs = []
for resid in tqdm(collected_resids_new["antonyms"]):
    tv = resid.mean(0)
    _, pr, rtv = sae_encode_gated(sae, tv)
    l0 = (pr > 0).sum() 
    tv_prs.append(
        (pr, l0)
    )

  0%|          | 0/50 [00:00<?, ?it/s]

In [168]:
selected_features = [720,
 1065,
 9218,
 13590,
 14166,
 15865,
 19597,
 23008,
 24549,
 27626,
 35202,
 36287,
 41973,
 44525]

In [30]:
jax.lax.top_k(weights, 10)

In [32]:
selected_features = jax.lax.top_k(weights, 10)[1]

In [38]:
selected_features.tolist()

In [33]:
top_features = []
for pr, l0 in tv_prs:
    top_features.append(
        # jax.lax.top_k(pr, l0)
        (pr[selected_features] / pr.sum(), selected_features)    
    )

In [34]:
all_features = set()
for w, i in top_features:
    # print(i)
    all_features.update(i.tolist())

In [35]:
feature_to_idx = {f: i for i, f in enumerate(all_features)}
heatmap = np.zeros((len(tv_prs), len(all_features)))

for i, (w, j) in enumerate(top_features):
    j = j.tolist()
    w = w.tolist()
    for f, v in zip(j, w):
        heatmap[i, feature_to_idx[f]] = v / sum(w)


In [36]:
px.imshow(heatmap, x=[str(x) for x in all_features], y=[str(x) for x in range(0, 50)], title="en")

In [64]:
px.imshow(heatmap, x=[str(x) for x in all_features], y=[str(x) for x in range(10, 30)], title="antonyms")

In [None]:
all_features

In [18]:
fs = FeatureSearch(task, pairs, 18, llama, tokenizer, n_shot=1, seed=seed+100, init_w=pr, early_stopping_steps=100, n_first=2)

In [19]:
w, m = fs.find_weights()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [20]:
weights = (w > 0) * jax.nn.relu(w * jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"] + sae["b_gate"])   

In [32]:
_, i = jax.lax.top_k(weights, 25)

In [24]:
recon = jnp.einsum("fv,f->v", sae["W_dec"], weights) + sae["b_dec"]
recon = recon.astype('bfloat16')


In [55]:
_, _, atv = sae_encode_gated(sae, tv, ablate_features=i)

In [56]:
add_act = make_act_adder(llama, atv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

logits = add_act(inputs)

loss = logprob_loss(
    logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
)

print(
    f"Recon fs: {task}, L: {layer}, Loss: {loss}"  
)

Recon fs: en_fr, L: 18, Loss: 12.25


In [12]:

tasks = [
    "en_es",
    "antonyms"
]

collected_resids = {}
for task in tasks:
    pairs = list(tasks[task].items())

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=10)


    for n_shot in tqdm(list(range(1, 20))):
        tokenized = runner.get_tokens([
            x[:n_shot] for x in runner.train_pairs
        ], tokenizer)

        inputs = tokenized_to_inputs(**tokenized)
        train_tokens = tokenized["input_ids"]

        _, resids = get_resids_call(inputs)

        mask = train_tokens == 1599
        mask = jnp.logical_and(jnp.roll(jnp.cumsum(mask[:, ::-1], axis=1)[:, ::-1] == 0, shift=-1, axis=1), mask)

        resids = resids[layer].value.unwrap(
            "batch", "seq", "embedding"
        )

        resids = resids[mask]

        collected_resids[task].append(resids) 
        

TypeError: list indices must be integers or slices, not str

In [178]:
means_2 = jnp.stack([x.mean(0) for x in collected_resids_2])

In [179]:
cosine_sims = jnp.dot(means, means_2.T) / jnp.linalg.norm(means, axis=1)[:, None] / jnp.linalg.norm(means_2, axis=1)[None, :]

In [180]:
import plotly.express as px

px.imshow(cosine_sims)

In [157]:
from micrlhf.utils.ito import grad_pursuit

In [181]:
collected_features_2 = []
all_features_2 = set()
for resid in collected_resids_2:
    resid = resid.mean(0)
    # _, features, _ = sae_encode_gated(sae, resid)

    # w, i = jax.lax.top_k(features, 100)

    w, _ = grad_pursuit(resid, sae["W_dec"], 100)

    w, i = jax.lax.top_k(w, 100)

    all_features_2.update(i.tolist())

    collected_features_2.append((w, i))

In [166]:
all_features_2 = list(all_features_2)

In [None]:
heatmap_2 = np.zeros((len(collected_features_2), len(all_features_2)))

feature_map_2 = {v: i for i, v in enumerate(all_features_2)}

for i, (w, f) in enumerate(collected_features_2):
    for j, v in zip(w, f.tolist()):
        heatmap_2[i, feature_map_2[v]] = j

# idx = np.argsort(heatmap_2.sum(0))

# heatmap_2 = heatmap_2[:, idx]

In [184]:
len(set(all_features) - set(all_features_2))

In [None]:
r = [heatmap_2[:, x] for set(all_features_2) - set(all_features)

In [185]:
len(all_features)

In [168]:
heatmap_s = heatmap[1:, :] - heatmap[:-1, :]

In [169]:
n_last = 100

In [172]:
feature_map[44525]

KeyError: 44525

In [170]:
px.imshow([x[-n_last:] for x in heatmap], x=[str(x) for x in [all_features[i] for i in idx]][-n_last:], y=[str(x) for x in range(1, 20)])

In [122]:
px.imshow([x[-n_last:] for x in heatmap], x=[str(x) for x in [all_features[i] for i in idx]][-n_last:], y=[str(x) for x in range(1, 20)])

In [171]:
px.imshow([x[-n_last:] for x in heatmap_s], x=[str(x) for x in [all_features[i] for i in idx]][-n_last:], y=[str(x) for x in range(2, 20)], color_continuous_scale="rdbu", color_continuous_midpoint=0)

In [151]:
px.imshow([x[-n_last:] for x in heatmap], x=[str(x) for x in [all_features[i] for i in idx]][-n_last:], y=[str(x) for x in range(1, 20)])

In [152]:
px.imshow([x[-n_last:] for x in heatmap_s], x=[str(x) for x in [all_features[i] for i in idx]][-n_last:], y=[str(x) for x in range(2, 20)], color_continuous_scale="rdbu", color_continuous_midpoint=0)

In [107]:
px.imshow(heatmap, x=[str(x) for x in all_features], y=[str(x) for x in range(1, 20)])