In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:
import random
import dataclasses
import jax
import optax

import jax.numpy as jnp
import numpy as np

from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector
from micrlhf.utils.load_sae import get_sae
from functools import partial

In [3]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
from task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()

Cloning into 'itv'...
fatal: unable to access 'https://github.com/roeehendel/icl_task_vectors data/itv/': URL using bad/illegal format or missing URL


In [6]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [7]:
def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [8]:
prompt = "<user>Follow the pattern\n{}"

In [9]:

target_layer = 17

In [10]:
task_names = [
    "en_es"
]

task = "en_es"

n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 32, 64, 256

In [11]:
def get_logprob_diff(logits: jnp.ndarray, target_tokens, print_results=False):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return target_logprobs


In [12]:
from micrlhf.llama import LlamaBlock
from functools import partial

def make_get_resids(llama, layer_target):
    get_resids = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer_target
                                                                              ).apply(lambda x:
        pz.nn.Sequential([
            pz.de.TellIntermediate.from_config(tag=f"resid_pre"),
            x
        ])
    )
    get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
    return get_resids

In [13]:
import dataclasses
def get_loss(weights, dictionary, inputs, target_tokens, taker, initial_resids):
    weights = jax.nn.relu(weights)
    

    recon = jnp.einsum("fv,f->v", dictionary, weights)
    recon = recon.astype('bfloat16')

    modified = pz.nx.nmap(lambda a, b: a.at[-1].add(b))(
        initial_resids.untag("seq", "embedding"), pz.nx.wrap(recon, "embedding").untag("embedding")
        ).tag("seq", "embedding")

    inputs = dataclasses.replace(inputs, tokens=modified)

    logits = taker(inputs).unwrap("batch", "seq", "vocabulary")

    logprob_diff = get_logprob_diff(logits, target_tokens)
    loss = -logprob_diff.mean()

    return loss + 2e-2 * jnp.linalg.norm(weights, ord=1), ((weights != 0).sum(), loss)
    # return loss, ((weights != 0).sum(), loss)

In [14]:
def train_step(weights, opt_state, dictionary, inputs, target_tokens, taker, initial_resids, optimizer, lwg, pos_only=True):
    (loss, (l0, loss_)), grad = lwg(weights, dictionary, inputs, target_tokens, taker, initial_resids)

    updates, opt_state = optimizer.update(grad, opt_state, weights)
    weights = optax.apply_updates(weights, updates)
    # weights_abs = jnp.abs(weights)
    # weights = jnp.sign(weights) * jax.nn.relu(weights_abs - shrinkage)

    return loss, weights, opt_state, dict(l0=l0, loss=loss_)

In [15]:
jittify = lambda x: partial(jax.jit(lambda lr, *args, **kwargs: lr(*args, **kwargs)[1][0].value), x)

In [16]:
from task_vector_utils import FeatureSearch

In [17]:
import datasets
from datasets import load_dataset

dataset = load_dataset("Helsinki-NLP/opus-100", "en-es", split="validation")

In [18]:
sample = dataset.shuffle()[:200]
sample = sample["translation"]

In [19]:
pairs = [(x["en"], x["es"]) for x in sample]
# pairs

In [38]:
picked_features = jnp.array(list(set([
    604, 34292,  6661, 24549,  6198, 14166, 44525, 41973, 35027,
       15949, 18911, 13590, 40721, 45024,  8023, 14153, 35597, 36343,
       24992, 23008, 34292, 24549, 6198, 14166, 44525, 41973, 18911, 13590
])))
fs = FeatureSearch("en_es", pairs, 18, l1_coeff=2e-2, init_w=0.5, lr=1e-2, n_shot=1, n_first=3, sae_v=4, early_stopping_steps=2000, max_seq_len=500, picked_features=picked_features)
weights, metrics = fs.find_weights()

  0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [19]:
pairs = tasks["en_es"]
pairs = list(pairs.items())

