# Data processing to produce

1. MSJs for test/train
2. Regular/science conversations for test/train
3. Numerical sequence prediction for test/train
4. Parity task for test

In [1]:
import json
import os
import random
from dotenv import load_dotenv
from transformers import AutoTokenizer
from tqdm import tqdm
import anthropic
import torch
from vars import assistant_fake_tags, user_fake_tags, format_functions
from datasets import load_dataset
from utils import format_conversation, assistant_mask_function, user_mask_function
from viz import viz_mask

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()

True

In [3]:
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
MODEL_LLAMA_3_CHAT = "meta-llama/Meta-Llama-3.1-8B-Instruct"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_LLAMA_3_CHAT, token=HUGGINGFACE_TOKEN)

In [5]:
with open("datasets/refusal.json", "r") as f:
    harmful_data = json.load(f)
with open( "datasets/mean_normal_responses_concise.json", "r") as f:
    mean_data = json.load(f)

In [6]:
def make_conversation(user_assistant: list[tuple[str, str]]):
    messages = []
    for u, a in user_assistant:
        messages.append({"role": "user", "content": [{"type": "text", "text": u}]})
        if a is not None:
            messages.append(
                {"role": "assistant", "content": [{"type": "text", "text": a}]}
            )
    return messages

In [7]:
def make_refusal_answers():
    api = anthropic.Anthropic(api_key=os.getenv("CLAUDE_API_KEY"))
    result = []
    for example in tqdm(harmful_data):
        question = example["question"]
        answer_harmful = example["answer_harmful"]
        response = api.messages.create(
            model="claude-3-5-sonnet-20241022",
            max_tokens=700,
            system="You are a helpful assistant.",
            messages=make_conversation(
                [
                    (
                        question,
                        None,
                    ),
                ]
            ),
        )
        answer_harmless =response.content[0].text
        result.append({
            "question": question,
            "answer_harmful": answer_harmful,
            "answer_harmless": answer_harmless
        })
    with open("datasets/refusal_v2.json", "w") as f:
        json.dump(result, f)

In [8]:
ds = load_dataset("LLM-LAT/harmful-dataset")

Generating train split: 100%|██████████| 4948/4948 [00:00<00:00, 381272.34 examples/s]


In [9]:
more_harmful_data = []
for row in ds["train"]:
    more_harmful_data.append(
        {
            "question": row["prompt"],
            "answer_harmful": row["rejected"],
            "answer_harmless": row["chosen"],
        }
    )

In [10]:
def make_dataset_single_shot(key_q, key_a, source_data):
    result = []
    for item in source_data:
        conversation = [
            {"role": "user", "content": item[key_q]},
            {"role": "assistant", "content": item[key_a]},
        ]
        formatted_conversation = format_conversation(conversation, tokenizer)
        tokens = tokenizer.encode(formatted_conversation, return_tensors="pt")
        user_mask = user_mask_function(tokens)
        assistant_mask = assistant_mask_function(tokens)
        result.append({
            "text": formatted_conversation,
            "tokens": tokens.tolist(),
            "user_mask": user_mask.tolist(),
            "assistant_mask": assistant_mask.tolist(),
        })
    return result

