In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:

pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [3]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [4]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")
from micrlhf.sampling import sample
from transformers import AutoTokenizer
import jax
# tokenizer = load_tokenizer(filename)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [6]:
from task_vector_utils import load_tasks, ICLDataset, ICLSequence
tasks = load_tasks()

Cloning into 'itv'...
fatal: unable to access 'https://github.com/roeehendel/icl_task_vectors data/itv/': URL using bad/illegal format or missing URL


In [7]:
tasks.keys()

In [8]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [9]:
def generate_task_prompt(task, n_shots):
    prompt = "<user>Follow the pattern\n{}"
    examples = []

    for s, t in random.sample(list(tasks[task].items()), n_shots):
        examples.append(f"{s} -> {t}")
    prompt = prompt.format("\n".join(examples))

    # print(prompt)

    return prompt

def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [10]:
def generate_task_inputs_old(task, n_shots, batch_size, max_length=128, seed=0):
    random.seed(seed)

    texts = [generate_task_prompt(task, n_shots) for _ in range(batch_size)]
    tokenized = tokenizer.batch_encode_plus(texts, padding="longest", max_length=max_length, truncation=True, return_tensors="np")

    inputs = tokenized_to_inputs(
        **tokenized
    )

    return inputs, tokenized

In [11]:
prompt = "<user>Follow the pattern\n{}"

In [12]:
def get_logprob_diff(logits: jnp.ndarray, completions: list[str], print_results=False, extra_space=False):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    pos = 1
    if extra_space:
        pos += 1

    target_tokens = [x[pos] for x in tokenizer(completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)
    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return target_logprobs - answer_logprobs.max(axis=-1)


In [13]:
task_names = [
    "en_it"
]
layer = 20
n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 16, 1024

In [14]:
from task_vector_utils import ICLRunner, logprob_loss, get_tv, make_act_adder


In [291]:

results = {}

task_names = [
    x for x in tasks.keys()
]

for task in tqdm(task_names):
    results[task] = []

    pairs = list(tasks[task].items())

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=10)

    tokenized = runner.get_tokens(runner.train_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, resids = get_resids_call(inputs)
    
    tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    logits = llama(inputs)
    

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"T: {task},  zero loss: {loss}"
    )

    results[task].append(loss)

    for layer in trange(10, 28):
        
        # tv = get_tv(resids[layer].value.unwrap("batch", "seq", "embedding"), train_tokens, shift = 1 if task.startswith("algo") else 0)
        tv = get_tv(resids[layer].value.unwrap("batch", "seq", "embedding"), train_tokens, shift = 0)
        tv = tv.astype('bfloat16')

        add_act = make_act_adder(llama, tv, tokens, layer, length=1, shift= 0)

        logits = add_act(inputs)

        loss = logprob_loss(
            logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
        )

        print(
            f"T: {task}, L: {layer}, Loss: {loss}"  
        )

        results[task].append(loss)
        
        

  0%|          | 0/28 [00:00<?, ?it/s]

T: location_continent,  zero loss: 7.96875


  0%|          | 0/18 [00:00<?, ?it/s]

T: location_continent, L: 10, Loss: 8.125
T: location_continent, L: 11, Loss: 8.3125
T: location_continent, L: 12, Loss: 8.125
T: location_continent, L: 13, Loss: 7.46875
T: location_continent, L: 14, Loss: 6.8125
T: location_continent, L: 15, Loss: 6.21875
T: location_continent, L: 16, Loss: 6.15625
T: location_continent, L: 17, Loss: 5.96875
T: location_continent, L: 18, Loss: 5.3125
T: location_continent, L: 19, Loss: 4.96875
T: location_continent, L: 20, Loss: 4.09375
T: location_continent, L: 21, Loss: 3.3125
T: location_continent, L: 22, Loss: 3.23438
T: location_continent, L: 23, Loss: 2.85938
T: location_continent, L: 24, Loss: 2.10938
T: location_continent, L: 25, Loss: 1.91406
T: location_continent, L: 26, Loss: 2.21875
T: location_continent, L: 27, Loss: 2.04688
T: football_player_position,  zero loss: 13.125


  0%|          | 0/18 [00:00<?, ?it/s]

T: football_player_position, L: 10, Loss: 13.0625
T: football_player_position, L: 11, Loss: 13.375
T: football_player_position, L: 12, Loss: 12.5625
T: football_player_position, L: 13, Loss: 10.875
T: football_player_position, L: 14, Loss: 9.625
T: football_player_position, L: 15, Loss: 9.25
T: football_player_position, L: 16, Loss: 9
T: football_player_position, L: 17, Loss: 8.3125
T: football_player_position, L: 18, Loss: 6.34375
T: football_player_position, L: 19, Loss: 5.96875
T: football_player_position, L: 20, Loss: 4.96875
T: football_player_position, L: 21, Loss: 4.65625
T: football_player_position, L: 22, Loss: 4.375
T: football_player_position, L: 23, Loss: 4.59375
T: football_player_position, L: 24, Loss: 4.875
T: football_player_position, L: 25, Loss: 4.6875
T: football_player_position, L: 26, Loss: 4.875
T: football_player_position, L: 27, Loss: 5.15625
T: location_religion,  zero loss: 7.875


  0%|          | 0/18 [00:00<?, ?it/s]

T: location_religion, L: 10, Loss: 8.4375
T: location_religion, L: 11, Loss: 8.625
T: location_religion, L: 12, Loss: 8.5
T: location_religion, L: 13, Loss: 8.1875
T: location_religion, L: 14, Loss: 7.5625
T: location_religion, L: 15, Loss: 7
T: location_religion, L: 16, Loss: 7.09375
T: location_religion, L: 17, Loss: 6.5625
T: location_religion, L: 18, Loss: 5.8125
T: location_religion, L: 19, Loss: 5.4375
T: location_religion, L: 20, Loss: 3.875
T: location_religion, L: 21, Loss: 3.10938
T: location_religion, L: 22, Loss: 2.57812
T: location_religion, L: 23, Loss: 2.3125
T: location_religion, L: 24, Loss: 2
T: location_religion, L: 25, Loss: 1.625
T: location_religion, L: 26, Loss: 1.89062
T: location_religion, L: 27, Loss: 1.55469
T: location_language,  zero loss: 7.9375


  0%|          | 0/18 [00:00<?, ?it/s]