In [20]:
from micrlhf.utils.vector_storage import load_vector
tv = load_vector("task_vectors/en_es:18")

In [22]:
from micrlhf.utils.load_sae import get_sae, sae_encode_gated

sae = get_sae(18, 4)

--2024-05-28 13:17:59--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-4-8.86E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.90, 108.156.211.95, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/fa68513c10a8cdd065e4a0e66c05816325e4d72fb272857ca70564fca7fa808f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1717161479&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNzE2MTQ3OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZmE2ODUxM2

In [23]:
_, post_relu, _ = sae_encode_gated(sae, tv)

In [26]:
jax.lax.top_k(post_relu, 40)

In [30]:
fs = FeatureSearch("en_es", pairs, 18, l1_coeff=2e-2, init_w=post_relu, lr=1e-2, n_shot=1, n_first=2, sae_v=4, early_stopping_steps=300, max_seq_len=500, seed=111)
weights, metrics = fs.find_weights()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [32]:
jax.lax.top_k(weights, 25)

In [31]:
results2 = (weights, metrics)

In [None]:
results1 = (weights, metrics)

In [25]:
jax.lax.top_k(weights, 20)

In [33]:
jax.lax.top_k(jnp.abs(results2[0]), 10)

In [37]:
pairs

In [34]:

i = jax.lax.top_k(jnp.abs(results1[0]), 10)[1]
fs.check_features(i, scale=20)

  0%|          | 0/10 [00:00<?, ?it/s]

In [35]:

i = jax.lax.top_k(jnp.abs(results2[0]), 10)[1]
fs.check_features(i, scale=20)

  0%|          | 0/10 [00:00<?, ?it/s]

In [39]:
fs = FeatureSearch("en_es", pairs, 18, l1_coeff=1e-2, init_w=0.5, lr=5e-2, n_shot=1, n_first=10, sae_v=4, early_stopping_steps=150)

In [40]:
w = fs.find_weights()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [41]:
_, i = jax.lax.top_k(w[0], 20)

In [42]:
fs.check_features(i, 20)

  0%|          | 0/20 [00:00<?, ?it/s]

In [50]:
fs.check_feature(34107, 20)

In [49]:
steering_vector = fs.sae["W_dec"][37312] * 40
steering_vector =  steering_vector.astype('bfloat16')

act_add = make_act_adder(
    llama, steering_vector, fs.eval_tokens, fs.target_layer, length=1
)


logits = act_add(fs.eval_inputs).unwrap("batch", "seq", "vocabulary")

In [50]:
positions = jnp.argwhere(fs.eval_tokens == 1599)[:, -1]


In [36]:
_logits = jax.vmap(
    lambda a, b: a[b]
)(logits, positions)
tokenizer.decode(_logits.argmax(axis=-1))

In [37]:
fs.runner.eval_pairs

In [66]:
fs.eval_tokens

In [61]:
t

In [44]:
def check_feature(feature: int, task:str, target_layer, sae, print_results=False):
    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=10, prepend_space=task.startswith("algo"))

    print(
        dataset.prompts
    )


    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )

    pos = 1
    if task.startswith("algo"):
        pos = 2

    target_tokens = [x[pos] for x in tokenizer(dataset.completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)

    # weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

    recon = sae["W_dec"][feature] * 20
    recon = recon.astype('bfloat16')

    act_add = add_vector(
        llama, recon, target_layer, scale=1.0, position="last"
    )

    logits = act_add(inputs).unwrap("batch", "seq", "vocabulary")

    # print(logits[:, -1].mean(axis=-1))

    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            repr(tokenizer.decode(answer_logprobs.argmax(axis=-1)))
        )

        print(
            repr(tokenizer.decode(target_tokens))
        )

    return (target_logprobs - answer_logprobs.max(axis=-1))

In [27]:
sae = get_sae(18, 4)

In [30]:
sae

In [45]:
sae = get_sae(18, 4)