In [11]:
def make_dataset(n, key_q, key_a, key_a_final, source_data, lengths: list[int]|str):
    result = []
    for _ in tqdm(range(n)):
        if isinstance(lengths, str):
            if lengths == "random":
                _lengths = [random.choice(range(5, 50))]
            else:
                raise ValueError("Invalid lengths string")
        else:
            _lengths = lengths

        shots = random.sample(source_data, max(_lengths))
        assistant_tags = random.choices(assistant_fake_tags, k=max(_lengths))
        user_tags = random.choices(user_fake_tags, k=max(_lengths))
        assistant_tags = [
            random.choice(format_functions)(tag) for tag in assistant_tags
        ]
        user_tags = [random.choice(format_functions)(tag) for tag in user_tags]

        for l in _lengths:
            if l == 1:
                conversation = [
                    {"role": "user", "content": shots[-1][key_q]},
                    {"role": "assistant", "content": shots[-1][key_a_final]},
                ]
            else:
                prompt = ""
                for i, (msg, a_tag, u_tag) in enumerate(
                    zip(shots[-l:-1], assistant_tags[-l:-1], user_tags[-l:-1])
                ):
                    if i == 0:
                        prompt += f"{msg[key_q]}\n{a_tag}\n\n{msg[key_a]}\n"
                    elif i < l - 1:
                        prompt += f"{u_tag}\n\n{msg[key_q]}\n{a_tag}\n\n{msg[key_a]}\n"
                prompt += f"{user_tags[-1]}\n\n{shots[-1][key_q]}"
                conversation = [
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": shots[-1][key_a_final]},
                ]
            formatted_conversation = format_conversation(conversation, tokenizer)
            tokens = tokenizer.encode(formatted_conversation, return_tensors="pt")
            user_mask = user_mask_function(tokens)
            assistant_mask = assistant_mask_function(tokens)
            result.append({
                "text": formatted_conversation,
                "tokens": tokens.tolist(),
                "user_mask": user_mask.tolist(),
                "assistant_mask": assistant_mask.tolist(),
                "n_shots": l,
            })
    return result

In [14]:
random.shuffle(mean_data)
all_harmful = harmful_data + random.sample(more_harmful_data, 300)
random.shuffle(all_harmful)

len(mean_data), len(all_harmful)

(568, 541)

In [16]:
test_msj_lengths = [2, 4, 8, 12, 16, 24, 32]

train_harmful = all_harmful[:400]
test_harmful = all_harmful[400:]
train_mean = mean_data[:400]
test_mean = mean_data[400:]

In [17]:
msjs_jailbreak_test = make_dataset(
    50, "question", "answer_harmful", "answer_harmful", test_harmful, test_msj_lengths
)
msjs_recovery_test = make_dataset(
    50, "question", "answer_harmful", "answer_harmless", test_harmful, test_msj_lengths
)
msjs_mean_jailbreak_test = make_dataset(
    50, "question", "mean_answer", "mean_answer", test_mean, test_msj_lengths
)
msjs_mean_recovery_test = make_dataset(
    50, "question", "mean_answer", "normal_answer", test_mean, test_msj_lengths
)

100%|██████████| 50/50 [00:10<00:00,  4.96it/s]
100%|██████████| 50/50 [00:09<00:00,  5.04it/s]
100%|██████████| 50/50 [00:08<00:00,  6.22it/s]
100%|██████████| 50/50 [00:07<00:00,  6.44it/s]


In [18]:
harmful_msjs_train = make_dataset(
    300, "question", "answer_harmful", "answer_harmless", train_harmful, "random"
)
mean_msjs_train = make_dataset(
    300, "question", "mean_answer", "normal_answer", train_mean, "random"
)

100%|██████████| 300/300 [00:16<00:00, 17.75it/s]
100%|██████████| 300/300 [00:13<00:00, 22.73it/s]


In [19]:
with open("processed_data/train/harmful_msjs.json", "w") as f:
    json.dump(harmful_msjs_train, f)
with open("processed_data/train/mean_msjs.json", "w") as f:
    json.dump(mean_msjs_train, f)
with open("processed_data/test/msjs_jailbreak.json", "w") as f:
    json.dump(msjs_jailbreak_test, f)
with open("processed_data/test/msjs_recovery.json", "w") as f:
    json.dump(msjs_recovery_test, f)
with open("processed_data/test/msjs_mean_jailbreak.json", "w") as f:
    json.dump(msjs_mean_jailbreak_test, f)
with open("processed_data/test/msjs_mean_recovery.json", "w") as f:
    json.dump(msjs_mean_recovery_test, f)

In [20]:
def viz_row(row, tokenizer, show="assistant"):
    tokens = torch.tensor(row["tokens"])
    if show == "user":
        user_mask = torch.tensor(row["user_mask"])
        viz_mask(tokens, user_mask, tokenizer)
    elif show == "assistant":
        assistant_mask = torch.tensor(row["assistant_mask"])
        viz_mask(tokens, assistant_mask, tokenizer)
    else:
        raise ValueError("Invalid show value (user or assistant)")