T: location_language, L: 10, Loss: 8.6875
T: location_language, L: 11, Loss: 8.8125
T: location_language, L: 12, Loss: 8.875
T: location_language, L: 13, Loss: 8.375
T: location_language, L: 14, Loss: 7.3125
T: location_language, L: 15, Loss: 6.15625
T: location_language, L: 16, Loss: 5.84375
T: location_language, L: 17, Loss: 4.9375
T: location_language, L: 18, Loss: 4
T: location_language, L: 19, Loss: 3.70312
T: location_language, L: 20, Loss: 2.10938
T: location_language, L: 21, Loss: 1.82812
T: location_language, L: 22, Loss: 2.34375
T: location_language, L: 23, Loss: 2.54688
T: location_language, L: 24, Loss: 2.01562
T: location_language, L: 25, Loss: 2.125
T: location_language, L: 26, Loss: 2.54688
T: location_language, L: 27, Loss: 2.875
T: person_profession,  zero loss: 11.3125


  0%|          | 0/18 [00:00<?, ?it/s]

T: person_profession, L: 10, Loss: 11.375
T: person_profession, L: 11, Loss: 11.5
T: person_profession, L: 12, Loss: 11
T: person_profession, L: 13, Loss: 8.9375
T: person_profession, L: 14, Loss: 6.90625
T: person_profession, L: 15, Loss: 6.28125
T: person_profession, L: 16, Loss: 6.78125
T: person_profession, L: 17, Loss: 6
T: person_profession, L: 18, Loss: 4.84375
T: person_profession, L: 19, Loss: 4.5625
T: person_profession, L: 20, Loss: 3.59375
T: person_profession, L: 21, Loss: 3.64062
T: person_profession, L: 22, Loss: 3.67188
T: person_profession, L: 23, Loss: 3.67188
T: person_profession, L: 24, Loss: 3.53125
T: person_profession, L: 25, Loss: 3.5625
T: person_profession, L: 26, Loss: 3.64062
T: person_profession, L: 27, Loss: 3.57812
T: location_country,  zero loss: 5.625


  0%|          | 0/18 [00:00<?, ?it/s]

T: location_country, L: 10, Loss: 6.0625
T: location_country, L: 11, Loss: 6.125
T: location_country, L: 12, Loss: 6.21875
T: location_country, L: 13, Loss: 5.9375
T: location_country, L: 14, Loss: 5.0625
T: location_country, L: 15, Loss: 4.84375
T: location_country, L: 16, Loss: 4.84375
T: location_country, L: 17, Loss: 4.28125
T: location_country, L: 18, Loss: 4.125
T: location_country, L: 19, Loss: 3.9375
T: location_country, L: 20, Loss: 3.6875
T: location_country, L: 21, Loss: 3.67188
T: location_country, L: 22, Loss: 3.82812
T: location_country, L: 23, Loss: 3.875
T: location_country, L: 24, Loss: 3.82812
T: location_country, L: 25, Loss: 3.82812
T: location_country, L: 26, Loss: 3.96875
T: location_country, L: 27, Loss: 4.09375
T: country_capital,  zero loss: 6.40625


  0%|          | 0/18 [00:00<?, ?it/s]

T: country_capital, L: 10, Loss: 6.90625
T: country_capital, L: 11, Loss: 6.875
T: country_capital, L: 12, Loss: 7.0625
T: country_capital, L: 13, Loss: 6.1875
T: country_capital, L: 14, Loss: 5
T: country_capital, L: 15, Loss: 4.625
T: country_capital, L: 16, Loss: 4.3125
T: country_capital, L: 17, Loss: 4.125
T: country_capital, L: 18, Loss: 3.73438
T: country_capital, L: 19, Loss: 4.34375
T: country_capital, L: 20, Loss: 3.875
T: country_capital, L: 21, Loss: 3.82812
T: country_capital, L: 22, Loss: 3.9375
T: country_capital, L: 23, Loss: 4.4375
T: country_capital, L: 24, Loss: 4.5
T: country_capital, L: 25, Loss: 4.59375
T: country_capital, L: 26, Loss: 4.875
T: country_capital, L: 27, Loss: 5.25
T: person_language,  zero loss: 7.4375


  0%|          | 0/18 [00:00<?, ?it/s]

T: person_language, L: 10, Loss: 8.8125
T: person_language, L: 11, Loss: 8.875
T: person_language, L: 12, Loss: 8.5625
T: person_language, L: 13, Loss: 7.1875
T: person_language, L: 14, Loss: 5.96875
T: person_language, L: 15, Loss: 5.3125
T: person_language, L: 16, Loss: 4.875
T: person_language, L: 17, Loss: 4.25
T: person_language, L: 18, Loss: 3.6875
T: person_language, L: 19, Loss: 3.45312
T: person_language, L: 20, Loss: 1.91406
T: person_language, L: 21, Loss: 1.28906
T: person_language, L: 22, Loss: 1.64844
T: person_language, L: 23, Loss: 1.96875
T: person_language, L: 24, Loss: 1.59375
T: person_language, L: 25, Loss: 1.65625
T: person_language, L: 26, Loss: 1.88281
T: person_language, L: 27, Loss: 1.94531
T: singular_plural,  zero loss: 1.96875


  0%|          | 0/18 [00:00<?, ?it/s]

T: singular_plural, L: 10, Loss: 2.34375
T: singular_plural, L: 11, Loss: 2.20312
T: singular_plural, L: 12, Loss: 2.20312
T: singular_plural, L: 13, Loss: 1.89062
T: singular_plural, L: 14, Loss: 1.5
T: singular_plural, L: 15, Loss: 1.44531
T: singular_plural, L: 16, Loss: 1.45312
T: singular_plural, L: 17, Loss: 1.375
T: singular_plural, L: 18, Loss: 1.3125
T: singular_plural, L: 19, Loss: 1.16406
T: singular_plural, L: 20, Loss: 0.921875
T: singular_plural, L: 21, Loss: 0.910156
T: singular_plural, L: 22, Loss: 1.20312
T: singular_plural, L: 23, Loss: 1.39844
T: singular_plural, L: 24, Loss: 1.51562
T: singular_plural, L: 25, Loss: 1.49219
T: singular_plural, L: 26, Loss: 1.55469
T: singular_plural, L: 27, Loss: 1.75781
T: present_simple_past_simple,  zero loss: 2.48438


  0%|          | 0/18 [00:00<?, ?it/s]