check_feature(6198, "en_es", 18, sae, print_results=True)

['worry ->', 'money ->', 'hear ->', 'general ->', 'morning ->', 'government ->', 'course ->', 'letter ->', 'oil ->', 'future ->', 'kill ->', 'cut ->', 'risk ->', 'time ->', 'student ->', 'wear ->', 'organization ->', 'simply ->', 'population ->', 'doctor ->', 'attention ->', 'money ->', 'carry ->', 'require ->', 'court ->', 'fund ->', 'summer ->', 'something ->', 'environment ->', 'industry ->', 'reduce ->', 'player ->', 'money ->', 'subject ->', 'dead ->', 'comment ->', 'position ->', 'military ->', 'art ->', 'great ->', 'town ->', 'make ->', 'bill ->', 'fine ->', 'love ->', 'clear ->', 'drug ->', 'foreign ->', 'administration ->', 'discuss ->', 'watch ->', 'recognize ->', 'gun ->', 'gun ->', 'guy ->', 'rule ->', 'on ->', 'ok ->', 'club ->', 'much ->', 'situation ->', 'respond ->', 'cut ->', 'provide ->']


'wor mon here general tarde government curs number pet fut k cort ries  estud vest organiz s población enfer at mon c require co fund ver something natur indust reduce jug mon verb death <  militar art grande town me bill fin am cl dro for administ disc mir reconoc g g ch  en ok _ mucho action res cort segu'
'pre din esc general ma gobierno cur cart ace fut mat cort ries h estud lle organiz sim población méd at din lle ex corte fond ver algo amb indust redu jug din sujet m com pos militar arte est pueblo fabric fact mult amor clar dro extr admin disc re reconoc p p ch reg sobre ok club mucho situ res cort prove'


In [None]:
for feature in i:
    print(
        f"Feature {feature}: {check_feature(feature)}"
    )

# get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), target_tokens, print_results=True)

Feature 10018: -7.5
Feature 20311: -2.375
Feature 40023: -4.5
Feature 5145: -7.21875
Feature 45145: -7.09375
Feature 2719: -7.09375
Feature 15415: -7.28125
Feature 24802: -7.15625
Feature 20375: -7.125
Feature 19505: -7.09375
Feature 38618: -7.25
Feature 36786: -7.125
Feature 47611: -7.03125
Feature 31726: -7.28125


KeyboardInterrupt: 

In [None]:
check_feature(764, True)

message reduce behavior place se hello cel media club live line throw sport es door shape daughter run pu
 bottom mind mat polit year moment cut hour has sit comment drop un government d wait phone anticip south arm sport material place picture shape papel no en < choice viel fine arrived ii f third pu hang u er must just hello fig
mess rid comport lu sic infer cell media club abit punto get sport così porta col fig corr bel reg super m question polit dec momento tag ora aver sed comment g gi governo pap asp tele asp sud es sport mater lu f col cart no met rela sc pi bene arriv io cin ter bel append fu dire do app infer figura


In [17]:
fs = FeatureSearch("algo_last", 18)

--2024-05-23 13:29:53--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-6-1.01E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.51, 108.156.211.95, 108.156.211.90, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.51|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/f057cb46f3d871ba03c66e707e3b3d8299322f36fa433862dc3fdca956715614?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716730193&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjczMDE5M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZjA1N2NiNDZm

In [18]:
ds = fs.eval_dataset

In [18]:
info = """location_continent 21 -1.09375
football_player_position 21 -3.76562
location_religion 21 -1.09375
location_language 20 -0.457031
person_profession 21 -1.30469
location_country 21 -1.44531
country_capital 18 -1.25781
person_language 18 -0.287109
singular_plural 21 -0.261719
present_simple_past_simple 20 -0.341797
antonyms 14 -0.75
plural_singular 20 -0.133789
present_simple_past_perfect 18 -1.25781
present_simple_gerund 20 -0.171875
en_it 18 -1.29688
it_en 14 -1.75781
en_fr 18 -1.39062
en_es 18 -1.21875
fr_en 17 -1.25781
es_en 17 -1.09375
en_ru 18 ~1.1
en_de 18 ~1
algo_max 24 -0.515625
algo_min 13 -0.945312
algo_last 17 -0.162109
algo_first 18 -0.828125|
algo_sum 14 -1.49219
algo_most_common 12 -1.15625
"""

