In [10]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
import random
import dataclasses
import jax
import optax

import jax.numpy as jnp
import numpy as np

from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector
from micrlhf.utils.load_sae import get_sae
from functools import partial

In [12]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="tpu:0")

In [13]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [14]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [15]:
jittify = lambda x: partial(jax.jit(lambda lr, *args, **kwargs: lr(*args, **kwargs)[1][0].value), x)

In [16]:
import datasets
from datasets import load_dataset

dataset = load_dataset("Helsinki-NLP/opus-100", "en-es", split="validation")

In [17]:
batch_size, max_seq_len, layer = 64, 256, 18

In [18]:
prompt = "<|user|>\nTranslate the following from English to {}:\n{}\n<|end|>\n<|assistant|>\n{}"

def tokenize_batch(batch, language, code):
    batch = [
        prompt.format(language, x["en"], x[code]) for x in batch["translation"]
    ]

    return tokenizer(batch, truncation=True, padding="max_length", max_length=max_seq_len, return_tensors="np")

In [19]:
def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [20]:
from more_itertools import chunked

batch = dataset[:10]


batch = tokenize_batch(batch, "Spanish", "es")

inputs = tokenized_to_inputs(batch["input_ids"], batch["attention_mask"])

logits, resids = get_resids_call(inputs)



In [2]:
from micrlhf.utils.load_sae import get_sae, sae_encode

In [3]:
sae = get_sae(layer, 8)

NameError: name 'layer' is not defined

In [38]:
def loss(logits, tokens):
    logits = jax.nn.log_softmax(logits)

    logits = logits[:, :-1]

    logits = jnp.take_along_axis(logits, tokens[:, 1:, None], axis=-1).squeeze(-1)

    mask = tokens[:, 1:] == 32001
    mask = jnp.cumsum(mask[:, ::-1], axis=-1)[:, ::-1] > 0
    mask = jnp.logical_not(mask)

    rolled_mask = jnp.roll(mask, 2, axis=-1)
    mask = jnp.logical_and(mask, jnp.logical_not(rolled_mask))

    print(mask)
    print(mask.shape)
    print(tokens[:, 1:][mask])
    print(tokens[:, -12:])

    logits = logits * mask

    return -logits.sum(axis=-1).mean(axis=-1)

In [39]:
loss(
    logits.unwrap("batch", "seq", "vocabulary"), batch["input_ids"]
)

[[False False False ... False False False]
 [False False False ... False False False]
 [False False False ... False False False]
 ...
 [False False False ... False False False]
 [False False False ... False False False]
 [False False False ... False False False]]
(10, 255)
[ 1939  1162  2661   359   317 29983  1939   734 18613  1252   317 29983
 18613 29925  1939  7814  3423   546  5952 10487]
[[32001  1939  1162 25768  1277   439 29948  3576   425  1236 20774 29889]
 [ 7845  2363   427   712  9747  6005 28857   447  3006  1941 11629 29901]
 [29891   560 28766   712  1775 29980   263  2261  2571   387   336 29889]
 [   13 32007 32001  1939   734   298 26712   560 13856   335  2016 29889]
 [ 8712 29877   409   282  1943   346   343  2723  8712 29877   694 29973]
 [ 8241 29892 11131 29889    13 32007 32001   317 29983 29892 11131 29889]
 [29973    13 32007 32001 18613 29925   272   439 29948  3248 27066 29973]
 [  278  7679 29889    13 32007 32001  1939  7814 20470 11092   941 29889]
 [ 

In [34]:

mask = batch["input_ids"] == 32001

tv = resids[layer].value.unwrap("batch", "seq", "embedding")[mask]

In [37]:
tv = tv.mean(axis=0)

In [39]:
jnp.argwhere(mask)[:, 1]

In [42]:
from micrlhf.utils.activation_manipulation import add_vector


act_add = add_vector(
    llama, tv, layer=10, scale=1.0, position=jnp.argwhere(mask)[:, -1]+1
)

In [43]:
act_add(
    inputs
)

In [24]:
from more_itertools import chunked


for batch in tqdm(chunked(dataset, batch_size)):
    batch = tokenize_batch(batch, "Spanish", "es")

    inputs = tokenized_to_inputs(batch["input_ids"], batch["attention_mask"])

    logits, resids = get_resids_call(inputs)

    mask = batch["input_ids"] == 1

    tv = resids[layer].value.unwrap("batch", "seq", "embedding")[mask]

    
    break

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

{'translation': [{'en': "I don't even remember what the fight was about.", 'es': 'No recuerdo por qué fue la pelea.'}, {'en': 'Here are the sites of each of those that have taken place:', 'es': 'Estos son los sitios en que cada Congreso ha tenido lugar:'}, {'en': "I'm the man who killed Blackbeard.", 'es': 'Sí. Soy el hombre que mató a Barbanegra.'}, {'en': "Don't get smart.", 'es': 'No te hagas el inteligente.'}, {'en': 'Is there an exact moment in the life of a soldier before which he is not suffering from shell-shock and after which he is?', 'es': '¿Existe un límite de cuándo se padece y cuándo no?'}, {'en': 'Yes, Joe.', 'es': 'Sí, Joe.'}, {'en': 'Why two months?', 'es': '¿Por qué dos meses?'}, {'en': "- We don't need the rock.", 'es': 'No nos hace falta.'}, {'en': 'I hope Umeko will forget about marrying him.', 'es': 'Espero que a Umeko se le pase lo de casarse con él.'}, {'en': 'He has led several projects and campaigns in Switzerland and abroad.', 'es': 'Ha dirigido varios proyec

In [None]:
def collate_fn(batch):
    return {k: jnp.array(v) for k, v in batch.items()}

In [27]:
import pickle

with open("results.pkl", "rb") as f:
    results = pickle.load(f)



In [28]:
for task, (best_feature, best_loss, mean_loss, losses, metrics) in results.items():
    print(
        f"{task}, best feature {best_feature}, loss ratio {mean_loss / best_loss}"
    )

location_continent, best feature 47188, loss ratio 1.27344
football_player_position, best feature 26021, loss ratio 1.17969