T: present_simple_past_simple, L: 10, Loss: 3.75
T: present_simple_past_simple, L: 11, Loss: 3.76562
T: present_simple_past_simple, L: 12, Loss: 3.64062
T: present_simple_past_simple, L: 13, Loss: 2.375
T: present_simple_past_simple, L: 14, Loss: 2.14062
T: present_simple_past_simple, L: 15, Loss: 2.15625
T: present_simple_past_simple, L: 16, Loss: 2.26562
T: present_simple_past_simple, L: 17, Loss: 1.73438
T: present_simple_past_simple, L: 18, Loss: 1.25781
T: present_simple_past_simple, L: 19, Loss: 1.07812
T: present_simple_past_simple, L: 20, Loss: 0.773438
T: present_simple_past_simple, L: 21, Loss: 0.839844
T: present_simple_past_simple, L: 22, Loss: 0.867188
T: present_simple_past_simple, L: 23, Loss: 1.0625
T: present_simple_past_simple, L: 24, Loss: 1.46094
T: present_simple_past_simple, L: 25, Loss: 1.5625
T: present_simple_past_simple, L: 26, Loss: 1.60156
T: present_simple_past_simple, L: 27, Loss: 1.82812
T: antonyms,  zero loss: 5.96875


  0%|          | 0/18 [00:00<?, ?it/s]

T: antonyms, L: 10, Loss: 5.875
T: antonyms, L: 11, Loss: 6.15625
T: antonyms, L: 12, Loss: 5.6875
T: antonyms, L: 13, Loss: 3.8125
T: antonyms, L: 14, Loss: 3.10938
T: antonyms, L: 15, Loss: 2.75
T: antonyms, L: 16, Loss: 2.8125
T: antonyms, L: 17, Loss: 2.5625
T: antonyms, L: 18, Loss: 2.29688
T: antonyms, L: 19, Loss: 2.45312
T: antonyms, L: 20, Loss: 2.78125
T: antonyms, L: 21, Loss: 3.82812
T: antonyms, L: 22, Loss: 4.15625
T: antonyms, L: 23, Loss: 4.09375
T: antonyms, L: 24, Loss: 4.375
T: antonyms, L: 25, Loss: 4.65625
T: antonyms, L: 26, Loss: 4.90625
T: antonyms, L: 27, Loss: 4.9375
T: plural_singular,  zero loss: 2.53125


  0%|          | 0/18 [00:00<?, ?it/s]

T: plural_singular, L: 10, Loss: 2.84375
T: plural_singular, L: 11, Loss: 2.75
T: plural_singular, L: 12, Loss: 2.90625
T: plural_singular, L: 13, Loss: 2.26562
T: plural_singular, L: 14, Loss: 1.95312
T: plural_singular, L: 15, Loss: 2.0625
T: plural_singular, L: 16, Loss: 1.94531
T: plural_singular, L: 17, Loss: 1.73438
T: plural_singular, L: 18, Loss: 1.53906
T: plural_singular, L: 19, Loss: 1.54688
T: plural_singular, L: 20, Loss: 1.59375
T: plural_singular, L: 21, Loss: 1.6875
T: plural_singular, L: 22, Loss: 1.82812
T: plural_singular, L: 23, Loss: 2.1875
T: plural_singular, L: 24, Loss: 2.5
T: plural_singular, L: 25, Loss: 2.57812
T: plural_singular, L: 26, Loss: 2.60938
T: plural_singular, L: 27, Loss: 2.53125
T: present_simple_past_perfect,  zero loss: 4.65625


  0%|          | 0/18 [00:00<?, ?it/s]

T: present_simple_past_perfect, L: 10, Loss: 5.1875
T: present_simple_past_perfect, L: 11, Loss: 5.28125
T: present_simple_past_perfect, L: 12, Loss: 4.9375
T: present_simple_past_perfect, L: 13, Loss: 3.9375
T: present_simple_past_perfect, L: 14, Loss: 3.76562
T: present_simple_past_perfect, L: 15, Loss: 3.8125
T: present_simple_past_perfect, L: 16, Loss: 3.98438
T: present_simple_past_perfect, L: 17, Loss: 3.51562
T: present_simple_past_perfect, L: 18, Loss: 3
T: present_simple_past_perfect, L: 19, Loss: 2.89062
T: present_simple_past_perfect, L: 20, Loss: 2.48438
T: present_simple_past_perfect, L: 21, Loss: 2.375
T: present_simple_past_perfect, L: 22, Loss: 2.65625
T: present_simple_past_perfect, L: 23, Loss: 2.54688
T: present_simple_past_perfect, L: 24, Loss: 2.71875
T: present_simple_past_perfect, L: 25, Loss: 2.65625
T: present_simple_past_perfect, L: 26, Loss: 2.98438
T: present_simple_past_perfect, L: 27, Loss: 3.3125
T: present_simple_gerund,  zero loss: 3.51562


  0%|          | 0/18 [00:00<?, ?it/s]

T: present_simple_gerund, L: 10, Loss: 4.4375
T: present_simple_gerund, L: 11, Loss: 4.1875
T: present_simple_gerund, L: 12, Loss: 4.15625
T: present_simple_gerund, L: 13, Loss: 3.40625
T: present_simple_gerund, L: 14, Loss: 2.75
T: present_simple_gerund, L: 15, Loss: 2.92188
T: present_simple_gerund, L: 16, Loss: 2.875
T: present_simple_gerund, L: 17, Loss: 2.54688
T: present_simple_gerund, L: 18, Loss: 2.60938
T: present_simple_gerund, L: 19, Loss: 2.65625
T: present_simple_gerund, L: 20, Loss: 1.19531
T: present_simple_gerund, L: 21, Loss: 1.17188
T: present_simple_gerund, L: 22, Loss: 1.35156
T: present_simple_gerund, L: 23, Loss: 1.39062
T: present_simple_gerund, L: 24, Loss: 1.51562
T: present_simple_gerund, L: 25, Loss: 1.55469
T: present_simple_gerund, L: 26, Loss: 1.71094
T: present_simple_gerund, L: 27, Loss: 2.14062
T: en_it,  zero loss: 15.375


  0%|          | 0/18 [00:00<?, ?it/s]

T: en_it, L: 10, Loss: 14.5625
T: en_it, L: 11, Loss: 14.8125
T: en_it, L: 12, Loss: 13.4375
T: en_it, L: 13, Loss: 10.625
T: en_it, L: 14, Loss: 9.375
T: en_it, L: 15, Loss: 9.1875
T: en_it, L: 16, Loss: 9.6875
T: en_it, L: 17, Loss: 4.9375
T: en_it, L: 18, Loss: 3.0625
T: en_it, L: 19, Loss: 4.375
T: en_it, L: 20, Loss: 7.53125
T: en_it, L: 21, Loss: 6.21875
T: en_it, L: 22, Loss: 5.71875
T: en_it, L: 23, Loss: 7.03125
T: en_it, L: 24, Loss: 7.09375
T: en_it, L: 25, Loss: 7.46875
T: en_it, L: 26, Loss: 8.625
T: en_it, L: 27, Loss: 9.625
T: it_en,  zero loss: 8.0625


  0%|          | 0/18 [00:00<?, ?it/s]