In [21]:
info = """location_continent: w/o: 7.96875, w/: 1.91406, layer: 25
football_player_position: w/o: 13.125, w/: 4.375, layer: 22
location_religion: w/o: 7.875, w/: 1.55469, layer: 27
location_language: w/o: 7.9375, w/: 1.82812, layer: 21
person_profession: w/o: 11.3125, w/: 3.53125, layer: 24
location_country: w/o: 5.625, w/: 3.67188, layer: 21
country_capital: w/o: 6.40625, w/: 3.73438, layer: 18
person_language: w/o: 7.4375, w/: 1.28906, layer: 21
singular_plural: w/o: 1.96875, w/: 0.910156, layer: 21
present_simple_past_simple: w/o: 2.48438, w/: 0.773438, layer: 20
antonyms: w/o: 5.96875, w/: 2.29688, layer: 18
plural_singular: w/o: 2.53125, w/: 1.53906, layer: 18
present_simple_past_perfect: w/o: 4.65625, w/: 2.375, layer: 21
present_simple_gerund: w/o: 3.51562, w/: 1.17188, layer: 21
en_it: w/o: 15.375, w/: 3.0625, layer: 18
it_en: w/o: 8.0625, w/: 3.8125, layer: 18
en_ru: w/o: 16.125, w/: 3.76562, layer: 21
en_fr: w/o: 12.625, w/: 5.6875, layer: 18
en_es: w/o: 14.625, w/: 4.65625, layer: 18
fr_en: w/o: 7.3125, w/: 2.45312, layer: 18
es_en: w/o: 8.5, w/: 1.34375, layer: 20
en_de: w/o: 12.125, w/: 6.9375, layer: 18
algo_max: w/o: 3.0625, w/: 2.73438, layer: 24
algo_min: w/o: 3, w/: 2.5, layer: 13
algo_last: w/o: 3.39062, w/: 2.53125, layer: 17
algo_first: w/o: 3.20312, w/: 2.57812, layer: 18
algo_sum: w/o: 5.25, w/: 4.71875, layer: 14
algo_most_common: w/o: 2.92188, w/: 2.67188, layer: 12
"""

In [22]:
tasks_to_check = [x.split(": ") for x in info.split("\n") if x]
tasks_to_check = [(x[0], int(x[-1])) for x in tasks_to_check]
# tasks_to_check = [(x[0], int(x[1])) for x in tasks_to_check]
tasks_to_check = [
    ("en_es", 18),
#     # ("en_it", 18),
#     # ("en_fr", 18),
#     # ("en_de", 18),
#     # ("en_ru", 18)
#     ("algo_last", 18),
#     ("algo_first", 18),
#     ("algo_sum", 14),
] 
# tasks_to_check = [
#     x for x in tasks_to_check if x[0].startswith("algo")
# ]

tasks_to_check

In [23]:

results = {}

In [26]:
from datasets import load_dataset


