In [21]:
import json
import os

import pandas as pd
import torch
from tqdm import tqdm

from archer.data import ReplayBuffer
from archer.environment.env_utils import add_mc_return, add_trajectory_reward

In [2]:
# Read the JSON file into a DataFrame
df = pd.read_json('dataset/twenty_questions.json')
df.head()

Unnamed: 0,lines,correct,word
0,"[Is the object alive? Yes., Is the object a ma...",False,[Tomato]
1,"[Is the object animate? No., Is the object man...",False,"[Canvas, Painting canvas]"
2,"[Is it an animal? Yes., Is it a mammal? Yes., ...",True,[Cat]
3,"[Is the object alive? No., Is the object man-m...",False,[Flute]
4,"[Is it an animal? No., Is it an inanimate obje...",True,[River]


In [20]:
trajectories = []
for index, row in tqdm(df.iterrows(), total=len(df)):
    trajectory = []
    next_obs = f"Questions:\n"
    for i, line in enumerate(row["lines"]):
        obs = next_obs

        # question, answer = line.split("? ")[-2:]  # cleaner dataset
        question = "? ".join(line.split("? ")[0:-1])
        answer = line.split("? ")[-1]  # identical dataset

        action = f"{question}?"
        next_obs = f"{obs}{action} {answer}\n"
        reward = -1 if i < len(row["lines"]) - 1 or not row["correct"] else 0
        done = i == len(row["lines"]) - 1
        trajectory.append(
            {
                "observation": obs,
                "action": action,
                "reward": reward,
                "next_observation": next_obs,
                "done": done,
            }
        )
    # print(row["word"])
    # print(json.dumps(trajectory, indent=4))

    trajectory = add_mc_return(add_trajectory_reward(trajectory))
    trajectories.append(trajectory)

100%|██████████| 100000/100000 [00:18<00:00, 5328.74it/s]


In [22]:
replay_buffer= ReplayBuffer(batch_size=2, capacity=len(trajectories))
for trajectory in tqdm(trajectories):
    for transition in trajectory:
        replay_buffer.insert(**transition)

100%|██████████| 100000/100000 [00:04<00:00, 24398.82it/s]


In [24]:
torch.save(replay_buffer, os.path.join("outputs/offline_archer_20q/replay_buffer.pt"))
torch.save(trajectories, os.path.join("outputs/offline_archer_20q/trajectories.pt"))