In [21]:
viz_row(msjs_jailbreak_test[0], tokenizer)

In [22]:
viz_row(msjs_recovery_test[0], tokenizer)

In [23]:
viz_row(msjs_mean_jailbreak_test[0], tokenizer)

In [24]:
viz_row(msjs_mean_recovery_test[0], tokenizer)

In [25]:
viz_row(harmful_msjs_train[0], tokenizer)

In [26]:
viz_row(mean_msjs_train[0], tokenizer)

In [24]:
def make_regular_answers():
    api = anthropic.Anthropic(api_key=os.getenv("CLAUDE_API_KEY"))
    result = []
    for example in tqdm(mean_data):
        question = example["question"]
        mean_answer = example["mean_answer"]
        response = api.messages.create(
            model="claude-3-5-sonnet-20241022",
            max_tokens=700,
            system="You give concise answers to questions, avoiding verbosity.",
            messages=make_conversation(
                [
                    (
                        question,
                        None,
                    ),
                ]
            ),
        )
        normal_answer =response.content[0].text
        result.append({
            "question": question,
            "mean_answer": mean_answer,
            "normal_answer": normal_answer
        })
    with open("datasets/mean_normal_responses_concise.json", "w") as f:
        json.dump(result, f)

In [25]:
ds = load_dataset("HuggingFaceTB/everyday-conversations-llama3.1-2k")

In [26]:
def format_rows(dataset):
    rows = []
    for row in tqdm(dataset):
        conversation = row["messages"]
        if conversation[-1]["role"] == "user":
            conversation.pop()
        if len(conversation) < 2:
            continue
        assert len(conversation) % 2 == 0
        formatted = format_conversation(conversation, tokenizer)
        tokens = tokenizer.encode(formatted, return_tensors="pt")
        user_mask = user_mask_function(tokens)
        assistant_mask = assistant_mask_function(tokens)
        rows.append({
            "text": formatted,
            "tokens": tokens.tolist(),
            "user_mask": user_mask.tolist(),
            "assistant_mask": assistant_mask.tolist(),
            "n_shots": len(conversation)//2
        })
    return rows

In [27]:
regular_conversations_test = format_rows(ds["test_sft"])
regular_conversations_train = random.sample(format_rows(ds["train_sft"]), 1000)

100%|██████████| 119/119 [00:00<00:00, 520.43it/s]
100%|██████████| 2260/2260 [00:04<00:00, 548.86it/s]


In [28]:
viz_row(regular_conversations_train[0], tokenizer)

In [29]:
viz_row(regular_conversations_train[0], tokenizer, show="user")

In [30]:
with open("processed_data/train/regular_conversations.json", "w") as f:
    json.dump(regular_conversations_train, f)
with open("processed_data/test/regular_conversations.json", "w") as f:
    json.dump(regular_conversations_test, f)

In [27]:
ds = load_dataset("jeffmeloy/sonnet3.5_science_conversations")
ds

Generating train split: 100%|██████████| 8835/8835 [00:00<00:00, 75972.24 examples/s]


DatasetDict({
    train: Dataset({
        features: ['conversation'],
        num_rows: 8835
    })
})

In [28]:
rows = list(ds["train"])
rows = random.sample(rows, 1050)

In [29]:
def format_science_rows(dataset):
    rows = []
    for row in tqdm(dataset):
        conversation = row["conversation"]
        shots = []
        if conversation[0]["from"] == "system":
            conversation.pop(0)
        if conversation[-1]["from"] == "human":
            conversation.pop()
        for msg in conversation:
            if msg["from"] == "human":
                shots.append({"role": "user", "content": msg["value"]})
            elif msg["from"] == "gpt":
                shots.append({"role": "assistant", "content": msg["value"]})
            else:
                raise ValueError("Invalid from value for msg", msg)
        formatted = format_conversation(shots, tokenizer)
        tokens = tokenizer.encode(formatted, return_tensors="pt")
        user_mask = user_mask_function(tokens)
        assistant_mask = assistant_mask_function(tokens)
        assert len(shots) % 2 == 0
        rows.append({
            "text": formatted,
            "tokens": tokens.tolist(),
            "user_mask": user_mask.tolist(),
            "assistant_mask": assistant_mask.tolist(),
            "n_shots": len(shots)//2
        })
    return rows

