In [2]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:

pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [4]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")
from micrlhf.sampling import sample
from transformers import AutoTokenizer
import jax
# tokenizer = load_tokenizer(filename)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [6]:
!git clone https://github.com/roeehendel/icl_task_vectors data/itv
import glob
import json
import os
tasks = {}
for g in glob.glob("data/itv/data/**/*.json"):
    tasks[os.path.basename(g).partition(".")[0]] = json.load(open(g))

fatal: destination path 'data/itv' already exists and is not an empty directory.


  pid, fd = os.forkpty()


In [7]:
tasks.keys()

In [8]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [9]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [10]:
from typing import List

class ICLSequence:
    '''
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    '''
    def __init__(self, word_pairs: List[List[str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    # def prompt(self):
    #     '''Returns the prompt, which contains all but the second element in the last word pair.'''
    #     p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
    #     return p[:-len(self.completion())]

    def prompt(self):
        '''Returns the prompt, which contains all but the second element in the last word pair.'''
        p = ", ".join([f"{x} -> {y}" for x, y in self.word_pairs])
        return p[:-len(self.completion())-1]

    def completion(self):
        '''Returns the second element in the last word pair (with padded space).'''
        return "" + self.y[-1]

    def __str__(self):
        '''Prints a readable string representation of the prompt & completion (indep of template).'''
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(", ")


word_list = [["hot", "cold"], ["yes", "no"], ["in", "out"], ["up", "down"]]
seq = ICLSequence(word_list)

print("Tuple-representation of the sequence:")
print(seq)
print("\nActual prompt, which will be fed into the model:")
print(seq.prompt())

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
hot -> cold, yes -> no, in -> out, up ->


In [11]:
class ICLDataset:
    '''
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    '''

    def __init__(
        self,
        word_pairs: List[List[str]],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        seed: int = 0,
        corrupted: bool = False,
    ):
        assert n_prepended+1 <= len(word_pairs), "Not enough antonym pairs in dataset to create prompt."

        self.word_pairs = word_pairs
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `ICLSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(len(self.word_pairs), n_prepended+1, replace=False)
            random_orders = np.random.choice([1, -1], n_prepended+1)
            if not(bidirectional): random_orders[:] = 1
            word_pairs = [self.word_pairs[pair][::order] for pair, order in zip(random_pairs, random_orders)]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = ICLSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        '''Creates a corrupted version of the dataset (with same random seed).'''
        return ICLDataset(self.word_pairs, self.size, self.n_prepended, self.bidirectional, corrupted=True, seed=self.seed)

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

In [12]:
def generate_task_prompt(task, n_shots):
    prompt = "<user>Follow the pattern\n{}"
    examples = []

    for s, t in random.sample(list(tasks[task].items()), n_shots):
        examples.append(f"{s} -> {t}")
    prompt = prompt.format("\n".join(examples))

    # print(prompt)

    return prompt

def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [13]:
from tqdm import tqdm

datasets_dict = {}

for task, pairs in tqdm(tasks.items()):
  pairs = [list(x) for x in pairs.items()]
  dataset = ICLDataset(pairs, size=20, n_prepended=10, bidirectional=False, seed=0)

  prompts = dataset

  # ld = logit_diff(dataset, model.run_with_saes(prompts, saes=[sae]))

  # prompts_dict[task] = [x for x, y in zip(dataset.prompts, ld) if y >= -1e-6]
  datasets_dict[task] = dataset


100%|██████████| 20/20 [00:00<00:00, 388.44it/s]


In [14]:
zero_datasets_dict = {}

for task, pairs in tqdm(tasks.items()):
  pairs = [list(x) for x in pairs.items()]
  dataset = ICLDataset(pairs, size=20, n_prepended=0, bidirectional=False, seed=1)

  prompts = dataset

  zero_datasets_dict[task] = dataset

100%|██████████| 20/20 [00:00<00:00, 394.89it/s]


In [15]:
def generate_task_inputs_old(task, n_shots, batch_size, max_length=128, seed=0):
    random.seed(seed)

    texts = [generate_task_prompt(task, n_shots) for _ in range(batch_size)]
    tokenized = tokenizer.batch_encode_plus(texts, padding="longest", max_length=max_length, truncation=True, return_tensors="np")

    inputs = tokenized_to_inputs(
        **tokenized
    )

    return inputs, tokenized

In [16]:
def get_logprob_diff_old(logits: jnp.ndarray, tokens: jnp.ndarray):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    
    last_arrows = np.repeat(np.arange(tokens.shape[1])[None, :], tokens.shape[0], axis=0) * (tokens == 1599)
    last_arrows = last_arrows.max(axis=-1)

    answer_logprobs = jnp.take_along_axis(logprobs, last_arrows[:, None, None], axis=-1).squeeze()

    target_tokens = jnp.take_along_axis(tokens, last_arrows[:, None], axis=-1).squeeze()
    target_logprobs = jnp.take_along_axis(logprobs, target_tokens[:, None], axis=-1).squeeze()

    return target_logprobs - answer_logprobs.max(axis=-1)


In [17]:
prompt = "<user>Follow the pattern\n{}"

In [18]:
def get_logprob_diff(logits: jnp.ndarray, completions: List[str], print_results=False):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_tokens = [x[1] for x in tokenizer(completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)
    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return target_logprobs - answer_logprobs.max(axis=-1)


In [24]:

task_names = [
    "en_es"
]
layer = 17

tvs = []

n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 10, 20, 128
for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_zero = [[] for _ in range(2)]
    for seed in trange(n_seeds):
        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=seed)

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )
        
        logits, resids = get_resids_call(inputs)
    
        tokens = tokenized["input_ids"]

        shot_logprobs_orig[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_orig[1].append(
            shot_logprobs_orig[0][-1] == 0.
        )

        mask = inputs.tokens == 1599
        mask = mask.unwrap("batch", "seq")

        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
            mask
        ]

        tv = tv.mean(
            axis=0
        )

        tvs.append(tv)

        print(
            tv.mean(), tv.std()
        )

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=seed+1)


        print(
            dataset.prompts, dataset.completions
        )

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )

        logits = act_add(inputs)

        tokens = tokenized["input_ids"]

        shot_logprobs_added[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_added[1].append(
            shot_logprobs_added[0][-1] == 0.
        )

        logits, _ = get_resids_call(inputs)
        
        shot_logprobs_zero[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_zero[1].append(
            shot_logprobs_zero[0][-1] == 0.
        )
        
    print(f"orig: {shot_logprobs_orig}")
    print(f"zero: {shot_logprobs_zero}")
    print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = [list(map(np.mean, x)) for x in shot_logprobs_orig]
    shot_logprobs_zero = [list(map(np.mean, x)) for x in shot_logprobs_zero]
    shot_logprobs_added = [list(map(np.mean, x)) for x in shot_logprobs_added]


print(f"orig: {shot_logprobs_orig}")
print(f"zero: {shot_logprobs_zero}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

j termin hab comenz perfect oficial min grupo di material cuatro simple tipo cal bu v son vida res región
j termin hab com perfect oficial min grupo di material cuatro simple am cal bien vot son vida res región
-0.00028801 0.96875
['right ->', 'so ->', 'upon ->', 'practice ->', 'detail ->', 'financial ->', 'change ->', 'response ->', 'interesting ->', 'worry ->', 'money ->', 'hear ->', 'general ->', 'morning ->', 'government ->', 'course ->', 'letter ->', 'oil ->', 'future ->', 'kill ->'] ['derecho', 'entonces', 'al', 'práctica', 'detalle', 'financiero', 'cambio', 'respuesta', 'interesante', 'preocuparse', 'dinero', 'escuchar', 'general', 'mañana', 'gobierno', 'curso', 'carta', 'aceite', 'futuro', 'matar']
right s up pract det financi cambio res inter wor mon hear general ma gobierno cur letter oil fut kill
dere entonces al pr det financi cambio res inter pre din esc general ma gobierno cur cart ace fut mat
left so up practice detail financial ch question interesting wor money here gen

100%|██████████| 1/1 [00:46<00:00, 46.31s/it]

wor mon hear general ma gobierno cur letter oil fut kill cut ries tiempo estud vest organiz simple población méd
pre din esc general ma gobierno cur cart ace fut mat cort ries h estud lle organiz sim población méd
wor money here general afternoon government courses number oil fut k cut ris time student wear organization simply population doctor
pre din esc general ma gobierno cur cart ace fut mat cort ries h estud lle organiz sim población méd
orig: [[Array([0, 0, 0, -3.5, 0, 0, 0, 0, 0, 0, 0, 0, -8.625, 0, -3.625, -4.5, 0,
       0, 0, 0], dtype=bfloat16), Array([0, 0, -3.5, 0, 0, 0, 0, 0, 0, 0, 0, -8.625, 0, -3.625, -4.5, 0, 0,
       0, 0, -5.625], dtype=bfloat16), Array([0, -3.5, 0, 0, 0, 0, 0, 0, 0, 0, -8.625, 0, -3.625, -4.5, 0, 0, 0,
       0, -5.625, -5.5], dtype=bfloat16), Array([-3.5, 0, 0, 0, 0, 0, 0, 0, 0, -8.625, 0, -3.625, -4.5, 0, 0, 0, 0,
       -5.625, -5.5, -3.25], dtype=bfloat16), Array([0, 0, 0, 0, 0, 0, 0, 0, -8.625, 0, -3.625, -4.5, 0, 0, 0, 0,
       -5.625, -5.5




In [20]:

from micrlhf.utils.load_sae import get_sae
sae = get_sae(layer, 6)

--2024-05-20 15:37:56--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l17-test-run-6-4.52E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.90, 108.156.211.125, 108.156.211.95, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.90|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/1623d8da38be3171fcc8516a4cbe9fdb80e3d77e370aa5690895697649d688f3?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716478676&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjQ3ODY3Nn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvMTYyM2Q4ZGE

In [31]:
from micrlhf.utils.ito import grad_pursuit

In [45]:
k = 10

weights, recon = grad_pursuit(tvs[0], sae["W_dec"], k, pos_only=True)
w, i = jax.lax.top_k(jnp.abs(weights), k)

i

In [46]:
jnp.linalg.norm(tvs[0] - recon)

In [36]:
k = 5

[
    jax.lax.top_k(jnp.abs(grad_pursuit(x, sae["W_dec"], k, pos_only=True)[0]), k) for x in tvs
]

In [50]:

task_names = [
    "antonyms"
]

for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_sae = [[] for _ in range(2)]
    for k in range(2, 10):
        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=100)

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )
        
        logits, resids = get_resids_call(inputs)
    
        tokens = tokenized["input_ids"]

        shot_logprobs_orig[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_orig[1].append(
            shot_logprobs_orig[0][-1] == 0.
        )

        mask = inputs.tokens == 1599
        mask = mask.unwrap("batch", "seq")

        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
            mask
        ]

        tv = tv.mean(
            axis=0
        )

        print(
            tv.mean(), tv.std()
        )

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=101)


        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )

        logits = act_add(inputs)

        tokens = tokenized["input_ids"]

        shot_logprobs_added[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False)
        )

        shot_logprobs_added[1].append(
            shot_logprobs_added[0][-1] == 0.
        )

        weights, recon = grad_pursuit(tv, sae["W_dec"], k, pos_only=True)

        act_add = add_vector(
            llama, recon.astype('bfloat16'), layer, scale=2.0, position="last"
        )

        logits = act_add(inputs)
        
        shot_logprobs_sae[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=False)
        )

        shot_logprobs_sae[1].append(
            shot_logprobs_sae[0][-1] == 0.
        )
        
    print(f"orig: {shot_logprobs_orig}")
    print(f"sae: {shot_logprobs_sae}")
    print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = [list(map(np.mean, x)) for x in shot_logprobs_orig]
    shot_logprobs_sae = [list(map(np.mean, x)) for x in shot_logprobs_sae]
    shot_logprobs_added = [list(map(np.mean, x)) for x in shot_logprobs_added]


print(f"orig: {shot_logprobs_orig}")
print(f"sae: {shot_logprobs_sae}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

cur hum vertical contract down urban lose sad wrong rough pol private under un vertical loose far worst junior loose
cur hum vertical shr down urban lose sad wrong rough pol private under sick vertical loose far worst junior loose
0.00448608 0.960938
cur hum vertical contract down urban lose sad wrong rough pol private under un vertical loose far worst junior loose
cur hum vertical shr down urban lose sad wrong rough pol private under sick vertical loose far worst junior loose
0.00448608 0.960938
cur hum vertical contract down urban lose sad wrong rough pol private under un vertical loose far worst junior loose
cur hum vertical shr down urban lose sad wrong rough pol private under sick vertical loose far worst junior loose
0.00448608 0.960938
cur hum vertical contract down urban lose sad wrong rough pol private under un vertical loose far worst junior loose
cur hum vertical shr down urban lose sad wrong rough pol private under sick vertical loose far worst junior loose
0.00448608 0.960

100%|██████████| 1/1 [01:15<00:00, 75.64s/it]

orig: [[Array([0, 0, 0, -2.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4.75, 0, 0, 0, 0, 0,
       0], dtype=bfloat16), Array([0, 0, 0, -2.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4.75, 0, 0, 0, 0, 0,
       0], dtype=bfloat16), Array([0, 0, 0, -2.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4.75, 0, 0, 0, 0, 0,
       0], dtype=bfloat16), Array([0, 0, 0, -2.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4.75, 0, 0, 0, 0, 0,
       0], dtype=bfloat16), Array([0, 0, 0, -2.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4.75, 0, 0, 0, 0, 0,
       0], dtype=bfloat16), Array([0, 0, 0, -2.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4.75, 0, 0, 0, 0, 0,
       0], dtype=bfloat16), Array([0, 0, 0, -2.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4.75, 0, 0, 0, 0, 0,
       0], dtype=bfloat16), Array([0, 0, 0, -2.125, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4.75, 0, 0, 0, 0, 0,
       0], dtype=bfloat16)], [Array([ True,  True,  True, False,  True,  True,  True,  True,  True,
        True,  True,  True,  True, False,  True,  True,  True,  True,
        True,  True], dtype=bool), Array([




In [49]:
shot_logprobs_sae