In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:

pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [3]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [4]:
import plotly.express as px

In [5]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")
from micrlhf.sampling import sample
from transformers import AutoTokenizer
import jax
# tokenizer = load_tokenizer(filename)



In [6]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
from task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()

Cloning into 'itv'...
fatal: unable to access 'https://github.com/roeehendel/icl_task_vectors data/itv/': URL using bad/illegal format or missing URL


In [8]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [9]:
def generate_task_prompt(task, n_shots):
    prompt = "<user>Follow the pattern\n{}"
    examples = []

    for s, t in random.sample(list(tasks[task].items()), n_shots):
        examples.append(f"{s} -> {t}")
    prompt = prompt.format("\n".join(examples))

    # print(prompt)

    return prompt

def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [12]:
task_names = [
    "en_es"
]
layer = 18
n_seeds = 10
seed = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 32, 128

In [13]:
from task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder

In [14]:
from micrlhf.utils.load_sae import get_sae, sae_encode_gated
sae = get_sae(layer, 9)

--2024-05-30 17:13:03--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-9-7.50E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.95, 108.156.211.51, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/d935d57ccf7d09a79bc7533bb2ed37d9e2c1ed747d5808d39c7faac01d561449?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1717348383&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNzM0ODM4M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZDkzNWQ1N2

In [15]:
task_names = [
    "en_it", "en_fr", "en_de", "person_profession", "country_capital", "location_religion", "location_continent", "location_language", "es_en", "fr_en"
]

In [16]:
task = "antonyms"

In [17]:
n_few_shots = 40

pairs = list(tasks[task].items())

runner = ICLRunner(task, pairs, batch_size=32, n_shot=n_few_shots-1, max_seq_len=256, seed=10)

In [18]:
tokenized = runner.get_tokens(runner.train_pairs, tokenizer)

train_inputs = tokenized_to_inputs(**tokenized)
train_tokens = tokenized["input_ids"]

In [15]:
logits, resids = get_resids_call(train_inputs)

loss = logprob_loss(
    logits.unwrap("batch", "seq", "vocabulary"), train_tokens, shift=1 if task.startswith("algo") else 0, n_first=2
)

print(
    f"Full: {task}, loss: {loss}, n_shot: {n_few_shots}"
)

mask = train_tokens == 1599

resids = resids[layer].value.unwrap(
    "batch", "seq", "embedding"
)

resids = resids[mask]

        

Full: antonyms, loss: 0.339844, n_shot: 40


In [16]:
tv = resids.mean(axis=0)

In [16]:
tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
inputs = tokenized_to_inputs(**tokenized)
tokens = tokenized["input_ids"]

In [17]:


logits = llama(inputs)

add_act = make_act_adder(llama, tv.astype('bfloat16'), tokens, layer, length=1, shift= 0)

logits = add_act(inputs)

loss = logprob_loss(
    logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
)

print(
    f"TV: {task}, L: {layer}, Loss: {loss}"  
)

TV: antonyms, L: 18, Loss: 1.65625


In [18]:
from micrlhf.utils.load_sae import get_sae, sae_encode_gated
sae = get_sae(layer, 4)

--2024-05-30 15:36:27--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-4-8.86E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.51, 108.156.211.95, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/fa68513c10a8cdd065e4a0e66c05816325e4d72fb272857ca70564fca7fa808f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1717342587&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNzM0MjU4N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZmE2ODUxM2

In [19]:
_, pr, _ = sae_encode_gated(sae, tv)

In [20]:
from task_vector_utils import FeatureSearch

fs = FeatureSearch(task, pairs, layer, llama, tokenizer, n_shot=1, seed=seed+100, init_w=pr, early_stopping_steps=100, n_first=2)

w, m = fs.find_weights()

  0%|          | 0/2000 [00:00<?, ?it/s]

In [21]:
weights = (w > 0) * jax.nn.relu(w * jax.nn.softplus(sae["s_gate"]) * sae["scaling_factor"] + sae["b_gate"])   

recon = jnp.einsum("fv,f->v", sae["W_dec"], weights) + sae["b_dec"]
recon = recon.astype('bfloat16')

add_act = make_act_adder(llama, recon, tokens, layer, length=1, shift= 0)

logits = add_act(inputs)

loss = logprob_loss(
    logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
)