In [30]:
science_conversations = format_science_rows(rows)

100%|██████████| 1050/1050 [00:29<00:00, 35.67it/s]


In [31]:
viz_row(science_conversations[0], tokenizer)

In [32]:
with open("processed_data/train/science_conversations.json", "w") as f:
    json.dump(science_conversations[:-50], f)

In [33]:
with open("processed_data/test/science_conversations.json", "w") as f:
    json.dump(science_conversations[-50:], f)

In [34]:
def make_icl_rows(n_functions, seq_type="arithmetic", split="test"):
    rows = []
    for _ in tqdm(range(n_functions)):
        a = random.choice(list(range(-15, 0)) + list(range(1, 15)))
        b = random.choice(list(range(-15, 0)) + list(range(1, 15)))
        fn = lambda x : a * x + b
        examples = []
        starts = random.sample(range(-30, 30), 48)
        for start in starts:
            if seq_type == "arithmetic":
                seq = [fn(start + i) for i in range(8)]
            elif seq_type == "affine_reccurence":
                seq = [start]
                for _ in range(7):
                    seq.append(fn(seq[-1]))
            examples.append(seq)
        # Generate a prompt for different numbers of demonstrations
        demo_lengths = test_msj_lengths if split == "test" else random.sample(range(2, 50), 2)
        for n in demo_lengths:
            example_set = examples[-n:-1]  # Demonstrations
            final_set = examples[-1]  # Final sequence the model should continue
            prompt = "Here are some examples of sequences:\n"
            for ex in example_set:
                prompt += ",".join(map(str, ex)) + "\n"
            prompt += f"Now continue the sequence that starts with {final_set[0]},"
            answer = ",".join(map(str, final_set))
            formatted = format_conversation([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}], tokenizer)
            tokens = tokenizer.encode(formatted, return_tensors="pt")
            user_mask = user_mask_function(tokens)
            assistant_mask = assistant_mask_function(tokens)
            rows.append({
                "text": formatted,
                "tokens": tokens.tolist(),
                "user_mask": user_mask.tolist(),
                "assistant_mask": assistant_mask.tolist(),
                "n_shots": n
            })
    return rows
            

In [35]:
icl_rows_test = make_icl_rows(15, split="test", seq_type="arithmetic") + make_icl_rows(15, split="test", seq_type="affine_reccurence")
icl_rows_train = make_icl_rows(50, split="train") + make_icl_rows(50, split="train", seq_type="affine_reccurence")

100%|██████████| 15/15 [00:00<00:00, 40.45it/s]
100%|██████████| 15/15 [00:00<00:00, 29.50it/s]
100%|██████████| 50/50 [00:00<00:00, 84.11it/s]
100%|██████████| 50/50 [00:00<00:00, 61.94it/s]


In [36]:
len(icl_rows_test), len(icl_rows_train)

(210, 200)

In [37]:
print(icl_rows_train[0]["text"])

<|start_header_id|>user<|end_header_id|>

Here are some examples of sequences:
-69,-67,-65,-63,-61,-59,-57,-55
-9,-7,-5,-3,-1,1,3,5
25,27,29,31,33,35,37,39
-17,-15,-13,-11,-9,-7,-5,-3
-31,-29,-27,-25,-23,-21,-19,-17
-3,-1,1,3,5,7,9,11
7,9,11,13,15,17,19,21
17,19,21,23,25,27,29,31
41,43,45,47,49,51,53,55
-59,-57,-55,-53,-51,-49,-47,-45
-61,-59,-57,-55,-53,-51,-49,-47
-29,-27,-25,-23,-21,-19,-17,-15
Now continue the sequence that starts with 35,<|eot_id|><|start_header_id|>assistant<|end_header_id|>

