In [1]:
from datasets import load_dataset, Dataset
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# top 1

control_tks = {
    "user": "<|im_start|>user\n",
    "assistant": "<|im_end|>\n<|im_start|>assistant\n",
    "end": "<|im_end|>\n"
}
def get_messages(text):
    for tk in control_tks:
        text = text.replace(control_tks[tk], "@!@!@!")
    text = text.split("@!@!@!")
    prompt, response = text[1], text[2]
    return [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": response}
    ]

data = load_dataset("OpenAssistant/oasst_top1_2023-08-25", cache_dir="/scratch/datasets")
all_messages = []
for idx in range(len(data["train"])):
    all_messages.append(get_messages(data["train"][idx]["text"]))
for idx in range(len(data["test"])):
    all_messages.append(get_messages(data["test"][idx]["text"]))

data = pd.DataFrame()
data["messages"] = all_messages
dataset = Dataset.from_pandas(data)
dataset.push_to_hub(
    "maius/oasst_top1",
    private=False
)

Creating parquet from Arrow format: 100%|██████████| 14/14 [00:00<00:00, 265.75ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.51s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/maius/oasst_top1/commit/9725585ef47e874fb0d557e5cce6ea5034098179', commit_message='Upload dataset', commit_description='', oid='9725585ef47e874fb0d557e5cce6ea5034098179', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/maius/oasst_top1', endpoint='https://huggingface.co', repo_type='dataset', repo_id='maius/oasst_top1'), pr_revision=None, pr_num=None)

In [6]:
data = load_dataset("OpenAssistant/oasst2", cache_dir="/scratch/datasets")

# group messages by message_tree_id
grouped_messages = {}
for split in data.keys():
    for row in data[split]:
        tree_id = row["message_tree_id"]
        if tree_id not in grouped_messages:
            grouped_messages[tree_id] = []
        
        # store relevant information for each message
        grouped_messages[tree_id].append({
            "message_id": row["message_id"],
            "parent_id": row["parent_id"],
            "text": row["text"],
            "role": row["role"]
        })

# sort messages within each tree and convert to the desired format
all_messages = []
for tree_id, messages in grouped_messages.items():
    def process_tree(chains):
        current_chains = []
        for chain in chains:
            children = [msg for msg in messages if msg["parent_id"] == chain[-1]["message_id"]]
            if len(children) == 0: continue
            role = "prompter" if chain[-1]["role"] == "assistant" else "assistant"
            for child in children: 
                assert child["role"] == role
                current_chains.append(chain + [child])     
        return process_tree(current_chains) if current_chains else chains
                
    roots = [[msg] for msg in messages if not msg["parent_id"]]
    assert len(roots) == 1
    assert roots[0][-1]["role"] == "prompter"
    # walk the tree
    chains = process_tree(roots)
    # split each possible assistant message to train on
    chains_extended = []
    for chain in chains:
        for idx in range(len(chain)):
            if chain[idx]["role"] == "assistant":
                chains_extended.append(chain[:idx+1])
    chains = chains_extended
    # convert to messages format
    formatted_messages = []
    for chain in chains:
        current_messages, role = [], "user"
        for msg in chain:
            current_messages.append({"role": role, "content": msg["text"]})
            role = "assistant" if role == "user" else "user"
        formatted_messages.append(current_messages)
    all_messages.extend(formatted_messages)

data = pd.DataFrame()
data["messages"] = all_messages
dataset = Dataset.from_pandas(data)
dataset.push_to_hub(
    "maius/oasst2",
    private=False
)

Creating parquet from Arrow format: 100%|██████████| 75/75 [00:00<00:00, 473.67ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:04<00:00,  4.07s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/maius/oasst2/commit/f210874bfa98e56342c5b928ea6249f31f522ac5', commit_message='Upload dataset', commit_description='', oid='f210874bfa98e56342c5b928ea6249f31f522ac5', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/maius/oasst2', endpoint='https://huggingface.co', repo_type='dataset', repo_id='maius/oasst2'), pr_revision=None, pr_num=None)