In [1]:
import numpy as np; np.random.seed(123456)
import pandas as pd
from charactertraining.constants import DATA_PATH

In [25]:
model = "llama-3.1-8b"

In [30]:
# load rephrased answers
path = f"{DATA_PATH}/critiques/{model}.jsonl"
outputs = pd.read_json(path, orient="records", lines=True)
# load original questions
path = f"{DATA_PATH}/questions.jsonl"
inputs = pd.read_json(path, orient="records", lines=True)
# grab questions
questions = inputs["messages"].apply(lambda x: x[0]["content"])
# duplicate each row 5 times
duplicated_data = []
for question in questions:
    for _ in range(5):
        duplicated_data.append(question)
questions = pd.Series(duplicated_data, name="question")
# add them to the dataset
dataset = pd.concat([outputs[["initial", "revisions"]], questions], axis=1)
# split questions for sft and dpo
unique_questions = dataset["question"].unique()
np.random.shuffle(unique_questions) 
# split unique questions in half
n_split = len(unique_questions) // 2
sft_questions = unique_questions[:n_split]
dpo_questions = unique_questions[n_split:]
# split dataset according to these questions
sft = dataset[dataset["question"].isin(sft_questions)].reset_index(drop=True)
dpo = dataset[dataset["question"].isin(dpo_questions)].reset_index(drop=True)
# sft messages
sft["messages"] = sft.apply(
    lambda row: [
        {
            "role": "user",
            "content": row["question"]
        },
        {
            "role": "assistant",
            "content": row["revisions"]
        }
    ], axis=1)
# dpo chosen and rejected
dpo["chosen"] = dpo.apply(
    lambda row: [
        {
            "role": "user",
            "content": row["question"]
        },
        {
            "role": "assistant",
            "content": row["revisions"]
        }
    ], axis=1)
dpo["rejected"] = dpo.apply(
    lambda row: [
        {
            "role": "user",
            "content": row["question"]
        },
        {
            "role": "assistant",
            "content": row["initial"]
        }
    ], axis=1)
# save datasets
sft.to_json(f"{DATA_PATH}/sft/{model}.jsonl", orient="records", lines=True)
dpo.to_json(f"{DATA_PATH}/dpo/{model}.jsonl", orient="records", lines=True)