T: it_en, L: 10, Loss: 7.3125
T: it_en, L: 11, Loss: 7.25
T: it_en, L: 12, Loss: 6.59375
T: it_en, L: 13, Loss: 4.34375
T: it_en, L: 14, Loss: 4.28125
T: it_en, L: 15, Loss: 4.625
T: it_en, L: 16, Loss: 5.03125
T: it_en, L: 17, Loss: 4.3125
T: it_en, L: 18, Loss: 3.8125
T: it_en, L: 19, Loss: 4.125
T: it_en, L: 20, Loss: 4.25
T: it_en, L: 21, Loss: 4.75
T: it_en, L: 22, Loss: 5.09375
T: it_en, L: 23, Loss: 4.84375
T: it_en, L: 24, Loss: 4.84375
T: it_en, L: 25, Loss: 5.1875
T: it_en, L: 26, Loss: 5.21875
T: it_en, L: 27, Loss: 5.375
T: en_ru,  zero loss: 16.125


  0%|          | 0/18 [00:00<?, ?it/s]

T: en_ru, L: 10, Loss: 15.75
T: en_ru, L: 11, Loss: 15.5625
T: en_ru, L: 12, Loss: 15
T: en_ru, L: 13, Loss: 14
T: en_ru, L: 14, Loss: 12.6875
T: en_ru, L: 15, Loss: 12.9375
T: en_ru, L: 16, Loss: 13.1875
T: en_ru, L: 17, Loss: 8.8125
T: en_ru, L: 18, Loss: 5.8125
T: en_ru, L: 19, Loss: 5.09375
T: en_ru, L: 20, Loss: 7.46875
T: en_ru, L: 21, Loss: 3.76562
T: en_ru, L: 22, Loss: 4.125
T: en_ru, L: 23, Loss: 4.5
T: en_ru, L: 24, Loss: 4.21875
T: en_ru, L: 25, Loss: 4.1875
T: en_ru, L: 26, Loss: 5.1875
T: en_ru, L: 27, Loss: 5.875
T: en_fr,  zero loss: 12.625


  0%|          | 0/18 [00:00<?, ?it/s]

T: en_fr, L: 10, Loss: 12.125
T: en_fr, L: 11, Loss: 12
T: en_fr, L: 12, Loss: 11.4375
T: en_fr, L: 13, Loss: 9.8125
T: en_fr, L: 14, Loss: 8.9375
T: en_fr, L: 15, Loss: 9.125
T: en_fr, L: 16, Loss: 9.1875
T: en_fr, L: 17, Loss: 6.71875
T: en_fr, L: 18, Loss: 5.6875
T: en_fr, L: 19, Loss: 6.0625
T: en_fr, L: 20, Loss: 7.375
T: en_fr, L: 21, Loss: 6.03125
T: en_fr, L: 22, Loss: 5.90625
T: en_fr, L: 23, Loss: 6.4375
T: en_fr, L: 24, Loss: 6.78125
T: en_fr, L: 25, Loss: 6.71875
T: en_fr, L: 26, Loss: 7.375
T: en_fr, L: 27, Loss: 7.6875
T: en_es,  zero loss: 14.625


  0%|          | 0/18 [00:00<?, ?it/s]

T: en_es, L: 10, Loss: 13.875
T: en_es, L: 11, Loss: 14
T: en_es, L: 12, Loss: 13.375
T: en_es, L: 13, Loss: 12
T: en_es, L: 14, Loss: 10.875
T: en_es, L: 15, Loss: 11
T: en_es, L: 16, Loss: 11.1875
T: en_es, L: 17, Loss: 6
T: en_es, L: 18, Loss: 4.65625
T: en_es, L: 19, Loss: 5.8125
T: en_es, L: 20, Loss: 8.125
T: en_es, L: 21, Loss: 6.0625
T: en_es, L: 22, Loss: 6.34375
T: en_es, L: 23, Loss: 7.15625
T: en_es, L: 24, Loss: 7.625
T: en_es, L: 25, Loss: 7.46875
T: en_es, L: 26, Loss: 8.3125
T: en_es, L: 27, Loss: 8.8125
T: fr_en,  zero loss: 7.3125


  0%|          | 0/18 [00:00<?, ?it/s]

T: fr_en, L: 10, Loss: 6.1875
T: fr_en, L: 11, Loss: 6.09375
T: fr_en, L: 12, Loss: 5.4375
T: fr_en, L: 13, Loss: 3.0625
T: fr_en, L: 14, Loss: 3.29688
T: fr_en, L: 15, Loss: 3.35938
T: fr_en, L: 16, Loss: 3.60938
T: fr_en, L: 17, Loss: 2.79688
T: fr_en, L: 18, Loss: 2.45312
T: fr_en, L: 19, Loss: 2.85938
T: fr_en, L: 20, Loss: 2.84375
T: fr_en, L: 21, Loss: 3.3125
T: fr_en, L: 22, Loss: 3.6875
T: fr_en, L: 23, Loss: 3.60938
T: fr_en, L: 24, Loss: 3.875
T: fr_en, L: 25, Loss: 4.0625
T: fr_en, L: 26, Loss: 4.15625
T: fr_en, L: 27, Loss: 4.3125
T: es_en,  zero loss: 8.5


  0%|          | 0/18 [00:00<?, ?it/s]

T: es_en, L: 10, Loss: 7.46875
T: es_en, L: 11, Loss: 7.21875
T: es_en, L: 12, Loss: 6.25
T: es_en, L: 13, Loss: 3.17188
T: es_en, L: 14, Loss: 3.32812
T: es_en, L: 15, Loss: 4.09375
T: es_en, L: 16, Loss: 4.53125
T: es_en, L: 17, Loss: 2.70312
T: es_en, L: 18, Loss: 1.71094
T: es_en, L: 19, Loss: 1.92969
T: es_en, L: 20, Loss: 1.34375
T: es_en, L: 21, Loss: 1.96094
T: es_en, L: 22, Loss: 2.03125
T: es_en, L: 23, Loss: 2.32812
T: es_en, L: 24, Loss: 2.76562
T: es_en, L: 25, Loss: 3.09375
T: es_en, L: 26, Loss: 3.21875
T: es_en, L: 27, Loss: 3.5
T: en_de,  zero loss: 12.125


  0%|          | 0/18 [00:00<?, ?it/s]

