In [1]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [2]:
import random
import jax
import optax

import jax.numpy as jnp
import numpy as np

from matplotlib import pyplot as plt
from tqdm.auto import tqdm, trange
from penzai.data_effects.side_output import SideOutputValue
from micrlhf.utils.activation_manipulation import add_vector

In [3]:
filename = "models/phi-3-16.gguf"
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained(filename, device_map="tpu:0")

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
!git clone https://github.com/roeehendel/icl_task_vectors data/itv
import glob
import json
import os
tasks = {}
for g in glob.glob("data/itv/data/**/*.json"):
    tasks[os.path.basename(g).partition(".")[0]] = json.load(open(g))

  pid, fd = os.forkpty()


fatal: destination path 'data/itv' already exists and is not an empty directory.


In [6]:
from micrlhf.llama import LlamaBlock
from micrlhf.sampling import sample, jit_wrapper
get_resids = llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(lambda i, x:
    pz.nn.Sequential([
        pz.de.TellIntermediate.from_config(tag=f"resid_pre_{i}"),
        x
    ])
)
get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
get_resids_call = jit_wrapper.Jitted(get_resids)

In [7]:
from typing import List

class ICLSequence:
    '''
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    '''
    def __init__(self, word_pairs: List[List[str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    # def prompt(self):
    #     '''Returns the prompt, which contains all but the second element in the last word pair.'''
    #     p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
    #     return p[:-len(self.completion())]

    def prompt(self):
        '''Returns the prompt, which contains all but the second element in the last word pair.'''
        p = ", ".join([f"{x} -> {y}" for x, y in self.word_pairs])
        return p[:-len(self.completion())-1]

    def completion(self):
        '''Returns the second element in the last word pair (with padded space).'''
        return "" + self.y[-1]

    def __str__(self):
        '''Prints a readable string representation of the prompt & completion (indep of template).'''
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(", ")


word_list = [["hot", "cold"], ["yes", "no"], ["in", "out"], ["up", "down"]]
seq = ICLSequence(word_list)

print("Tuple-representation of the sequence:")
print(seq)
print("\nActual prompt, which will be fed into the model:")
print(seq.prompt())

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
hot -> cold, yes -> no, in -> out, up ->


In [8]:
class ICLDataset:
    '''
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    '''

    def __init__(
        self,
        word_pairs: List[List[str]],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        seed: int = 0,
        corrupted: bool = False,
    ):
        assert n_prepended+1 <= len(word_pairs), "Not enough antonym pairs in dataset to create prompt."

        self.word_pairs = word_pairs
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `ICLSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(len(self.word_pairs), n_prepended+1, replace=False)
            random_orders = np.random.choice([1, -1], n_prepended+1)
            if not(bidirectional): random_orders[:] = 1
            word_pairs = [self.word_pairs[pair][::order] for pair, order in zip(random_pairs, random_orders)]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = ICLSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        '''Creates a corrupted version of the dataset (with same random seed).'''
        return ICLDataset(self.word_pairs, self.size, self.n_prepended, self.bidirectional, corrupted=True, seed=self.seed)

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

In [9]:
def tokenized_to_inputs(input_ids, attention_mask):
    token_array = jnp.asarray(input_ids)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    mask_array = jnp.asarray(attention_mask, dtype=jnp.bool)
    mask_array = jax.device_put(mask_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    mask_array = pz.nx.wrap(mask_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs

In [10]:
prompt = "<user>Follow the pattern\n{}"
target_layer = 17

In [11]:
task_names = [
    "en_es"
]
task = "en_es"

n_seeds = 10

# n_few_shots, batch_size, max_seq_len = 64, 64, 512
n_few_shots, batch_size, max_seq_len = 10, 32, 128

In [28]:
seed = 0

pairs = tasks[task]
pairs = [list(x) for x in pairs.items()]
dataset = ICLDataset(pairs, size=batch_size, n_prepended=n_few_shots, bidirectional=False, seed=seed)

tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
inputs = tokenized_to_inputs(
    **tokenized
)

logits, resids = get_resids_call(inputs)

tokens = tokenized["input_ids"]

mask = inputs.tokens == 1599
mask = mask.unwrap("batch", "seq")

tv = resids[target_layer].value.unwrap("batch", "seq", "embedding")[
    mask
]

tv = tv.mean(
    axis=0
)


In [12]:
from micrlhf.utils.load_sae import get_sae
sae = get_sae(target_layer, 6)

--2024-05-21 17:23:12--  https://huggingface.co/nev/phi-3-4k-saex-test/resolve/main/l17-test-run-6-4.52E-06/sae_weights.safetensors?download=true
Resolving huggingface.co (huggingface.co)... 108.156.211.125, 108.156.211.51, 108.156.211.90, ...
Connecting to huggingface.co (huggingface.co)|108.156.211.125|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/eb/d8/ebd889d6ac58573e8e8a7aa1176d4d357581a6da60135b94aca378fddf4e9e54/1623d8da38be3171fcc8516a4cbe9fdb80e3d77e370aa5690895697649d688f3?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27sae_weights.safetensors%3B+filename%3D%22sae_weights.safetensors%22%3B&Expires=1716571392&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNjU3MTM5Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2ViL2Q4L2ViZDg4OWQ2YWM1ODU3M2U4ZThhN2FhMTE3NmQ0ZDM1NzU4MWE2ZGE2MDEzNWI5NGFjYTM3OGZkZGY0ZTllNTQvMTYyM2Q4ZG

In [13]:
dictionary = sae["W_dec"]
dictionary.shape

In [15]:
pairs = tasks[task]
pairs = [list(x) for x in pairs.items()]
seed = 0

dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=seed+1)


tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
inputs = tokenized_to_inputs(
    **tokenized
)

target_tokens = [x[1] for x in tokenizer(dataset.completions)["input_ids"]]
target_tokens = jnp.asarray(target_tokens)

In [16]:
def get_logprob_diff(logits: jnp.ndarray, target_tokens, print_results=False):
    # print(logits)
    logprobs = jax.nn.log_softmax(logits, axis=-1)
    answer_logprobs = logprobs[:, -1]

    target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

    if print_results:
        print(
            tokenizer.decode(answer_logprobs.argmax(axis=-1))
        )

        print(
            tokenizer.decode(target_tokens)
        )

    return target_logprobs


In [17]:
from micrlhf.llama import LlamaBlock
from functools import partial


layer_source = 17
def make_get_resids(llama, layer_target):
    get_resids = llama.select().at_instances_of(LlamaBlock).pick_nth_selected(layer_target
                                                                              ).apply(lambda x:
        pz.nn.Sequential([
            pz.de.TellIntermediate.from_config(tag=f"resid_pre"),
            x
        ])
    )
    get_resids = pz.de.CollectingSideOutputs.handling(get_resids, tag_predicate=lambda x: x.startswith("resid_pre"))
    return get_resids
jittify = lambda x: partial(jax.jit(lambda lr, *args, **kwargs: lr(*args, **kwargs)[1][0].value), x)
get_resids_initial = make_get_resids(llama, layer_source)
get_resids_initial = jittify(get_resids_initial)

In [18]:
intital_resids = get_resids_initial(inputs)

In [19]:
layer_target = 17

taker = jit_wrapper.Jitted(llama.select().at_instances_of(LlamaBlock).apply_with_selected_index(
    lambda i, x: x if i >= layer_target else pz.nn.Identity()
).select().at_instances_of(pz.nn.EmbeddingLookup).apply(lambda _: pz.nn.Identity())
                .select().at_instances_of(pz.nn.ConstantRescale).pick_nth_selected(0).apply(lambda _: pz.nn.Identity()))

In [71]:
import dataclasses


scale = 20


def get_loss(weights, dictionary, inputs, target_tokens):
    # weights = weights / weights.max()
    weights = jax.nn.relu(weights)
    

    recon = jnp.einsum("fv,f->v", dictionary, weights)
    # act_add = add_vector(
    #     llama, recon.astype('bfloat16'), target_layer, 1.0, position="last"
    # )
    # recon = recon / jnp.linalg.norm(recon) * scale
    recon = recon.astype('bfloat16')

    modified = pz.nx.nmap(lambda a, b: a.at[-1].add(b))(
        intital_resids.untag("seq", "embedding"), pz.nx.wrap(recon, "embedding").untag("embedding")
        ).tag("seq", "embedding")

    inputs = dataclasses.replace(inputs, tokens=modified)

    # print(inputs)

    logits = taker(inputs).unwrap("batch", "seq", "vocabulary")

    # logits = act_add(inputs).unwrap("batch", "seq", "vocabulary")

    logprob_diff = get_logprob_diff(logits, target_tokens)
    loss = -logprob_diff.mean()

    # return loss + 1e-2 * (jnp.abs(weights) / (jnp.abs(weights) + 0.1)).sum(), ((weights != 0).sum(), loss)
    # return loss + 1e-5 * jnp.linalg.norm(weights, ord=0.5), ((weights != 0).sum(), loss)
    return loss + 1e-2 * jnp.linalg.norm(weights, ord=1), ((weights != 0).sum(), loss)
    # return loss, ((weights != 0).sum(), loss)

In [73]:
from functools import partial

optimizer = optax.chain(
    # optax.clip_by_global_norm(1.0),
    # optax.adam(1e-3, b1=0, b2=0.99),
    optax.adam(5e-2),
    optax.zero_nans(),
    
)

lwg = jax.value_and_grad(get_loss, has_aux=True)
# lwg = jax.jit(lwg)
shrinkage = 0

# @partial(jax.jit, donate_argnums=(0, 1))
def train_step(weights, opt_state, dictionary, inputs, target_tokens, pos_only=True):
    (loss, (l0, loss_)), grad = lwg(weights, dictionary, inputs, target_tokens)

    updates, opt_state = optimizer.update(grad, opt_state, weights)
    weights = optax.apply_updates(weights, updates)
    # weights_abs = jnp.abs(weights)
    # weights = jnp.sign(weights) * jax.nn.relu(weights_abs - shrinkage)

    return loss, weights, opt_state, dict(l0=l0, loss=loss_)

In [74]:
iterations = 1000

weights = jnp.ones(dictionary.shape[0]) * 0.1
opt_state = optimizer.init(weights)


for _ in (bar := trange(iterations)):
    loss, weights, opt_state, metrics = train_step(weights, opt_state, dictionary, inputs, target_tokens)

    tk = jax.lax.top_k(weights, 2)

    bar.set_postfix(loss_optim=loss, **metrics, top=tk[1][0], top_diff=(tk[0][0] - tk[0][1]) / tk[0][0])

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [75]:
_, i = jax.lax.top_k(jnp.abs(weights), 10)
weights[i] / jnp.abs(weights).max(), i

In [80]:
pairs = tasks[task]
pairs = [list(x) for x in pairs.items()]
dataset = ICLDataset(pairs, size=batch_size, n_prepended=0, bidirectional=False, seed=10)


tokenized = tokenizer.batch_encode_plus([prompt.format(x) for x in dataset.prompts], padding="longest", max_length=max_seq_len, truncation=True, return_tensors="np")
inputs = tokenized_to_inputs(
    **tokenized
)

target_tokens = [x[1] for x in tokenizer(dataset.completions)["input_ids"]]
target_tokens = jnp.asarray(target_tokens)

# weights, recon = grad_pursuit(tv * 2, sae["W_dec"], k, pos_only=True)

recon = sae["W_dec"][37312] * 20

act_add = add_vector(
    llama, recon.astype('bfloat16'), layer_target, scale=1.0, position="last"
)

logits = act_add(inputs).unwrap("batch", "seq", "vocabulary")


logprobs = jax.nn.log_softmax(logits, axis=-1)
answer_logprobs = logprobs[:, -1]

target_logprobs = jnp.take_along_axis(answer_logprobs, target_tokens[:, None], axis=-1).squeeze()

print(
    tokenizer.decode(answer_logprobs.argmax(axis=-1))
)

print(
    tokenizer.decode(target_tokens)
)

(target_logprobs).mean()

# get_logprob_diff(logits.unwrap("batch", "seq", "vocabulary"), target_tokens, print_results=True)

wor mon o general tarde g cur number pet fut mat cort ries time estud vest organiz s población do at mon c require corte fund ver something natur indust redu jug
pre din esc general ma gobierno cur cart ace fut mat cort ries h estud lle organiz sim población méd at din lle ex corte fond ver algo amb indust redu jug
