In [1]:
from functools import partial

from datasets import Dataset, DatasetDict, load_dataset
import pandas as pd

import matplotlib.pyplot as plt
import numpy as np

# if using a Jupyter notebook, includue:
%matplotlib inline

In [2]:
dataset = load_dataset("sl-alex/openai-prm800k-stepwise-best")

Downloading readme:   0%|          | 0.00/649 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/4.19M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/390M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/473k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/10.9M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [3]:
# Remove intermediate stuff because final answer has all of them anyway
dataset = dataset.filter(lambda row: row["answer"] is not None)

Filter:   0%|          | 0/629470 [00:00<?, ? examples/s]

Filter:   0%|          | 0/19228 [00:00<?, ? examples/s]

In [4]:
# Remove stuff with mathml and special symbols
filtered_dataset = dataset.filter(lambda row: (
    '$' not in row["instruction"]
    and '\\' not in row["instruction"]
    and all([
        ('$' not in r
        and '\\' not in r)
        for r in row["responses"]
    ])
))

len(filtered_dataset["train"]), len(dataset["train"])

Filter:   0%|          | 0/12419 [00:00<?, ? examples/s]

Filter:   0%|          | 0/676 [00:00<?, ? examples/s]

(2199, 12419)

In [5]:
# Chatml utils
make_chatml = lambda name, role, content: dict(
    name=name, role=role, content=content,
)

system = lambda name, content: make_chatml(
    role="system",
    name=name,
    content=content,
)

situation = lambda content: system(name="situation", content=content)
thought = lambda content: system(name="thought", content=content)
information = lambda content: system(name="information", content=content)
me = lambda content, name=None: make_chatml(
    role="assistant",
    content=content,
    name=name,
)

person = lambda content, name=None: make_chatml(
    role="user",
    content=content,
    name=name,
)

In [6]:
filtered_dataset["train"][40]

{'instruction': 'Joe is studying a bacteria population.  There are 20 bacteria present at 3:00 p.m. and the population doubles every 3 minutes.  Assuming none of the bacteria die, how many bacteria are present at 3:15 p.m. the same day?',
 'responses': ['We first need to find out how many minutes have passed since 3:00 p.m.',
  'Yes, it is now 3:15 p.m., so 15 minutes have passed.',
  'Right. And we are told that the bacteria population doubles every 3 minutes.',
  'That means that at 3:03 p.m., there were 40 bacteria. And at 3:06 p.m., there were 80 bacteria.',
  'We can keep going until we reach 3:15 p.m.',
  'Right. So at 3:09 p.m., there were 160 bacteria. And at 3:12 p.m., there were 320 bacteria.',
  'Finally, at 3:15 p.m., there were 640 bacteria.'],
 'next_response': 'So there are 640 bacteria present at 3:15 p.m.',
 'answer': '640',
 'is_human_response': False}

In [7]:
def to_chatml(row):
    instruction = row["instruction"]
    responses = row["responses"][:]
    answer = row["answer"]

    answer = f'The answer is "{answer}".'

    thoughts_text = "\n".join([
        f"{i+1}. {response}"
        for i, response in enumerate(responses)
    ])

    chatml = [
        person(instruction),
        thought(f"Thoughts:\n\n{thoughts_text}"),
        me(answer),
    ]

    return dict(chatml=chatml)

In [8]:
# lens = np.array([t for t in orca_mini_prompt_response["train"]["token_count"] if t < 2048])
# plt.hist(lens, 100)
# plt.show()

In [9]:
final = filtered_dataset.map(to_chatml)

Map:   0%|          | 0/2199 [00:00<?, ? examples/s]

Map:   0%|          | 0/153 [00:00<?, ? examples/s]

In [10]:
final.push_to_hub("diwank/prm800k-chatml", private=True)

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading metadata:   0%|          | 0.00/725 [00:00<?, ?B/s]