In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:

pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [3]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [4]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")
from micrlhf.sampling import sample
from transformers import AutoTokenizer
import jax
# tokenizer = load_tokenizer(filename)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [29]:
from task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()

Cloning into 'itv'...
fatal: unable to access 'https://github.com/roeehendel/icl_task_vectors data/itv/': URL using bad/illegal format or missing URL


In [13]:
tasks.keys()

In [14]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
hot -> cold, yes -> no, in -> out, up ->


In [17]:
def generate_task_prompt(task, n_shots):
    prompt = "<user>Follow the pattern\n{}"
    examples = []

    for s, t in random.sample(list(tasks[task].items()), n_shots):
        examples.append(f"{s} -> {t}")
    prompt = prompt.format("\n".join(examples))

    # print(prompt)

    return prompt

def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [18]:
def generate_task_inputs_old(task, n_shots, batch_size, max_length=128, seed=0):
    random.seed(seed)

    texts = [generate_task_prompt(task, n_shots) for _ in range(batch_size)]
    tokenized = tokenizer.batch_encode_plus(texts, padding="longest", max_length=max_length, truncation=True, return_tensors="np")

    inputs = tokenized_to_inputs(
        **tokenized
    )

    return inputs, tokenized

In [19]:
prompt = "<user>Follow the pattern\n{}"

