In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:

pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

In [3]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="auto")
from micrlhf.sampling import sample
from transformers import AutoTokenizer
import jax
# tokenizer = load_tokenizer(filename)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

In [5]:
!git clone https://github.com/roeehendel/icl_task_vectors data/itv
import glob
import json
import os
tasks = {}
for g in glob.glob("data/itv/data/**/*.json"):
    tasks[os.path.basename(g).partition(".")[0]] = json.load(open(g))

fatal: destination path 'data/itv' already exists and is not an empty directory.


  pid, fd = os.forkpty()


In [6]:
tasks.keys()

In [7]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [8]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import random
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [27]:
from typing import List

class ICLSequence:
    '''
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    '''
    def __init__(self, word_pairs: List[List[str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    # def prompt(self):
    #     '''Returns the prompt, which contains all but the second element in the last word pair.'''
    #     p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
    #     return p[:-len(self.completion())]

    def prompt(self):
        '''Returns the prompt, which contains all but the second element in the last word pair.'''
        p = "\n".join([f"{x} -> {y}" for x, y in self.word_pairs])
        return p[:-len(self.completion())-1]

    def completion(self):
        '''Returns the second element in the last word pair (with padded space).'''
        return "" + self.y[-1]

    def __str__(self):
        '''Prints a readable string representation of the prompt & completion (indep of template).'''
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(", ")


word_list = [["hot", "cold"], ["yes", "no"], ["in", "out"], ["up", "down"]]
seq = ICLSequence(word_list)

print("Tuple-representation of the sequence:")
print(seq)
print("\nActual prompt, which will be fed into the model:")
print(seq.prompt())

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
hot -> cold
yes -> no
in -> out
up ->


In [28]:
class ICLDataset:
    '''
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    '''

    def __init__(
        self,
        word_pairs: List[List[str]],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        seed: int = 0,
        corrupted: bool = False,
    ):
        assert n_prepended+1 <= len(word_pairs), "Not enough antonym pairs in dataset to create prompt."

        self.word_pairs = word_pairs
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `ICLSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(len(self.word_pairs), n_prepended+1, replace=False)
            random_orders = np.random.choice([1, -1], n_prepended+1)
            if not(bidirectional): random_orders[:] = 1
            word_pairs = [self.word_pairs[pair][::order] for pair, order in zip(random_pairs, random_orders)]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = ICLSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        '''Creates a corrupted version of the dataset (with same random seed).'''
        return ICLDataset(self.word_pairs, self.size, self.n_prepended, self.bidirectional, corrupted=True, seed=self.seed)

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

In [29]:
def generate_task_prompt(task, n_shots):
    prompt = "<user>Here is a rule:\n{}"
    examples = []

    for s, t in random.sample(list(tasks[task].items()), n_shots):
        examples.append(f"{s} -> {t}")
    prompt = prompt.format("\n".join(examples))

    # print(prompt)

    return prompt

def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [30]:
from tqdm import tqdm

datasets_dict = {}

for task, pairs in tqdm(tasks.items()):
  pairs = [list(x) for x in pairs.items()]
  dataset = ICLDataset(pairs, size=20, n_prepended=10, bidirectional=False, seed=0)

  prompts = dataset

  # ld = logit_diff(dataset, model.run_with_saes(prompts, saes=[sae]))

  # prompts_dict[task] = [x for x, y in zip(dataset.prompts, ld) if y >= -1e-6]
  datasets_dict[task] = dataset


100%|██████████| 20/20 [00:00<00:00, 381.98it/s]


In [31]:
zero_datasets_dict = {}

for task, pairs in tqdm(tasks.items()):
  pairs = [list(x) for x in pairs.items()]
  dataset = ICLDataset(pairs, size=20, n_prepended=0, bidirectional=False, seed=1)

  prompts = dataset

  zero_datasets_dict[task] = dataset

100%|██████████| 20/20 [00:00<00:00, 430.87it/s]


In [32]:
def generate_task_inputs_old(task, n_shots, batch_size, max_length=128, seed=0):
    random.seed(seed)

    texts = [generate_task_prompt(task, n_shots) for _ in range(batch_size)]
    tokenized = tokenizer.batch_encode_plus(texts, padding="longest", max_length=max_length, truncation=True, return_tensors="np")

    inputs = tokenized_to_inputs(
        **tokenized
    )

    return inputs, tokenized

In [33]:
def get_logprob_diff_old(logits: jnp.ndarray, tokens: jnp.ndarray):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    
    last_arrows = np.repeat(np.arange(tokens.shape[1])[None, :], tokens.shape[0], axis=0) * (tokens == 1599)
    last_arrows = last_arrows.max(axis=-1)

    answer_logprobs = jnp.take_along_axis(logprobs, last_arrows[:, None, None], axis=-1).squeeze()

    target_tokens = jnp.take_along_axis(tokens, last_arrows[:, None], axis=-1).squeeze()
    target_logprobs = jnp.take_along_axis(logprobs, target_tokens[:, None], axis=-1).squeeze()

    return target_logprobs - answer_logprobs.max(axis=-1)


In [40]:
prompt = "<user>Here is a rule:\n{}"

In [19]:
inputs, tokenized = generate_task_inputs_old("antonyms", 20, 12, 128, seed=1)       
logits, resids = get_resids_call(inputs)
tokens = tokenized["input_ids"]

logits = logits.unwrap("batch", "seq", "vocabulary")


In [34]:
dataset: ICLDataset = datasets_dict["antonyms"]
tokenized = tokenizer.batch_encode_plus(dataset.prompts, padding="longest", max_length=128, truncation=True, return_tensors="np")
inputs = tokenized_to_inputs(
    **tokenized
)

logits, resids = get_resids_call(inputs)

tokens = tokenized["input_ids"]

In [35]:
def get_logprob_diff(logits: jnp.ndarray, completions: List[str], print_results=False):
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_tokens = [x[1] for x in tokenizer(completions)["input_ids"]]
    target_tokens = jnp.asarray(target_tokens)
    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return target_logprobs - answer_logprobs.max(axis=-1)


In [36]:


get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)

lost dead active far off fun small mismatch rural sharp weak h war retre cold wrong even stop bad rough
lost dead active far off fun tiny sick rural sharp weak h war retre cold wrong even finish bad rough


In [58]:

task_names = [
    "antonyms"
]
layer = 16

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 10, 20, 128
for task in tqdm(task_names):
    shot_logprobs_orig = [[] for _ in range(2)]
    shot_logprobs_added = [[] for _ in range(2)]
    shot_logprobs_zero = [[] for _ in range(2)]
    for _ in trange(1):
        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=0)

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )
        
        logits, resids = get_resids_call(inputs)
    
        tokens = tokenized["input_ids"]

        shot_logprobs_orig[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_orig[1].append(
            shot_logprobs_orig[0][-1] == 0.
        )

        mask = inputs.tokens == 1599
        mask = mask.unwrap("batch", "seq")

        tv = resids[layer].value.unwrap("batch", "seq", "embedding")[
            mask
        ]

        tv = tv.mean(
            axis=0
        )

        print(
            tv.mean(), tv.std()
        )

        act_add = add_vector(
            llama, tv, layer, scale=2.0, position="last"
        )

        pairs = tasks[task]
        pairs = [list(x) for x in pairs.items()]
        dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=1)


        print(
            dataset.prompts, dataset.completions
        )

        tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
        inputs = tokenized_to_inputs(
            **tokenized
        )

        logits = act_add(inputs)

        tokens = tokenized["input_ids"]

        shot_logprobs_added[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_added[1].append(
            shot_logprobs_added[0][-1] == 0.
        )

        logits, _ = get_resids_call(inputs)
        
        shot_logprobs_zero[0].append(
            get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), dataset.completions, print_results=True)
        )

        shot_logprobs_zero[1].append(
            shot_logprobs_zero[0][-1] == 0.
        )
        
    print(f"orig: {shot_logprobs_orig}")
    print(f"zero: {shot_logprobs_zero}")
    print(f"added: {shot_logprobs_added}")


    shot_logprobs_orig = list(map(np.mean, shot_logprobs_orig))
    shot_logprobs_zero = list(map(np.mean, shot_logprobs_zero))
    shot_logprobs_added = list(map(np.mean, shot_logprobs_added))


print(f"orig: {shot_logprobs_orig}")
print(f"zero: {shot_logprobs_zero}")
print(f"added: {shot_logprobs_added}")

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

lost dead active far off fun small miss rural sharp weak h war retre cold left even stop bad rough
lost dead active far off fun tiny sick rural sharp weak h war retre cold wrong even finish bad rough
(220, 3072)
-0.00166321 0.886719
['lazy ->', 'old ->', 'pretty ->', 'heavy ->', 'cheap ->', 'deep ->', 'less ->', 'lazy ->', 'wild ->', 'bitter ->', 'cheap ->', 'hot ->', 'alive ->', 'horizontal ->', 'pro ->', 'proud ->', 'exit ->', 'divide ->', 'come ->', 'rare ->'] ['active', 'young', 'ugly', 'light', 'expensive', 'shallow', 'more', 'active', 'tame', 'sweet', 'expensive', 'cold', 'dead', 'vertical', 'con', 'humble', 'enter', 'unite', 'go', 'common']


100%|██████████| 1/1 [00:04<00:00,  4.95s/it]

lazy new ugly light expensive shall more lazy wild sweet expensive cold alive vertical positive happy exit divide go rare
active young ugly light expensive shall more active t sweet expensive cold dead vertical con hum enter un go common
lazy old pretty heavy cheap deep less lazy wild bitter cheap hot alive horizontal pro ar good / come r
active young ugly light expensive shall more active t sweet expensive cold dead vertical con hum enter un go common
orig: [[Array([0, 0, 0, 0, 0, 0, -0.875, -6.375, 0, 0, 0, 0, 0, 0, 0, -0.125, 0,
       -2, 0, 0], dtype=bfloat16)], [Array([ True,  True,  True,  True,  True,  True, False, False,  True,
        True,  True,  True,  True,  True,  True, False,  True, False,
        True,  True], dtype=bool)]]
zero: [[Array([-6.125, -9.25, -2.3125, -3.125, -1.32031, -7.5625, -7.25, -6.125,
       -5.125, -4.125, -1.32031, -6.5625, -1.25, -2.375, -6.875, -2.625,
       -5, -8.0625, -6.3125, -10.6875], dtype=bfloat16)], [Array([False, False, False, False, F


