In [None]:
import json
import os

from datasets import load_dataset
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from archer.data import ReplayBuffer
from archer.environment.env_utils import add_mc_return, add_trajectory_reward

In [11]:
dataset = load_dataset("OpenAssistant/oasst2", split="train")
print(json.dumps(dataset[0], indent=4))

{
    "message_id": "002c4715-b026-48d1-8d19-3f724a9fc1e8",
    "parent_id": null,
    "user_id": "30d0209f-a418-4fac-8157-adf8ddc21aee",
    "created_date": "2023-02-05T22:44:05.434674+00:00",
    "text": "Dame los pasos de las cosas que deber\u00eda de aprender para ser un desarrollador de videojuegos.",
    "role": "prompter",
    "lang": "es",
    "review_count": 3,
    "review_result": true,
    "deleted": false,
    "rank": null,
    "synthetic": false,
    "model_name": null,
    "detoxify": {
        "toxicity": 0.0023501887917518616,
        "severe_toxicity": 0.00018701136286836118,
        "obscene": 0.0027747140266001225,
        "identity_attack": 0.00036620706669054925,
        "insult": 0.002524558687582612,
        "threat": 0.00032923146500252187,
        "sexual_explicit": 0.0001823759957915172
    },
    "message_tree_id": "002c4715-b026-48d1-8d19-3f724a9fc1e8",
    "tree_state": "ready_for_export",
    "emojis": {
        "name": [
            "+1",
            "_sk

In [13]:
# Convert the dataset to a pandas DataFrame for easier filtering
df = pd.DataFrame(dataset)

In [15]:
def find_paths(node_id, path):
    children = df[df['parent_id'] == node_id]
    if children.empty:
        paths.append(path)
    else:
        for _, child in children.iterrows():
            find_paths(child['message_id'], path + [child['message_id']])

paths = []
roots = df[df['parent_id'].isnull()]
for _, root in roots.iterrows():
    find_paths(root['message_id'], [root['message_id']])

In [33]:
def format_prompt(message):
    return f"\n\nHuman: {message}\n\nAssistant:"

trajectories = []
for path in tqdm(paths):
    trajectory = []

    obs = None
    action = None
    reward = None
    next_obs = None
    done = False
    for i, node_id in enumerate(path):
        row = df[df['message_id'] == node_id].iloc[0].to_dict()
        if row["role"] == "assistant":
            action = row["text"]
            reward = -(row["rank"] or 0)
        elif obs is None:
            obs = format_prompt(row["text"])
        else:
            next_obs = obs + action + format_prompt(row["text"])
            trajectory.append(
                {
                    "observation": obs,
                    "next_observation": next_obs,
                    "reward": reward,
                    "done": i == len(path) - 1,
                    "action": action,
                }
            )
            obs = next_obs
            action = None

    if action is not None:
        trajectory.append(
            {
                "observation": obs,
                "action": action,
                "reward": reward,
                "next_observation": next_obs or (obs + action),
                "done": True,
            }
        )

    trajectory = add_mc_return(add_trajectory_reward(trajectory))
    trajectories.append(trajectory)

100%|██████████| 65143/65143 [47:56<00:00, 22.65it/s] 


In [65]:
replay_buffer= ReplayBuffer(batch_size=2, capacity=len(trajectories))
for trajectory in tqdm(trajectories):
    for transition in trajectory:
        replay_buffer.insert(**transition)

100%|██████████| 65143/65143 [00:00<00:00, 116128.59it/s]


In [67]:
torch.save(replay_buffer, os.path.join("outputs/oasst/replay_buffer.pt"))
torch.save(trajectories, os.path.join("outputs/oasst/trajectories.pt"))