T: en_de, L: 10, Loss: 11.3125
T: en_de, L: 11, Loss: 11
T: en_de, L: 12, Loss: 10.5
T: en_de, L: 13, Loss: 9.5
T: en_de, L: 14, Loss: 8.5625
T: en_de, L: 15, Loss: 8.9375
T: en_de, L: 16, Loss: 9.0625
T: en_de, L: 17, Loss: 7.4375
T: en_de, L: 18, Loss: 6.9375
T: en_de, L: 19, Loss: 7.1875
T: en_de, L: 20, Loss: 8.125
T: en_de, L: 21, Loss: 7.53125
T: en_de, L: 22, Loss: 7.40625
T: en_de, L: 23, Loss: 7.6875
T: en_de, L: 24, Loss: 8
T: en_de, L: 25, Loss: 7.8125
T: en_de, L: 26, Loss: 8.3125
T: en_de, L: 27, Loss: 8.5
T: algo_max,  zero loss: 3.0625


  0%|          | 0/18 [00:00<?, ?it/s]

T: algo_max, L: 10, Loss: 3.85938
T: algo_max, L: 11, Loss: 3.15625
T: algo_max, L: 12, Loss: 3.65625
T: algo_max, L: 13, Loss: 3.92188
T: algo_max, L: 14, Loss: 3.23438
T: algo_max, L: 15, Loss: 3.14062
T: algo_max, L: 16, Loss: 3.03125
T: algo_max, L: 17, Loss: 2.76562
T: algo_max, L: 18, Loss: 2.89062
T: algo_max, L: 19, Loss: 2.90625
T: algo_max, L: 20, Loss: 2.95312
T: algo_max, L: 21, Loss: 2.89062
T: algo_max, L: 22, Loss: 2.78125
T: algo_max, L: 23, Loss: 2.8125
T: algo_max, L: 24, Loss: 2.73438
T: algo_max, L: 25, Loss: 2.75
T: algo_max, L: 26, Loss: 2.75
T: algo_max, L: 27, Loss: 2.76562
T: algo_min,  zero loss: 3


  0%|          | 0/18 [00:00<?, ?it/s]

T: algo_min, L: 10, Loss: 4.03125
T: algo_min, L: 11, Loss: 3.23438
T: algo_min, L: 12, Loss: 2.64062
T: algo_min, L: 13, Loss: 2.5
T: algo_min, L: 14, Loss: 2.5625
T: algo_min, L: 15, Loss: 2.71875
T: algo_min, L: 16, Loss: 2.82812
T: algo_min, L: 17, Loss: 2.65625
T: algo_min, L: 18, Loss: 2.73438
T: algo_min, L: 19, Loss: 2.71875
T: algo_min, L: 20, Loss: 2.78125
T: algo_min, L: 21, Loss: 2.79688
T: algo_min, L: 22, Loss: 2.8125
T: algo_min, L: 23, Loss: 3
T: algo_min, L: 24, Loss: 2.89062
T: algo_min, L: 25, Loss: 2.84375
T: algo_min, L: 26, Loss: 2.79688
T: algo_min, L: 27, Loss: 2.79688
T: algo_last,  zero loss: 3.39062


  0%|          | 0/18 [00:00<?, ?it/s]

T: algo_last, L: 10, Loss: 4.0625
T: algo_last, L: 11, Loss: 3.45312
T: algo_last, L: 12, Loss: 3.15625
T: algo_last, L: 13, Loss: 3.40625
T: algo_last, L: 14, Loss: 2.78125
T: algo_last, L: 15, Loss: 2.92188
T: algo_last, L: 16, Loss: 2.90625
T: algo_last, L: 17, Loss: 2.53125
T: algo_last, L: 18, Loss: 2.79688
T: algo_last, L: 19, Loss: 2.84375
T: algo_last, L: 20, Loss: 2.95312
T: algo_last, L: 21, Loss: 2.90625
T: algo_last, L: 22, Loss: 2.85938
T: algo_last, L: 23, Loss: 3
T: algo_last, L: 24, Loss: 2.92188
T: algo_last, L: 25, Loss: 2.95312
T: algo_last, L: 26, Loss: 2.9375
T: algo_last, L: 27, Loss: 2.95312
T: algo_first,  zero loss: 3.20312


  0%|          | 0/18 [00:00<?, ?it/s]

T: algo_first, L: 10, Loss: 3.92188
T: algo_first, L: 11, Loss: 3.65625
T: algo_first, L: 12, Loss: 3.98438
T: algo_first, L: 13, Loss: 3.8125
T: algo_first, L: 14, Loss: 3.57812
T: algo_first, L: 15, Loss: 3.10938
T: algo_first, L: 16, Loss: 3.25
T: algo_first, L: 17, Loss: 2.92188
T: algo_first, L: 18, Loss: 2.57812
T: algo_first, L: 19, Loss: 2.76562
T: algo_first, L: 20, Loss: 2.8125
T: algo_first, L: 21, Loss: 2.8125
T: algo_first, L: 22, Loss: 2.82812
T: algo_first, L: 23, Loss: 3.09375
T: algo_first, L: 24, Loss: 3.07812
T: algo_first, L: 25, Loss: 2.98438
T: algo_first, L: 26, Loss: 2.98438
T: algo_first, L: 27, Loss: 2.98438
T: algo_sum,  zero loss: 5.25


  0%|          | 0/18 [00:00<?, ?it/s]

T: algo_sum, L: 10, Loss: 5.4375
T: algo_sum, L: 11, Loss: 4.9375
T: algo_sum, L: 12, Loss: 4.84375
T: algo_sum, L: 13, Loss: 4.84375
T: algo_sum, L: 14, Loss: 4.71875
T: algo_sum, L: 15, Loss: 4.84375
T: algo_sum, L: 16, Loss: 4.71875
T: algo_sum, L: 17, Loss: 4.78125
T: algo_sum, L: 18, Loss: 4.78125
T: algo_sum, L: 19, Loss: 4.875
T: algo_sum, L: 20, Loss: 4.8125
T: algo_sum, L: 21, Loss: 4.84375
T: algo_sum, L: 22, Loss: 4.96875
T: algo_sum, L: 23, Loss: 5.03125
T: algo_sum, L: 24, Loss: 5.03125
T: algo_sum, L: 25, Loss: 5
T: algo_sum, L: 26, Loss: 5.0625
T: algo_sum, L: 27, Loss: 5.0625
T: algo_most_common,  zero loss: 2.92188


  0%|          | 0/18 [00:00<?, ?it/s]

T: algo_most_common, L: 10, Loss: 3.71875
T: algo_most_common, L: 11, Loss: 2.95312
T: algo_most_common, L: 12, Loss: 2.67188
T: algo_most_common, L: 13, Loss: 2.78125
T: algo_most_common, L: 14, Loss: 2.79688
T: algo_most_common, L: 15, Loss: 2.89062
T: algo_most_common, L: 16, Loss: 2.9375
T: algo_most_common, L: 17, Loss: 2.70312
T: algo_most_common, L: 18, Loss: 2.79688
T: algo_most_common, L: 19, Loss: 2.79688
T: algo_most_common, L: 20, Loss: 2.84375
T: algo_most_common, L: 21, Loss: 2.95312
T: algo_most_common, L: 22, Loss: 2.95312
T: algo_most_common, L: 23, Loss: 3.0625
T: algo_most_common, L: 24, Loss: 2.98438
T: algo_most_common, L: 25, Loss: 2.96875
T: algo_most_common, L: 26, Loss: 2.92188
T: algo_most_common, L: 27, Loss: 2.9375