In [20]:
def get_logprob_diff(logits: jnp.ndarray, completions: List[str], print_results=False):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_tokens = [x[1] for x in tokenizer(completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)
    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return target_logprobs - answer_logprobs.max(axis=-1)


In [21]:
task_names = [
    "en_it"
]
layer = 18
n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 64, 256

In [68]:
tvs = []

for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_zero = [[] for _ in range(2)]
    for seed in trange(n_seeds):
        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=seed)

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )
        
        logits, resids = get_resids_call(inputs)
    
        tokens = tokenized["input_ids"]

        shot_logprobs_orig[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_orig[1].append(
            shot_logprobs_orig[0][-1] == 0.
        )

        mask = inputs.tokens == 1599
        mask = mask.unwrap("batch", "seq")

        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
            mask
        ]

        tv = tv.mean(
            axis=0
        )

        tvs.append(tv)

        print(
            tv.mean(), tv.std()
        )

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=seed+1)


        print(
            dataset.prompts, dataset.completions
        )

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )

        logits = act_add(inputs)

        tokens = tokenized["input_ids"]

        shot_logprobs_added[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_added[1].append(
            shot_logprobs_added[0][-1] == 0.
        )

        logits, _ = get_resids_call(inputs)
        
        shot_logprobs_zero[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_zero[1].append(
            shot_logprobs_zero[0][-1] == 0.
        )
        
    print(f"orig: {shot_logprobs_orig}")
    print(f"zero: {shot_logprobs_zero}")
    print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = [list(map(np.mean, x)) for x in shot_logprobs_orig]
    shot_logprobs_zero = [list(map(np.mean, x)) for x in shot_logprobs_zero]
    shot_logprobs_added = [list(map(np.mean, x)) for x in shot_logprobs_added]


print(f"orig: {shot_logprobs_orig}")
print(f"zero: {shot_logprobs_zero}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

pre situ sic s dest cres inter reg se un le t sens propri in tag prepar serv ind risult se m coin f c diff città lu come ass cart chiam
s situ sic s gi cres totale disco se gi le t sens propri in tag prepar serv ind risult se pare coin f stesso diff citt lu ven ass cart chiam
-0.00430298 1.09375
['legal ->', 'along ->', 'window ->', 'information ->', 'over ->', 'field ->', 'current ->', 'better ->', 'power ->', 'message ->', 'reduce ->', 'behavior ->', 'place ->', 'security ->', 'hell ->', 'cell ->', 'media ->', 'club ->', 'live ->', 'point ->', 'throw ->', 'sport ->', 'so ->', 'door ->', 'color ->', 'daughter ->', 'run ->', 'pretty ->', 'rule ->', 'top ->', 'mind ->', 'matter ->'] ['legale', 'lungo', 'finestra', 'informazione', 'sopra', 'campo', 'attuale', 'meglio', 'potenza', 'messaggio', 'ridurre', 'comportamento', 'luogo', 'sicurezza', 'inferno', 'cellula', 'media', 'club', 'abitare', 'punto', 'gettare', 'sport', 'così', 'porta', 'colore', 'figlia', 'correre', 'bello', 'regola', 's

100%|██████████| 1/1 [02:06<00:00, 126.93s/it]

orig: [[Array(-1.95312, dtype=bfloat16), Array(-1.85938, dtype=bfloat16), Array(-2.20312, dtype=bfloat16), Array(-2.20312, dtype=bfloat16), Array(-2.20312, dtype=bfloat16), Array(-2.09375, dtype=bfloat16), Array(-2.09375, dtype=bfloat16), Array(-1.91406, dtype=bfloat16), Array(-1.625, dtype=bfloat16), Array(-1.65625, dtype=bfloat16)], [Array(0.71875, dtype=float32), Array(0.75, dtype=float32), Array(0.71875, dtype=float32), Array(0.71875, dtype=float32), Array(0.71875, dtype=float32), Array(0.75, dtype=float32), Array(0.75, dtype=float32), Array(0.78125, dtype=float32), Array(0.78125, dtype=float32), Array(0.75, dtype=float32)]]
zero: [[Array(-8.0625, dtype=bfloat16), Array(-8.0625, dtype=bfloat16), Array(-7.84375, dtype=bfloat16), Array(-7.6875, dtype=bfloat16), Array(-7.9375, dtype=bfloat16), Array(-7.90625, dtype=bfloat16), Array(-8, dtype=bfloat16), Array(-7.9375, dtype=bfloat16), Array(-7.5, dtype=bfloat16), Array(-7.53125, dtype=bfloat16)], [Array(0.125, dtype=float32), Array(0.1




In [64]:
a

  0%|          | 0/1 [00:00<?, ?it/s]

-0.0090332 0.875
along window information over field current better power message reduce behavior place security hell cell media club live point throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00256348 0.898438
along window information over field current better power message reduce behavior place security hel cell media club live point throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00260925 0.957031
along window inform over campo current migli pot mess rid behavior place security h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00460815 1.09375
along fen inform over campo current migli pot mess rid comport pl seg h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00552368 1.17188


100%|██████████| 1/1 [00:22<00:00, 22.50s/it]

along window inform over campo current migli pot mess rid behavior place sic h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
orig: [[Array(-1.46094, dtype=bfloat16)], [Array(0.8, dtype=float32)]]
added: [[Array(-4.96875, dtype=bfloat16), Array(-4.75, dtype=bfloat16), Array(-1.53125, dtype=bfloat16), Array(-1.40625, dtype=bfloat16), Array(-1.72656, dtype=bfloat16)], [Array(0.15, dtype=float32), Array(0.15, dtype=float32), Array(0.45000002, dtype=float32), Array(0.6, dtype=float32), Array(0.55, dtype=float32)]]





In [61]:
shot_logprobs_added

In [25]:

from micrlhf.utils.load_sae import get_sae
sae = get_sae(layer, 6)

--2024-05-22 01:35:47--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-6-1.01E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.51, 108.156.211.95, 108.156.211.90, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.51|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/f057cb46f3d871ba03c66e707e3b3d8299322f36fa433862dc3fdca956715614?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716600947&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjYwMDk0N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZjA1N2NiNDZm

In [23]:
from micrlhf.utils.ito import grad_pursuit

In [45]:
k = 10

weights, recon = grad_pursuit(tvs[0], sae["W_dec"], k, pos_only=True)
w, i = jax.lax.top_k(jnp.abs(weights), k)

i

In [46]:
jnp.linalg.norm(tvs[0] - recon)

In [36]:
k = 5

[
    jax.lax.top_k(jnp.abs(grad_pursuit(x, sae["W_dec"], k, pos_only=True)[0]), k) for x in tvs
]

In [27]:

task_names = [
    "en_ru"
]

layer = 18

for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_sae = [[] for _ in range(2)]

    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=10)

    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )
    
    logits, resids = get_resids_call(inputs)

    tokens = tokenized["input_ids"]

    shot_logprobs_orig[0].append(
        get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False)
    )

    shot_logprobs_orig[1].append(
        shot_logprobs_orig[0][-1] == 0.
    )

    mask = inputs.tokens == 1599
    mask = mask.unwrap("batch", "seq")

    tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
        mask
    ]

    tv = tv.mean(
        axis=0
    )

    print(
        tv.mean(), tv.std()
    )

    act_add = add_vector(
        llama, tv, layer, scale=2.0, position="last"
    )

    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=11)


    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )

    logits = act_add(inputs)

    tokens = tokenized["input_ids"]

    shot_logprobs_added[0].append(
        get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
    )

    shot_logprobs_added[1].append(
        shot_logprobs_added[0][-1] == 0.
    )

    for k in range(0, 40):
        weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

        act_add = add_vector(
            llama, recon.astype('bfloat16'), layer, scale=1.0, position="last"
        )

        logits = act_add(inputs)
        
        shot_logprobs_sae[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_sae[1].append(
            shot_logprobs_sae[0][-1] == 0.
        )
        
    # print(f"orig: {shot_logprobs_orig}")
    # print(f"sae: {shot_logprobs_sae}")
    # print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = [list(map(np.mean, x)) for x in shot_logprobs_orig]
    shot_logprobs_sae = [list(map(np.mean, x)) for x in shot_logprobs_sae]
    shot_logprobs_added = [list(map(np.mean, x)) for x in shot_logprobs_added]


print(f"orig: {shot_logprobs_orig}")
print(f"sae: {shot_logprobs_sae}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

-0.00294495 1.10156
об amb ре  пи пи вопро ка кни пи день ок кни пу  г ве kitchen  less  ме вто ко<|placeholder1|> мате ко<|placeholder1|> ви мо би bus ф е г  та е date пи проду тан су ме прода tennis мате<|placeholder1|> ми п би би ве пе amb<|placeholder1|> пи check проду з<|placeholder1|> bus price са
зав ско ри дере сви га вопро ка кни х день ок те команди ло г ос ку мо у мо а се ко сез мате ко от с го би авто ф пи г дере сто е да на проду та су сту прода т мате сез ми учи би би з ру ско сез пи реги проду з от авто це са
br amb rice trees p bur answer calendar books bread money wind not vac boat be aut kitchen ice less ice ph second hat season math hat vac viol mount bi bus fl food be trees table me month drink product dan soup chair sell tennis math season minutes engineer business bi winter p amb season pie check product * vac bus price sal
зав ско ри дере сви га вопро ка кни х день ок те команди ло г ос ку мо у мо а се ко сез мате ко от с го би авто ф пи г дере сто е да на проду 

In [28]:
shot_logprobs_sae

In [40]:
shot_logprobs_sae[0][18]

In [41]:
k = 18

weights, recon = grad_pursuit(tv, sae["W_dec"], k, pos_only=True)
w, i = jax.lax.top_k(jnp.abs(weights), k)

i

In [38]:
w

In [None]:
#37312

In [44]:
pairs = tasks[task]
pairs = [list(x) for x in pairs.items()]
dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=101)


tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
inputs = tokenized_to_inputs(
    **tokenized
)

# weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

recon = sae["W_dec"][27215] * 20

act_add = add_vector(
    llama, recon.astype('bfloat16'), layer, scale=1.0, position="last"
)

logits = act_add(inputs)

get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True).mean()

less snow inst train st car t window concert the air new guitar vis table unique take a sal there actor sal higher emer per sal unique inst feu football library light river wine higher any time flight sal football milk less river су concert sa air higher service sal k summer milk meat football no name name milk tor ly sal plan less
у с университе мет сту ма вра ок кон те а ве ги ви сто с га су са по филь пи не ско ба сотруд день университе роман т би со ре ви не от время по профес т га у ре де кон са само не ус день пе ве мо г спо на час город мо би га це п у


In [26]:
dataset.completions

In [24]:
def prepare_inputs(dataset: ICLDataset):
    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )
    
    tokens = tokenized["input_ids"]

    return inputs, tokens