for task, target_layer in tqdm(tasks_to_check):
    if target_layer == 21:
        target_layer = 20
    if target_layer == 14:
        target_layer = 12
    if target_layer == 13:
        target_layer = 12
    if target_layer == 25:
        target_layer = 24
    if target_layer == 27:
        target_layer = 24
    if target_layer == 22:
        target_layer = 24
    # if target_layer == 17:
    #     target_layer = 18

    if "en_" in task:
        l,r = task.split("_")

        binding = f"{l}-{r}"
        if task == "en_de":
            binding = "de-en"


        dataset = load_dataset("Helsinki-NLP/opus-100", binding, split="validation")
        sample = dataset.shuffle()[:40]
        sample = sample["translation"]
        pairs = [(x[l], x[r]) for x in sample]
    else:
        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]


    fs = FeatureSearch(task, pairs, target_layer, l1_coeff=5e-3, init_w=0.5, lr=5e-2, n_shot=1, n_first=10, sae_v=8, early_stopping_steps=200, max_seq_len=500)
    weights, metrics = fs.find_weights()
    i = jax.lax.top_k(jnp.abs(weights), 30)[1]
    best_feature, best_loss, mean_loss, losses = fs.check_features(i, scale=20)

    metrics = {k:float(v) for k, v in metrics.items()}

    results[task] = (best_feature, best_loss, mean_loss, losses, metrics, i)
    print(
        f"Task {task}, best feature {best_feature}, best loss {best_loss}, mean loss {mean_loss}, metrics {metrics}"
    )

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Task en_ru, best feature 13504, best loss 66, mean loss 67, metrics {'l0': 22.0, 'loss': 59.75}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Task en_fr, best feature 75057, best loss 56.75, mean loss 58.75, metrics {'l0': 44.0, 'loss': 49.5}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Task en_es, best feature 22683, best loss 63.25, mean loss 68.5, metrics {'l0': 115.0, 'loss': 45.0}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Task fr_en, best feature 46522, best loss 3.96875, mean loss 6.09375, metrics {'l0': 19.0, 'loss': 1.5703125}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Task es_en, best feature 28802, best loss 5.09375, mean loss 7.53125, metrics {'l0': 24.0, 'loss': 1.4140625}


  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Task en_de, best feature 86736, best loss 65.5, mean loss 66.5, metrics {'l0': 96.0, 'loss': 51.0}


In [22]:
results

In [None]:
import pickle

with open("results_v8_2.pkl", "wb") as f:
    pickle.dump(results, f)

In [None]:
import pickle
with open("results_v8_np.pkl", "rb") as f:
    results = pickle.load(f)


In [9]:
import numpy as np

results = {k: [np.array(x) for x in v] for k, v in results.items()}


array(51149, dtype=int32)

In [11]:
with open("results_v8_np.pkl", "wb") as f:
    pickle.dump(results, f)
# results

In [1]:
import pickle
with open("results_v8_np.pkl", "rb") as f:
    results = pickle.load(f)

In [5]:
results

In [58]:
with open("results_s_s.pkl", "wb") as f:
    import pickle
    pickle.dump(results, f)

In [None]:
import pickle

with open("results.pkl", "wb") as f:
    pickle.dump(results, f)
# with open("results.json", "w") as f:
#     json.dump(results, f)

In [26]:
for task, (best_feature, best_loss, mean_loss, losses, metrics) in results.items():
    print(
        f"{task}, best feature {best_feature}, loss ratio {mean_loss / best_loss}, loss {best_loss}"
    )

algo_max, best feature 28655, loss ratio 1.61719, loss -0.816406
algo_min, best feature 42195, loss ratio 1.14062, loss -1.3125
algo_last, best feature 42756, loss ratio 1.46094, loss -0.839844
algo_first, best feature 39302, loss ratio 11.6875, loss -0.0683594
algo_sum, best feature 10061, loss ratio 1.22656, loss -1.08594
algo_most_common, best feature 3788, loss ratio 1.34375, loss -0.933594


In [None]:
fs = FeatureSearch("en_de", 17)

--2024-05-22 00:38:44--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l17-test-run-6-4.52E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.90, 108.156.211.95, 108.156.211.125, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.90|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/1623d8da38be3171fcc8516a4cbe9fdb80e3d77e370aa5690895697649d688f3?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716597524&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjU5NzUyNH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvMTYyM2Q4ZGE