print(
    f"Recon fs: {task}, L: {layer}, Loss: {loss}"  
)

Recon fs: antonyms, L: 18, Loss: 1.45312


In [22]:
_, i = jax.lax.top_k(weights, 25)

In [23]:
print(i)

[46421  8787 48142 48437 21441 22813 47882  8281 45093 21885  1065 33278
     0     1     2     3     4     5     6     7     8     9    10    11
    12]


In [25]:
from micrlhf.utils.vector_storage import load_vector
first_layer = 4
w = load_vector(f"fs_antonyms_{first_layer}_v4")

In [24]:
from micrlhf.utils.vector_storage import save_and_upload_vector

save_and_upload_vector(f"fs_{task}_{layer}_v4", w, overwrite=False)

fs_antonyms_18_v4.npz:   0%|          | 0.00/197k [00:00<?, ?B/s]

In [19]:
from micrlhf.utils.vector_storage import load_vector
first_layer = 0
# w = load_vector(f"fs_{task}_{first_layer}_v4")

In [20]:
# first_sae = get_sae(first_layer, 4)

In [22]:
_, initial_resids = get_resids_call(train_inputs)

In [23]:
initial_resids = initial_resids[first_layer].value.unwrap(
    "batch", "seq", "embedding"
)


In [24]:
taker = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
    lambda i, x: x if i >= first_layer else pz.nn.Identity()
).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity()).select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity())

In [25]:
import dataclasses