In [16]:

tvs = {}

task_names = [
    ("en_es", 18),
    ("en_fr", 18),
    ("en_it", 18),
    ("en_de", 18),
    ("en_ru", 18),
    ("es_en", 18),
    ("fr_en", 18),
    ("it_en", 20),
    ("antonyms", 18),
    ("plural_singular", 18),
    ("present_simple_past_simple", 20),
    ("country_capital", 18),
    ("algo_first", 18),
    ("location_country", 21)
]


for task, layer in tqdm(task_names):

    pairs = list(tasks[task].items())

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=10)

    tokenized = runner.get_tokens(runner.train_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    train_tokens = tokenized["input_ids"]

    _, resids = get_resids_call(inputs)
    
    tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]

    logits = llama(inputs)
    

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"T: {task},  zero loss: {loss}"
    )
        
    # tv = get_tv(resids[layer].value.unwrap("batch", "seq", "embedding"), train_tokens, shift = 1 if task.startswith("algo") else 0)
    tv = get_tv(resids[layer].value.unwrap("batch", "seq", "embedding"), train_tokens, shift = 0)
    tv = tv.astype('bfloat16')

    add_act = make_act_adder(llama, tv, tokens, layer, length=1, shift= 0)

    logits = add_act(inputs)

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"T: {task}, L: {layer}, Loss: {loss}"  
    )

    tvs[task] = tv

  0%|          | 0/14 [00:00<?, ?it/s]

T: en_es,  zero loss: 14.625
T: en_es, L: 18, Loss: 4.65625
T: en_fr,  zero loss: 12.625
T: en_fr, L: 18, Loss: 5.6875
T: en_it,  zero loss: 15.375
T: en_it, L: 18, Loss: 3.0625
T: en_de,  zero loss: 12.125
T: en_de, L: 18, Loss: 6.9375
T: en_ru,  zero loss: 16.125
T: en_ru, L: 18, Loss: 5.8125
T: es_en,  zero loss: 8.5
T: es_en, L: 18, Loss: 1.71094
T: fr_en,  zero loss: 7.3125
T: fr_en, L: 18, Loss: 2.45312
T: it_en,  zero loss: 8.0625
T: it_en, L: 20, Loss: 4.25
T: antonyms,  zero loss: 5.96875
T: antonyms, L: 18, Loss: 2.29688
T: plural_singular,  zero loss: 2.53125
T: plural_singular, L: 18, Loss: 1.53906
T: present_simple_past_simple,  zero loss: 2.48438
T: present_simple_past_simple, L: 20, Loss: 0.773438
T: country_capital,  zero loss: 6.40625
T: country_capital, L: 18, Loss: 3.73438
T: algo_first,  zero loss: 3.20312
T: algo_first, L: 18, Loss: 2.57812
T: location_country,  zero loss: 5.625
T: location_country, L: 21, Loss: 3.67188


In [19]:
from micrlhf.utils.vector_storage import save_and_upload_vector

for task, layer in task_names:
    tv = tvs[task]
    save_and_upload_vector(f"task_vectors/{task}:{layer}.npz", tv)

In [297]:
for task, r in results.items():
    print(f"{task}: w/o: {r[0]}, w/: {min(r[1:])}, layer: {r.index(min(r[1:])) + 9}")

location_continent: w/o: 7.96875, w/: 1.91406, layer: 25
football_player_position: w/o: 13.125, w/: 4.375, layer: 22
location_religion: w/o: 7.875, w/: 1.55469, layer: 27
location_language: w/o: 7.9375, w/: 1.82812, layer: 21
person_profession: w/o: 11.3125, w/: 3.53125, layer: 24
location_country: w/o: 5.625, w/: 3.67188, layer: 21
country_capital: w/o: 6.40625, w/: 3.73438, layer: 18
person_language: w/o: 7.4375, w/: 1.28906, layer: 21
singular_plural: w/o: 1.96875, w/: 0.910156, layer: 21
present_simple_past_simple: w/o: 2.48438, w/: 0.773438, layer: 20
antonyms: w/o: 5.96875, w/: 2.29688, layer: 18
plural_singular: w/o: 2.53125, w/: 1.53906, layer: 18
present_simple_past_perfect: w/o: 4.65625, w/: 2.375, layer: 21
present_simple_gerund: w/o: 3.51562, w/: 1.17188, layer: 21
en_it: w/o: 15.375, w/: 3.0625, layer: 18
it_en: w/o: 8.0625, w/: 3.8125, layer: 18
en_ru: w/o: 16.125, w/: 3.76562, layer: 21
en_fr: w/o: 12.625, w/: 5.6875, layer: 18
en_es: w/o: 14.625, w/: 4.65625, layer: 18


In [282]:
for task in tqdm(task_names):
    # results[task] = []

    pairs = list(tasks[task].items())

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=10)

    tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]
    logits = llama(inputs)
    

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"T: {task},  Loss: {loss}"
    )

  0%|          | 0/6 [00:00<?, ?it/s]

T: algo_max,  Loss: 3.0625
T: algo_min,  Loss: 3
T: algo_last,  Loss: 3.39062
T: algo_first,  Loss: 3.20312
T: algo_sum,  Loss: 5.25
T: algo_most_common,  Loss: 2.92188


In [283]:
for task in tqdm(task_names):
    # results[task] = []

    pairs = list(tasks[task].items())

    runner = ICLRunner(task, pairs, batch_size=batch_size, n_shot=n_few_shots-1, max_seq_len=max_seq_len, seed=10)

    tokenized = runner.get_tokens(runner.train_pairs, tokenizer)
    inputs = tokenized_to_inputs(**tokenized)
    tokens = tokenized["input_ids"]
    logits = llama(inputs)
    

    loss = logprob_loss(
        logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=2
    )

    print(
        f"T: {task},  Loss: {loss}"
    )

  0%|          | 0/6 [00:00<?, ?it/s]

T: algo_max,  Loss: 1.82031
T: algo_min,  Loss: 2.45312
T: algo_last,  Loss: 0.0549316
T: algo_first,  Loss: 1.6875
T: algo_sum,  Loss: 4.09375
T: algo_most_common,  Loss: 3.04688


