In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:

pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [3]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")
from micrlhf.sampling import sample
from transformers import AutoTokenizer
import jax
# tokenizer = load_tokenizer(filename)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [5]:
!git clone https://github.com/roeehendel/icl_task_vectors data/itv
import glob
import json
import os
tasks = {}
for g in glob.glob("data/itv/data/**/*.json"):
    tasks[os.path.basename(g).partition(".")[0]] = json.load(open(g))

  pid, fd = os.forkpty()


fatal: destination path 'data/itv' already exists and is not an empty directory.


In [6]:
tasks.keys()

In [7]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [8]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [9]:
from typing import List

class ICLSequence:
    '''
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    '''
    def __init__(self, word_pairs: List[List[str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    # def prompt(self):
    #     '''Returns the prompt, which contains all but the second element in the last word pair.'''
    #     p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
    #     return p[:-len(self.completion())]

    def prompt(self):
        '''Returns the prompt, which contains all but the second element in the last word pair.'''
        p = ", ".join([f"{x} -> {y}" for x, y in self.word_pairs])
        return p[:-len(self.completion())-1]

    def completion(self):
        '''Returns the second element in the last word pair (with padded space).'''
        return "" + self.y[-1]

    def __str__(self):
        '''Prints a readable string representation of the prompt & completion (indep of template).'''
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(", ")


word_list = [["hot", "cold"], ["yes", "no"], ["in", "out"], ["up", "down"]]
seq = ICLSequence(word_list)

print("Tuple-representation of the sequence:")
print(seq)
print("\nActual prompt, which will be fed into the model:")
print(seq.prompt())

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
hot -> cold, yes -> no, in -> out, up ->


In [10]:
class ICLDataset:
    '''
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    '''

    def __init__(
        self,
        word_pairs: List[List[str]],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        seed: int = 0,
        corrupted: bool = False,
    ):
        assert n_prepended+1 <= len(word_pairs), "Not enough antonym pairs in dataset to create prompt."

        self.word_pairs = word_pairs
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `ICLSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(len(self.word_pairs), n_prepended+1, replace=False)
            random_orders = np.random.choice([1, -1], n_prepended+1)
            if not(bidirectional): random_orders[:] = 1
            word_pairs = [self.word_pairs[pair][::order] for pair, order in zip(random_pairs, random_orders)]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = ICLSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        '''Creates a corrupted version of the dataset (with same random seed).'''
        return ICLDataset(self.word_pairs, self.size, self.n_prepended, self.bidirectional, corrupted=True, seed=self.seed)

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

In [11]:
def generate_task_prompt(task, n_shots):
    prompt = "<user>Follow the pattern\n{}"
    examples = []

    for s, t in random.sample(list(tasks[task].items()), n_shots):
        examples.append(f"{s} -> {t}")
    prompt = prompt.format("\n".join(examples))

    # print(prompt)

    return prompt

def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [12]:
def generate_task_inputs_old(task, n_shots, batch_size, max_length=128, seed=0):
    random.seed(seed)

    texts = [generate_task_prompt(task, n_shots) for _ in range(batch_size)]
    tokenized = tokenizer.batch_encode_plus(texts, padding="longest", max_length=max_length, truncation=True, return_tensors="np")

    inputs = tokenized_to_inputs(
        **tokenized
    )

    return inputs, tokenized

In [13]:
prompt = "<user>Follow the pattern\n{}"

In [14]:
def get_logprob_diff(logits: jnp.ndarray, completions: List[str], print_results=False):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_tokens = [x[1] for x in tokenizer(completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)
    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return target_logprobs - answer_logprobs.max(axis=-1)


In [16]:
task_names = [
    "en_it"
]
layer = 18
n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 64, 256

In [68]:
tvs = []

for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_zero = [[] for _ in range(2)]
    for seed in trange(n_seeds):
        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=seed)

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )
        
        logits, resids = get_resids_call(inputs)
    
        tokens = tokenized["input_ids"]

        shot_logprobs_orig[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_orig[1].append(
            shot_logprobs_orig[0][-1] == 0.
        )

        mask = inputs.tokens == 1599
        mask = mask.unwrap("batch", "seq")

        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
            mask
        ]

        tv = tv.mean(
            axis=0
        )

        tvs.append(tv)

        print(
            tv.mean(), tv.std()
        )

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=seed+1)


        print(
            dataset.prompts, dataset.completions
        )

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )

        logits = act_add(inputs)

        tokens = tokenized["input_ids"]

        shot_logprobs_added[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_added[1].append(
            shot_logprobs_added[0][-1] == 0.
        )

        logits, _ = get_resids_call(inputs)
        
        shot_logprobs_zero[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_zero[1].append(
            shot_logprobs_zero[0][-1] == 0.
        )
        
    print(f"orig: {shot_logprobs_orig}")
    print(f"zero: {shot_logprobs_zero}")
    print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = [list(map(np.mean, x)) for x in shot_logprobs_orig]
    shot_logprobs_zero = [list(map(np.mean, x)) for x in shot_logprobs_zero]
    shot_logprobs_added = [list(map(np.mean, x)) for x in shot_logprobs_added]


print(f"orig: {shot_logprobs_orig}")
print(f"zero: {shot_logprobs_zero}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

pre situ sic s dest cres inter reg se un le t sens propri in tag prepar serv ind risult se m coin f c diff città lu come ass cart chiam
s situ sic s gi cres totale disco se gi le t sens propri in tag prepar serv ind risult se pare coin f stesso diff citt lu ven ass cart chiam
-0.00430298 1.09375
['legal ->', 'along ->', 'window ->', 'information ->', 'over ->', 'field ->', 'current ->', 'better ->', 'power ->', 'message ->', 'reduce ->', 'behavior ->', 'place ->', 'security ->', 'hell ->', 'cell ->', 'media ->', 'club ->', 'live ->', 'point ->', 'throw ->', 'sport ->', 'so ->', 'door ->', 'color ->', 'daughter ->', 'run ->', 'pretty ->', 'rule ->', 'top ->', 'mind ->', 'matter ->'] ['legale', 'lungo', 'finestra', 'informazione', 'sopra', 'campo', 'attuale', 'meglio', 'potenza', 'messaggio', 'ridurre', 'comportamento', 'luogo', 'sicurezza', 'inferno', 'cellula', 'media', 'club', 'abitare', 'punto', 'gettare', 'sport', 'così', 'porta', 'colore', 'figlia', 'correre', 'bello', 'regola', 's

100%|██████████| 1/1 [02:06<00:00, 126.93s/it]

orig: [[Array(-1.95312, dtype=bfloat16), Array(-1.85938, dtype=bfloat16), Array(-2.20312, dtype=bfloat16), Array(-2.20312, dtype=bfloat16), Array(-2.20312, dtype=bfloat16), Array(-2.09375, dtype=bfloat16), Array(-2.09375, dtype=bfloat16), Array(-1.91406, dtype=bfloat16), Array(-1.625, dtype=bfloat16), Array(-1.65625, dtype=bfloat16)], [Array(0.71875, dtype=float32), Array(0.75, dtype=float32), Array(0.71875, dtype=float32), Array(0.71875, dtype=float32), Array(0.71875, dtype=float32), Array(0.75, dtype=float32), Array(0.75, dtype=float32), Array(0.78125, dtype=float32), Array(0.78125, dtype=float32), Array(0.75, dtype=float32)]]
zero: [[Array(-8.0625, dtype=bfloat16), Array(-8.0625, dtype=bfloat16), Array(-7.84375, dtype=bfloat16), Array(-7.6875, dtype=bfloat16), Array(-7.9375, dtype=bfloat16), Array(-7.90625, dtype=bfloat16), Array(-8, dtype=bfloat16), Array(-7.9375, dtype=bfloat16), Array(-7.5, dtype=bfloat16), Array(-7.53125, dtype=bfloat16)], [Array(0.125, dtype=float32), Array(0.1




In [64]:
a

  0%|          | 0/1 [00:00<?, ?it/s]

-0.0090332 0.875
along window information over field current better power message reduce behavior place security hell cell media club live point throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00256348 0.898438
along window information over field current better power message reduce behavior place security hel cell media club live point throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00260925 0.957031
along window inform over campo current migli pot mess rid behavior place security h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00460815 1.09375
along fen inform over campo current migli pot mess rid comport pl seg h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
-0.00552368 1.17188


100%|██████████| 1/1 [00:22<00:00, 22.50s/it]

along window inform over campo current migli pot mess rid behavior place sic h cell media club v punto throw
lungo fin inform sop campo att meg pot mess rid comport lu sic infer cell media club abit punto get
orig: [[Array(-1.46094, dtype=bfloat16)], [Array(0.8, dtype=float32)]]
added: [[Array(-4.96875, dtype=bfloat16), Array(-4.75, dtype=bfloat16), Array(-1.53125, dtype=bfloat16), Array(-1.40625, dtype=bfloat16), Array(-1.72656, dtype=bfloat16)], [Array(0.15, dtype=float32), Array(0.15, dtype=float32), Array(0.45000002, dtype=float32), Array(0.6, dtype=float32), Array(0.55, dtype=float32)]]





In [61]:
shot_logprobs_added

In [20]:

from micrlhf.utils.load_sae import get_sae
sae = get_sae(layer, 6)

--2024-05-20 16:23:03--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l17-test-run-6-4.52E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.95, 108.156.211.90, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/1623d8da38be3171fcc8516a4cbe9fdb80e3d77e370aa5690895697649d688f3?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716481383&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjQ4MTM4M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvMTYyM2Q4ZG

In [21]:
from micrlhf.utils.ito import grad_pursuit

In [45]:
k = 10

weights, recon = grad_pursuit(tvs[0], sae["W_dec"], k, pos_only=True)
w, i = jax.lax.top_k(jnp.abs(weights), k)

i

In [46]:
jnp.linalg.norm(tvs[0] - recon)

In [36]:
k = 5

[
    jax.lax.top_k(jnp.abs(grad_pursuit(x, sae["W_dec"], k, pos_only=True)[0]), k) for x in tvs
]

In [76]:

task_names = [
    "en_fr"
]

layer = 18

for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_sae = [[] for _ in range(2)]

    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=10)

    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )
    
    logits, resids = get_resids_call(inputs)

    tokens = tokenized["input_ids"]

    shot_logprobs_orig[0].append(
        get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False)
    )

    shot_logprobs_orig[1].append(
        shot_logprobs_orig[0][-1] == 0.
    )

    mask = inputs.tokens == 1599
    mask = mask.unwrap("batch", "seq")

    tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
        mask
    ]

    tv = tv.mean(
        axis=0
    )

    print(
        tv.mean(), tv.std()
    )

    act_add = add_vector(
        llama, tv, layer, scale=2.0, position="last"
    )

    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=11)


    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )

    logits = act_add(inputs)

    tokens = tokenized["input_ids"]

    shot_logprobs_added[0].append(
        get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
    )

    shot_logprobs_added[1].append(
        shot_logprobs_added[0][-1] == 0.
    )

    for k in range(5, 30):
        weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

        act_add = add_vector(
            llama, recon.astype('bfloat16'), layer, scale=1.0, position="last"
        )

        logits = act_add(inputs)
        
        shot_logprobs_sae[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_sae[1].append(
            shot_logprobs_sae[0][-1] == 0.
        )
        
    print(f"orig: {shot_logprobs_orig}")
    print(f"sae: {shot_logprobs_sae}")
    print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = [list(map(np.mean, x)) for x in shot_logprobs_orig]
    shot_logprobs_sae = [list(map(np.mean, x)) for x in shot_logprobs_sae]
    shot_logprobs_added = [list(map(np.mean, x)) for x in shot_logprobs_added]


print(f"orig: {shot_logprobs_orig}")
print(f"sae: {shot_logprobs_sae}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

0.000831604 1.08594
sport r chanson face vers check if  half yet fund press if deux groupe if movie allow benefit office raison or bon compagnie ang tr h sept comm ent b tr
sport cours chanson vis vers ch si chaque demi encore fond pres si deux gr si film autor av b raison ou bon entre américain support sa sept com soit par ess
sport run song face the check ( each half yet fund press ( two group ( movie allow benefit office < or good company amer task her seven committee either by try
sport cours chanson vis vers ch si chaque demi encore fond pres si deux gr si film autor av b raison ou bon entre américain support sa sept com soit par ess
sport run song face < check ( each half yet fund press ( two group ( movie allow benefit office < or good company amer task her seven committee either by try
sport cours chanson vis vers ch si chaque demi encore fond pres si deux gr si film autor av b raison ou bon entre américain support sa sept com soit par ess
sport run song face _ check ( one half

100%|██████████| 1/1 [02:03<00:00, 123.83s/it]

sport run song face towards check if each half yet fund press if two group if movie allow benefit office reason or good company amer follow her seven committee either by try
sport cours chanson vis vers ch si chaque demi encore fond pres si deux gr si film autor av b raison ou bon entre américain support sa sept com soit par ess
orig: [[Array([0, -4.125, -2.125, -6.25, 0, -0.375, 0, 0, 0, -1.125, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, -16.5, 0, 0, 0, -3.25, 0, 0, 0, -3.625],      dtype=bfloat16)], [Array([ True, False, False, False,  True, False,  True,  True,  True,
       False,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True, False,  True,  True,  True,
       False,  True,  True,  True, False], dtype=bool)]]
sae: [[Array([0, -13.625, -7.4375, -8.625, -6, -2.25, -8.6875, -9.3125, -9,
       -8.6875, -4.5, -3, -8.6875, -8.5625, -4.75, -8.6875, -3.375, -7.75,
       -8.0625, -5.25, -8.875, -5.6875, -6.4375, -11.25, -8.125, -5.625




In [77]:
shot_logprobs_sae

: 

In [34]:
shot_logprobs_sae[0][18]

In [40]:
k = 19

weights, recon = grad_pursuit(tvs[0], sae["W_dec"], k, pos_only=True)
w, i = jax.lax.top_k(jnp.abs(weights), k)

i

In [38]:
w

In [None]:
#37312

In [45]:
pairs = tasks[task]
pairs = [list(x) for x in pairs.items()]
dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=101)


tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
inputs = tokenized_to_inputs(
    **tokenized
)

# weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

recon = sae["W_dec"][37312] * 20

act_add = add_vector(
    llama, recon.astype('bfloat16'), layer, scale=1.0, position="last"
)

logits = act_add(inputs)

get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)

oficial number ab datos ii conoc c población hel dec g h orden tra hel republic cor ok for camb
oficial mes padre información i reun club población inf siglo ad cal ped tra inf republic cor ok por cambio


In [18]:
def prepare_inputs(dataset: ICLDataset):
    tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
    inputs = tokenized_to_inputs(
        **tokenized
    )
    
    tokens = tokenized["input_ids"]

    return inputs, tokens

In [24]:
task_names = list(tasks.keys())
# task_names = ["en_es"]
# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 20, 64, 256

In [25]:
results = {}

for task in tqdm(task_names):
    results[task] = []
    
    pairs = tasks[task]
    pairs = [list(x) for x in pairs.items()]
    dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=10)

    clean_inputs, clean_tokens = prepare_inputs(dataset)

    _, resids = get_resids_call(clean_inputs)

    mask = clean_inputs.tokens == 1599
    mask = mask.unwrap("batch", "seq")

    dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=11)

    add_inputs, add_tokens = prepare_inputs(dataset)

    for layer in trange(10, 22):
        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[mask]
        tv = tv.mean(axis=0)
        tv = tv.astype('bfloat16')

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        logits = act_add(add_inputs)

        diff = get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False)

        results[task].append(diff.mean())
        
        

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

In [31]:
for k in results:
    print(
        k, np.argmax([float(x) for x in results[k]]) + 10, max(results[k])
    )

location_continent 21 -1.09375
football_player_position 21 -3.76562
location_religion 21 -1.09375
location_language 20 -0.457031
person_profession 21 -1.30469
location_country 21 -1.44531
country_capital 18 -1.25781
person_language 18 -0.287109
singular_plural 21 -0.261719
present_simple_past_simple 20 -0.341797
antonyms 14 -0.75
plural_singular 20 -0.133789
present_simple_past_perfect 18 -1.25781
present_simple_gerund 20 -0.171875
en_it 18 -1.29688
it_en 14 -1.75781
en_fr 18 -1.39062
en_es 18 -1.21875
fr_en 17 -1.25781
es_en 17 -1.09375