In [25]:
task_names = list(tasks.keys())
task_names = [x for x in task_names if x.startswith("algo")]
# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 64, 256

In [26]:
results = {}

for task in tqdm(task_names):
    results[task] = []
    
    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=10)

    clean_inputs, clean_tokens = prepare_inputs(dataset)

    _, resids = get_resids_call(clean_inputs)

    mask = clean_inputs.tokens == 1599
    mask = mask.unwrap("batch", "seq")

    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=11)

    add_inputs, add_tokens = prepare_inputs(dataset)

    for layer in trange(10, 22):
        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[mask]
        tv = tv.mean(axis=0)
        tv = tv.astype('bfloat16')

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        logits = act_add(add_inputs)

        diff = get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False)

        results[task].append(diff.mean())
        
        

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

In [27]:
for k in results:
    print(
        k, np.argmax([float(x) for x in results[k]]) + 10, max(results[k])
    )

algo_max 18 0
algo_min 17 0
algo_last 18 -0.0116577
algo_first 18 0
algo_sum 20 0
algo_most_common 18 -0.00585938


In [31]:
for k in results:
    print(
        k, np.argmax([float(x) for x in results[k]]) + 10, max(results[k])
    )

location_continent 21 -1.09375
football_player_position 21 -3.76562
location_religion 21 -1.09375
location_language 20 -0.457031
person_profession 21 -1.30469
location_country 21 -1.44531
country_capital 18 -1.25781
person_language 18 -0.287109
singular_plural 21 -0.261719
present_simple_past_simple 20 -0.341797
antonyms 14 -0.75
plural_singular 20 -0.133789
present_simple_past_perfect 18 -1.25781
present_simple_gerund 20 -0.171875
en_it 18 -1.29688
it_en 14 -1.75781
en_fr 18 -1.39062
en_es 18 -1.21875
fr_en 17 -1.25781
es_en 17 -1.09375