In [263]:
runner.get_prompt(runner.train_pairs[0])

In [264]:
tokenized = runner.get_tokens(runner.train_pairs, tokenizer)
inputs = tokenized_to_inputs(**tokenized)
tokens = tokenized["input_ids"]

logits, resids = get_resids_call(inputs)


tv = get_tv(resids[layer].value.unwrap("batch", "seq", "embedding"), tokens, shift = 1 if task.startswith("algo") else 0)

In [265]:
logprob_loss(
    logits.unwrap("batch", "seq", "vocabulary"), tokens, n_first=1, shift=1 if task.startswith("algo") else 0
)

In [275]:
from task_vector_utils import make_act_adder

tokenized = runner.get_tokens(runner.eval_pairs, tokenizer)
inputs = tokenized_to_inputs(**tokenized)
tokens = tokenized["input_ids"]

add_act = make_act_adder(llama, tv * 0, tokens, layer, length=1, shift=1 if task.startswith("algo") else 0)

In [276]:
logits = add_act(inputs)

In [269]:
positions = jnp.argwhere(tokens == 1599)[:, -1] 

pred = jnp.take_along_axis(logits.unwrap("batch", "seq", "vocabulary"), positions[:, None, None], axis=1).squeeze().argmax(axis=-1)

In [270]:
pred

In [271]:
tokenizer.decode(pred)

In [247]:
tokens[:, 1:]

In [277]:
logprob_loss(
    logits.unwrap("batch", "seq", "vocabulary"), tokens, shift=1 if task.startswith("algo") else 0, n_first=1
)

In [16]:
tvs = []

for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_zero = [[] for _ in range(2)]
    for seed in trange(n_seeds):
        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=seed, prepend_space=task.startswith("algo"))

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )
        
        logits, resids = get_resids_call(inputs)
    
        tokens = tokenized["input_ids"]

        shot_logprobs_orig[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True, extra_space=task.startswith("algo"))
        )

        shot_logprobs_orig[1].append(
            shot_logprobs_orig[0][-1] == 0.
        )

        mask = inputs.tokens == 1599
        mask = mask.unwrap("batch", "seq")

        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
            mask
        ]

        tv = tv.mean(
            axis=0
        )

        tvs.append(tv)

        print(
            tv.mean(), tv.std()
        )

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=seed+1, prepend_space=task.startswith("algo"))


        print(
            dataset.prompts, dataset.completions
        )

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )

        logits = act_add(inputs)

        tokens = tokenized["input_ids"]

        shot_logprobs_added[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True, extra_space=task.startswith("algo"))
        )

        shot_logprobs_added[1].append(
            shot_logprobs_added[0][-1] == 0.
        )

        logits, _ = get_resids_call(inputs)
        
        shot_logprobs_zero[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True, extra_space=task.startswith("algo"))
        )

        shot_logprobs_zero[1].append(
            shot_logprobs_zero[0][-1] == 0.
        )
        
    print(f"orig: {shot_logprobs_orig}")
    print(f"zero: {shot_logprobs_zero}")
    print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = [list(map(np.mean, x)) for x in shot_logprobs_orig]
    shot_logprobs_zero = [list(map(np.mean, x)) for x in shot_logprobs_zero]
    shot_logprobs_added = [list(map(np.mean, x)) for x in shot_logprobs_added]


print(f"orig: {shot_logprobs_orig}")
print(f"zero: {shot_logprobs_zero}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [64]:
a

  0%|          | 0/1 [00:00<?, ?it/s]

-0.0090332 0.875
along window information over field current better power message reduce behavior place security hell cell media club live point throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00256348 0.898438
along window information over field current better power message reduce behavior place security hel cell media club live point throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00260925 0.957031
along window inform over campo current migli pot mess rid behavior place security h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00460815 1.09375
along fen inform over campo current migli pot mess rid comport pl seg h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00552368 1.17188


100%|██████████| 1/1 [00:22<00:00, 22.50s/it]

along window inform over campo current migli pot mess rid behavior place sic h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
orig: [[Array(-1.46094, dtype=bfloat16)], [Array(0.8, dtype=float32)]]
added: [[Array(-4.96875, dtype=bfloat16), Array(-4.75, dtype=bfloat16), Array(-1.53125, dtype=bfloat16), Array(-1.40625, dtype=bfloat16), Array(-1.72656, dtype=bfloat16)], [Array(0.15, dtype=float32), Array(0.15, dtype=float32), Array(0.45000002, dtype=float32), Array(0.6, dtype=float32), Array(0.55, dtype=float32)]]





In [61]:
shot_logprobs_added

In [27]:

from micrlhf.utils.load_sae import get_sae
sae = get_sae(layer, 6)

--2024-05-23 18:36:35--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l18-test-run-6-1.01E-05/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.95, 108.156.211.51, 108.156.211.90, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.95|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/f057cb46f3d871ba03c66e707e3b3d8299322f36fa433862dc3fdca956715614?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716748596&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjc0ODU5Nn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvZjA1N2NiNDZm

In [28]:
from micrlhf.utils.ito import grad_pursuit

In [30]:
k = 20

weights, recon = grad_pursuit(tv, sae["W_dec"], k, pos_only=True)
w, i = jax.lax.top_k(jnp.abs(weights), k)

i

In [46]:
jnp.linalg.norm(tvs[0] - recon)

In [36]:
k = 5

[
    jax.lax.top_k(jnp.abs(grad_pursuit(x, sae["W_dec"], k, pos_only=True)[0]), k) for x in tvs
]

In [27]:

task_names = [
    "en_ru"
]

layer = 18

for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_sae = [[] for _ in range(2)]

    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=10, prepend_space=task.startswith("algo"))

    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )
    
    logits, resids = get_resids_call(inputs)

    tokens = tokenized["input_ids"]

    shot_logprobs_orig[0].append(
        get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False, extra_space=task.startswith("algo"))
    )

    shot_logprobs_orig[1].append(
        shot_logprobs_orig[0][-1] == 0.
    )

    mask = inputs.tokens == 1599
    mask = mask.unwrap("batch", "seq")

    tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
        mask
    ]

    tv = tv.mean(
        axis=0
    )

    print(
        tv.mean(), tv.std()
    )

    act_add = add_vector(
        llama, tv, layer, scale=2.0, position="last"
    )

    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=11, prepend_space=task.startswith("algo"))


    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )

    logits = act_add(inputs)

    tokens = tokenized["input_ids"]

    shot_logprobs_added[0].append(
        get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True, extra_space=task.startswith("algo"))
    )

    shot_logprobs_added[1].append(
        shot_logprobs_added[0][-1] == 0.
    )

    for k in range(0, 40):
        weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

        act_add = add_vector(
            llama, recon.astype('bfloat16'), layer, scale=1.0, position="last"
        )

        logits = act_add(inputs)
        
        shot_logprobs_sae[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True, extra_space=task.startswith("algo"))
        )

        shot_logprobs_sae[1].append(
            shot_logprobs_sae[0][-1] == 0.
        )
        
    # print(f"orig: {shot_logprobs_orig}")
    # print(f"sae: {shot_logprobs_sae}")
    # print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = [list(map(np.mean, x)) for x in shot_logprobs_orig]
    shot_logprobs_sae = [list(map(np.mean, x)) for x in shot_logprobs_sae]
    shot_logprobs_added = [list(map(np.mean, x)) for x in shot_logprobs_added]


