In [1]:
import os
if "models" not in os.listdir("."):
    os.chdir("..")

In [2]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [3]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf",
                                         from_type="gemma",
                                         load_eager=True
                                         )

In [13]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [4]:
!mkdir -p data
import pandas as pd

df_adv = pd.read_csv("data/adv.csv")

format_prompt = """<start_of_turn>user\n
{}\n
<start_of_turn>model\n
{}"""
# offset = 1
# df_do = df.apply(lambda x: format_response.format(x['goal'], x['target']), axis=1)
# prompts_harmful = df.apply(lambda x: format_prompt.format(x['goal'], "")[:-offset], axis=1).to_list()[:100]
prompts_harmful = df_adv.apply(lambda x: format_prompt.format(x['goal'], ""), axis=1).to_list()[:100]
dataset_jail = pd.read_csv("data/jail.csv").apply(lambda x: x["Goal"], axis=1).to_list()
prompts_jail = [format_prompt.format(x, "") for x in dataset_jail]
import datasets
# https://colab.research.google.com/drive/1a-aQvKC9avdZpdyBn4jgRQFObTPy1JZw
hf_path = 'tatsu-lab/alpaca'
dataset = datasets.load_dataset(hf_path)
# filter for instructions that do not have inputs
prompts_harmless = []
for i in range(len(dataset['train'])):
    if len(prompts_harmless) >= len(prompts_harmful):
        break
    if dataset['train'][i]['input'].strip() == '':
        # prompts_harmless.append(format_prompt.format(dataset['train'][i]['instruction'], "")[:-offset])
        prompts_harmless.append(format_prompt.format(dataset['train'][i]['instruction'], ""))

# ds = datasets.load_dataset("MBZUAI/LaMini-instruction", split="train", streaming=True)
# prompts_harmless = []
# for _, text in zip(range(100), ds):
#     prompts_harmless.append(format_prompt.format(text["instruction"], ""))

# ds = datasets.load_dataset("nev/openhermes-2.5-phi-format-text", split="train", streaming=True)
# prompts_harmless = []
# for _, text in zip(range(100), ds):
#     text = text["text"]
#     text = "".join(text.partition("<|assistant|>\n")[:2])
#     prompts_harmless.append(text)

  pid, fd = os.forkpty()


In [14]:
from micrlhf.sampling import sample
msl = 128
completions_harmless = sample(llama, tokenizer, prompts_harmless, max_seq_len=msl, do_sample=True, verbose=True, return_only_completion=True)[0]

  0%|          | 0/81 [00:00<?, ?it/s]

In [15]:
ds_harmless = [(prompt, completion.partition("<eos>")[0].strip()) for prompt, completion in zip(prompts_harmless, completions_harmless)]

In [6]:
from micrlhf.sampling import sample, trange, jnp, load_tokenizer, jit_wrapper
import jax

tokens = tokenizer.batch_encode_plus(prompts_harmful + prompts_harmless,
                                     return_tensors="np",
                                     padding="max_length",
                                     truncation=True,
                                     max_length=128,
                                     return_attention_mask=True)
token_array = jnp.asarray(tokens["input_ids"])
token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")
inputs = llama.inputs.from_basic_segments(token_array)

In [29]:
from micrlhf.utils.activation_manipulation import ablate_direction
from functools import partial


@partial(jax.jit, static_argnames=("normalize", "batch_axis"))
def get_loss(direction, llama, input_ids, loss_mask, normalize=True, batch_axis="direction"):
    inputs = llama.inputs.from_basic_segments(input_ids)
    act_abl = ablate_direction(llama, direction, normalize=normalize, batch_axis=batch_axis)
    logits = act_abl(inputs)
    loss = pz.nx.nmap(lambda l, i, m: -jnp.take_along_axis(jax.nn.log_softmax(l[:-1].astype(jnp.float32), -1), i[1:, None], -1)[:, 0] * m[1:] / m[1:].sum())(
        logits.untag("seq", "vocabulary"), input_ids.untag("seq"), loss_mask.untag("seq")).sum().unwrap("batch").mean()
    return loss
lwg = jax.jit(jax.value_and_grad(get_loss), static_argnames=("normalize", "batch_axis"))

In [35]:
from micrlhf.utils.vector_storage import download_vector
import optax
import random
n_iterations = 50
batch_size = 100
max_length = 128
harmless_weight = 1
direction = download_vector("gemma-refusal-l12", overwrite=True)
optimizer = optax.adam(0.01)
opt_state = optimizer.init(direction)
for i in (bar := trange(n_iterations)):
    data_harmful = df_adv.sample(batch_size).apply(lambda x: (
        tokenizer.encode(format_prompt.format(x["goal"], "")),
        tokenizer.encode(x["target"])[1:],), axis=1).to_list()
    data_harmless = [(tokenizer.encode(format_prompt.format(prompt, "")), tokenizer.encode(completion)[1:]) for prompt, completion in [(prompts_harmless[i], completions_harmless[i]) for i in random.sample(list(range(len(prompts_harmless))), k=batch_size)]]
    # data = [x + (1,) for x in data_harmful] + [x + (harmless_weight,) for x in data_harmless]
    # data = [x + (1,) for x in data_harmful]
    data = data_harmless + data_harmful
    # data = [x + (harmless_weight,) for x in data_harmless]
    # input_ids = [x + y for x, y, _ in data]
    input_ids = [x + y for x, y in data]
    input_ids = [(x + [tokenizer.pad_token_id] * max(0, max_length - len(x)))[:max_length] for x in input_ids]
    # loss_mask = [[0] * len(x) + [z] * len(y) for x, y, z in data]
    loss_mask = [[0] * len(x) + [1] * len(y) for x, y in data]
    loss_mask = [(x + [0] * (max_length - len(x)))[:max_length] for x in loss_mask]

    input_ids = pz.nx.wrap(jnp.asarray(input_ids), "batch", "seq")
    loss_mask = pz.nx.wrap(jnp.asarray(loss_mask), "batch", "seq")

    loss, grad = lwg(direction, llama, input_ids, loss_mask)
    # direction -= 0.01 * grad
    # direction = direction / jnp.linalg.norm(direction)
    direction, opt_state = optimizer.update(grad, opt_state, direction)
    bar.set_postfix(loss=float(loss))

  0%|          | 0/50 [00:00<?, ?it/s]