In [26]:
get_resids_taker = taker.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids_taker = pz.de.CollectingSideOutputs.handling(get_resids_taker, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_taker_call = jit_wrapper.Jitted(get_resids_taker)

In [29]:
mask = train_tokens == 1599

positions = jnp.argwhere(mask)[:, -1]
positions = jnp.column_stack(
    tuple(
        positions + i
        for i in range(1)
    )
)



In [31]:
initial_resids.shape

In [28]:
sae = get_sae(18, 4)

--2024-05-30 17:15:46--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-4-8.86E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.95, 108.156.211.90, 108.156.211.125, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.95|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/fa68513c10a8cdd065e4a0e66c05816325e4d72fb272857ca70564fca7fa808f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1717348546&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNzM0ODU0Nn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZmE2ODUxM2M

In [25]:
w4 = load_vector(f"fs_antonyms_4_v4")

In [32]:
import dataclasses


feature = 21885

def calc_feature(w):
    weights = (w > 0) * jax.nn.relu(w * jax.nn.softplus(first_sae["s_gate"]) * first_sae["scaling_factor"] + first_sae["b_gate"])   
    recon = jnp.einsum("fv,f->v", first_sae["W_dec"], weights) + first_sae["b_dec"]
    recon = recon.astype('bfloat16')


    vm = jax.vmap(
        lambda a, b: a.at[b].add(recon)
    )

    ir = vm(initial_resids, positions)
    ir = pz.nx.wrap(ir, "batch", "seq", "embedding")

    taker_inputs = dataclasses.replace(inputs, tokens=ir)

    _, taker_resids = get_resids_taker_call(taker_inputs)

    resid_stream = taker_resids[18].value.unwrap(
        "batch", "seq", "embedding"
    )

    resid_stream = resid_stream[mask]

    tv = resid_stream.mean(axis=0)

    _, pr, _ = sae_encode_gated(sae, tv)

    return pr[21885]

    

In [45]:
def calc_feature(initial_resids):
    w = w4
    weights = (w > 0) * jax.nn.relu(w * jax.nn.softplus(first_sae["s_gate"]) * first_sae["scaling_factor"] + first_sae["b_gate"])   
    recon = jnp.einsum("fv,f->v", first_sae["W_dec"], weights) + first_sae["b_dec"]
    recon = recon.astype('bfloat16')


    vm = jax.vmap(
        lambda a, b: a.at[b].add(recon)
    )

    ir = vm(initial_resids, positions)
    ir = pz.nx.wrap(ir, "batch", "seq", "embedding")

    taker_inputs = dataclasses.replace(inputs, tokens=ir)

    _, taker_resids = get_resids_taker_call(taker_inputs)

    resid_stream = taker_resids[18].value.unwrap(
        "batch", "seq", "embedding"
    )

    resid_stream = resid_stream[mask]

    tv = resid_stream.mean(axis=0)

    _, pr, _ = sae_encode_gated(sae, tv)

    return pr[21885]


In [36]:
def calc_feature(initial_resids):

    initial_resids = pz.nx.wrap(initial_resids, "batch", "seq", "embedding")

    taker_inputs = dataclasses.replace(train_inputs, tokens=initial_resids)

    _, taker_resids = get_resids_taker_call(taker_inputs)

    resid_stream = taker_resids[18].value.unwrap(
        "batch", "seq", "embedding"
    )

    resid_stream = resid_stream[mask]

    tv = resid_stream.mean(axis=0)

    _, pr, _ = sae_encode_gated(sae, tv)

    return pr[21885]


In [44]:
jax.lax.top_k(w4, 10)[1].tolist()

In [38]:
lwg = jax.value_and_grad(calc_feature)

In [39]:
_, grad2 = lwg(initial_resids)

In [40]:
norm = jnp.linalg.norm(grad2, axis=-1)

In [45]:
i = 0

px.imshow(norm[i:i+1, :100], x= [f"{j}_{tokenizer.decode(x)}" for j,x in enumerate(train_tokens[i, :100])])

In [43]:
[tokenizer.decode(x) for x in train_tokens[i, :100]]

In [39]:
gw, i = jax.lax.top_k(grad, 10)

In [42]:
i.tolist()

In [25]:
save_and_upload_vector(f"antonyms_16_recon_wlito_v4", recon)

antonyms_16_recon_wlito_v4.npz:   0%|          | 0.00/6.41k [00:00<?, ?B/s]

In [261]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
first_get_resids = add_act.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
first_get_resids = pz.de.CollectingSideOutputs.handling(first_get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
first_get_resids_call = jit_wrapper.Jitted(first_get_resids)

In [262]:
_, resids = first_get_resids_call(inputs)

In [263]:
check_layer = 18
sae = get_sae(check_layer, 4)

In [264]:
resids = resids[check_layer].value.unwrap(
    "batch", "seq", "embedding"
)

In [265]:
mask = tokens == 1599

tv = resids[mask]

In [266]:
tv = tv.mean(axis=0)

In [267]:

_, pr, _ = sae_encode_gated(sae, tv)

In [268]:
_, i = jax.lax.top_k(pr, 25)

In [269]:
i.tolist()

In [224]:
i.tolist()

In [76]:
print(i)

[45690 45844 24958 36610  9420   409 40693 30498 35966     0     1     2
     3     4     5     6     7     8     9    10    11    12    13    14
    15]


In [63]:
print(i)

[24305 45899   690 42632 37517 25177  8738     0     1     2     3     4
     5     6     7     8     9    10    11    12    13    14    15    16
    17]


In [145]:
selected_features = i.tolist()

In [146]:
selected_features

In [147]:
tokenized = runner.get_tokens(runner.train_pairs, tokenizer)

inputs = tokenized_to_inputs(**tokenized)
train_tokens = tokenized["input_ids"]

_, resids = get_resids_call(inputs)

mask = train_tokens == 1599

resids = resids[layer].value.unwrap(
    "batch", "seq", "embedding"
)

In [148]:
_, sae_resids, _ = sae_encode_gated(sae, resids)

In [149]:
n_tokens = 180

In [150]:
heatmap = np.zeros((len(selected_features), n_tokens - 1))

feature_to_index = {f: i for i, f in enumerate(selected_features)}

for i, r in enumerate(sae_resids[0][1:n_tokens]):
    for j, f in enumerate(selected_features):
        heatmap[j, i] = r[f]


In [151]:
token_labels = [tokenizer.decode([t]) for t in train_tokens[0, 1:n_tokens]]
# token_labels = [t.replace("->", "") for t in token_labels]
token_labels = [f"{i}_{t}" for i, t in enumerate(token_labels)]

In [152]:
px.imshow(heatmap, y = [str(x) for x in selected_features], x = token_labels, title=f"Feature heatmap layer {layer}")

In [134]:
px.imshow(heatmap, y = [str(x) for x in selected_features], x = token_labels, title="Feature heatmap")

In [107]:
px.imshow(heatmap, y = [str(x) for x in selected_features], x = token_labels, title="Feature heatmap")

In [28]:
for r in sae_resids[0]:
    print(r.shape)

(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(49152,)
(