print(f"orig: {shot_logprobs_orig}")
print(f"sae: {shot_logprobs_sae}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

-0.00294495 1.10156
об amb ре  пи пи вопро ка кни пи день ок кни пу  г ве kitchen  less  ме вто ко<|placeholder1|> мате ко<|placeholder1|> ви мо би bus ф е г  та е date пи проду тан су ме прода tennis мате<|placeholder1|> ми п би би ве пе amb<|placeholder1|> пи check проду з<|placeholder1|> bus price са
зав ско ри дере сви га вопро ка кни х день ок те команди ло г ос ку мо у мо а се ко сез мате ко от с го би авто ф пи г дере сто е да на проду та су сту прода т мате сез ми учи би би з ру ско сез пи реги проду з от авто це са
br amb rice trees p bur answer calendar books bread money wind not vac boat be aut kitchen ice less ice ph second hat season math hat vac viol mount bi bus fl food be trees table me month drink product dan soup chair sell tennis math season minutes engineer business bi winter p amb season pie check product * vac bus price sal
зав ско ри дере сви га вопро ка кни х день ок те команди ло г ос ку мо у мо а се ко сез мате ко от с го би авто ф пи г дере сто е да на проду 

In [28]:
shot_logprobs_sae

In [40]:
shot_logprobs_sae[0][18]

In [41]:
k = 18

weights, recon = grad_pursuit(tv, sae["W_dec"], k, pos_only=True)
w, i = jax.lax.top_k(jnp.abs(weights), k)

i

In [38]:
w

In [None]:
#37312

In [44]:
pairs = tasks[task]
pairs = [list(x) for x in pairs.items()]
dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=101, prepend_space=task.startswith("algo"))


tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
inputs = tokenized_to_inputs(
    **tokenized
)

# weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

recon = sae["W_dec"][27215] * 20

act_add = add_vector(
    llama, recon.astype('bfloat16'), layer, scale=1.0, position="last"
)

logits = act_add(inputs)

get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True).mean()

less snow inst train st car t window concert the air new guitar vis table unique take a sal there actor sal higher emer per sal unique inst feu football library light river wine higher any time flight sal football milk less river су concert sa air higher service sal k summer milk meat football no name name milk tor ly sal plan less
у с университе мет сту ма вра ок кон те а ве ги ви сто с га су са по филь пи не ско ба сотруд день университе роман т би со ре ви не от время по профес т га у ре де кон са само не ус день пе ве мо г спо на час город мо би га це п у


In [26]:
dataset.completions

In [17]:
def prepare_inputs(dataset: ICLDataset):
    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )
    
    tokens = tokenized["input_ids"]

    return inputs, tokens

In [18]:
task_names = list(tasks.keys())
task_names = [x for x in task_names if x.startswith("algo")]
# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 64, 256

In [24]:
results = {}

for task in tqdm(task_names):
    results[task] = []
    
    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=10, prepend_space=task.startswith("algo"))

    clean_inputs, clean_tokens = prepare_inputs(dataset)

    _, resids = get_resids_call(clean_inputs)

    mask = clean_inputs.tokens == 1599
    mask = mask.unwrap("batch", "seq")

    if task.startswith("algo"):
        mask = jnp.roll(mask, 1, axis=1)

    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=11, prepend_space=task.startswith("algo"))

    add_inputs, add_tokens = prepare_inputs(dataset)

    for layer in trange(10, 22):
        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[mask]
        tv = tv.mean(axis=0)
        tv = tv.astype('bfloat16')

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        logits = act_add(add_inputs)

        diff = get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False, extra_space=task.startswith("algo"))

        results[task].append(diff.mean())
        
        

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

In [25]:
for k in results:
    print(
        k, np.argmax([float(x) for x in results[k]]) + 10, max(results[k])
    )

algo_max 21 -0.515625
algo_min 21 -0.945312
algo_last 20 -0.162109
algo_first 13 -0.828125
algo_sum 17 -1.49219
algo_most_common 14 -1.15625


In [31]:
for k in results:
    print(
        k, np.argmax([float(x) for x in results[k]]) + 10, max(results[k])
    )

location_continent 21 -1.09375
football_player_position 21 -3.76562
location_religion 21 -1.09375
location_language 20 -0.457031
person_profession 21 -1.30469
location_country 21 -1.44531
country_capital 18 -1.25781
person_language 18 -0.287109
singular_plural 21 -0.261719
present_simple_past_simple 20 -0.341797
antonyms 14 -0.75
plural_singular 20 -0.133789
present_simple_past_perfect 18 -1.25781
present_simple_gerund 20 -0.171875
en_it 18 -1.29688
it_en 14 -1.75781
en_fr 18 -1.39062
en_es 18 -1.21875
fr_en 17 -1.25781
es_en 17 -1.09375


In [22]:
# for task in tqdm(task_names):

task = "algo_last"

results[task] = []

pairs = tasks[task]
pairs = [list(x) for x in pairs.items()]
dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=10, prepend_space=task.startswith("algo"))

clean_inputs, clean_tokens = prepare_inputs(dataset)

_, resids = get_resids_call(clean_inputs)

mask = clean_inputs.tokens == 1599
mask = mask.unwrap("batch", "seq")

mask = jnp.roll(mask, 1, axis=1)

dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=11, prepend_space=task.startswith("algo"))

add_inputs, add_tokens = prepare_inputs(dataset)

for layer in trange(10, 22):
    tv = resids[layer].value.unwrap("batch", "seq", "embedding")[mask]
    tv = tv.mean(axis=0)
    tv = tv.astype('bfloat16')

    act_add = add_vector(
        llama, tv, layer, scale=2.0, position="last"
    )

    logits = act_add(add_inputs)

    diff = get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False, extra_space=task.startswith("algo"))

    results[task].append(diff.mean())

  0%|          | 0/12 [00:00<?, ?it/s]

In [23]:
results[task]

In [10]:
from task_vector_utils import logprob_loss, ICLRunner

NameError: name 'batch_size' is not defined