35,37,39,41,43,45,47,49<|eot_id|>


In [38]:
viz_row(icl_rows_train[2], tokenizer)

In [39]:
with open("processed_data/test/icl_sequences.json", "w") as f:
    json.dump(icl_rows_test, f)

with open("processed_data/train/icl_sequences.json", "w") as f:
    json.dump(icl_rows_train, f)

In [42]:
def make_parity_rows(n):
    rows = []
    for _ in tqdm(range(n)):
        examples = []
        # Generate 50 examples of (numbers, parities) pairs
        for _ in range(50):
            numbers = random.choices([0, 1], k=16)
            parities = [
                "Odd" if sum(numbers[:i]) % 2 == 1 else "Even" for i in range(1, 17)
            ]
            examples.append((" ".join(map(str, numbers)), " ".join(parities)))
        for msj_length in test_msj_lengths:
            # Take the last msj_length examples
            example_subset = examples[-msj_length:-1]
            final_example = examples[-1]
            prompt = "Here are some examples of sequences of bits and their parities:\n"
            for ex in example_subset:
                prompt += f"{ex[0]} -> {ex[1]}\n"
            prompt += f"Generate the parities for {final_example[0]}."
            conversation = [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": final_example[1]},
            ]
            formatted = format_conversation(conversation, tokenizer)
            tokens = tokenizer.encode(formatted, return_tensors="pt")
            user_mask = user_mask_function(tokens)
            assistant_mask = assistant_mask_function(tokens)
            rows.append(
                {
                    "text": formatted,
                    "tokens": tokens.tolist(),
                    "user_mask": user_mask.tolist(),
                    "assistant_mask": assistant_mask.tolist(),
                    "n_shots": msj_length,
                }
            )
    return rows

In [43]:
parity_rows = make_parity_rows(30)

100%|██████████| 30/30 [00:01<00:00, 15.14it/s]


In [46]:
len(parity_rows)

210

In [44]:
print(parity_rows[-1]["text"])

<|start_header_id|>user<|end_header_id|>

Here are some examples of sequences of bits and their parities:
1 0 1 0 0 0 1 0 0 0 1 0 0 1 1 0 -> Odd Odd Even Even Even Even Odd Odd Odd Odd Even Even Even Odd Even Even
0 1 1 1 0 0 1 0 1 1 0 0 1 1 1 0 -> Even Odd Even Odd Odd Odd Even Even Odd Even Even Even Odd Even Odd Odd
0 1 0 0 1 0 1 1 1 1 0 0 1 1 1 1 -> Even Odd Odd Odd Even Even Odd Even Odd Even Even Even Odd Even Odd Even
0 0 1 1 1 0 1 1 0 1 1 0 1 0 0 0 -> Even Even Odd Even Odd Odd Even Odd Odd Even Odd Odd Even Even Even Even
0 1 1 0 0 1 0 1 0 1 0 1 1 1 1 1 -> Even Odd Even Even Even Odd Odd Even Even Odd Odd Even Odd Even Odd Even
1 1 1 0 1 0 1 1 0 1 1 1 0 1 0 0 -> Odd Even Odd Odd Even Even Odd Even Even Odd Even Odd Odd Even Even Even
0 1 1 0 0 0 0 1 1 1 1 0 1 0 1 1 -> Even Odd Even Even Even Even Even Odd Even Odd Even Even Odd Odd Even Odd
1 0 1 0 0 1 0 0 1 1 0 0 1 1 1 1 -> Odd Odd Even Even Even Odd Odd Odd Even Odd Odd Odd Even Odd Even Odd
0 0 0 1 1 1 1 1 0 1 0 0 0 1 0 0 -

In [45]:
viz_row(parity_rows[10], tokenizer)

In [47]:
with open("processed_data/test/parity_sequences.json", "w") as f:
    json.dump(parity_rows, f)