In [20]:
fs = FeatureSearch("en_ru", 18, early_stopping_steps=100, seed=10)
w = fs.find_weights()

--2024-05-22 01:20:39--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-6-1.01E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.51, 108.156.211.90, 108.156.211.95, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.51|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/f057cb46f3d871ba03c66e707e3b3d8299322f36fa433862dc3fdca956715614?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716600039&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjYwMDAzOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZjA1N2NiNDZm

  0%|          | 0/2000 [00:00<?, ?it/s]

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f12cc5bf290>>
Traceback (most recent call last):
  File "/home/dmitrii/.cache/pypoetry/virtualenvs/micrlhf-progress-_SD4q1c9-py3.12/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


In [32]:
fs = FeatureSearch("algo_last", 18, early_stopping_steps=100, seed=10)

In [33]:
w = fs.find_weights()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [21]:
jax.lax.top_k(w[0], 20)

In [34]:
fs.check_features(jax.lax.top_k(w[0], 20)[1], scale=20)

  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
fs.check_features(jax.lax.top_k(w[0], 20)[1], scale=20)

  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
jax.lax.top_k(w, 20)

In [None]:
def find_features(task, target_layer):
    scale = 20
    early_stopping_steps = 50
    init_w = 0.1
    iterations = 2000

    sae = get_sae(target_layer, 6)
    dictionary = sae["W_dec"]

    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    seed = 0

    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=seed+1)


    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )

    target_tokens = [x[1] for x in tokenizer(dataset.completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)

    jittify = lambda x: partial(jax.jit(lambda lr, *args, **kwargs: lr(*args, **kwargs)[1][0].value), x)
    get_resids_initial = make_get_resids(llama, target_layer)
    get_resids_initial = jittify(get_resids_initial)

    intital_resids = get_resids_initial(inputs)

    taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
        lambda i, x: x if i >= target_layer else pz.nn.Identity()
    ).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                    .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

    optimizer = optax.chain(
        optax.adam(3e-2),
        optax.zero_nans(),    
    )

    lwg = jax.value_and_grad(get_loss, has_aux=True)
    

    weights = jnp.ones(dictionary.shape[0]) * init_w
    opt_state = optimizer.init(weights)

    min_loss = 1e9
    early_stopping_counter = 0

    for _ in (bar := trange(iterations)):
        loss, weights, opt_state, metrics = train_step(weights, opt_state, dictionary, inputs, target_tokens, taker, intital_resids, optimizer, lwg)

        if metrics["loss"] < min_loss:
            min_loss = metrics["loss"]
            early_stopping_counter = 0
        

        tk = jax.lax.top_k(weights, 2)

        bar.set_postfix(loss_optim=loss, **metrics, top=tk[1][0], top_diff=(tk[0][0] - tk[0][1]) / tk[0][0])

        early_stopping_counter += 1
        if early_stopping_counter > early_stopping_steps:
            break

    
    return jax.lax.top_k(jnp.abs(weights), 20)


In [None]:
find_features(
    "en_es", 18
)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [35]:
sae = get_sae(target_layer, 6)

--2024-05-22 18:45:26--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l17-test-run-6-4.52E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.90, 108.156.211.95, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/1623d8da38be3171fcc8516a4cbe9fdb80e3d77e370aa5690895697649d688f3?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716662726&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjY2MjcyNn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvMTYyM2Q4ZG

In [36]:
check_feature(21195, "algo_last", 18, sae)

NameError: name 'check_feature' is not defined

In [26]:
from micrlhf.utils.vector_storage import load_vector

task = "en_es"
layer = 18

tv = load_vector(f"task_vectors/{task}:{layer}.npz")

In [28]:
from micrlhf.utils.ito import grad_pursuit
from micrlhf.utils.load_sae import get_sae

sae = get_sae(18, 4)

w, r = grad_pursuit(tv, sae["W_dec"], 20, pos_only=True)

In [29]:
jax.lax.top_k(w, 20)