Skip to content

Commit

Permalink
Add data loader for HF oasst1 (#2951)
Browse files Browse the repository at this point in the history
Make it possible to work with the OASST1 dataset directly from the HuggingFace hub.
Add  a new `hf_dataset_name` parameter to the `load_oasst_export` function.

---------

Co-authored-by: grgau <pedro.ferracini@usp.br>
Co-authored-by: Tobias Pitters <31857876+CloseChoice@users.noreply.github.com>
Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
  • Loading branch information
4 people committed Jun 13, 2023
1 parent fff7272 commit 463d729
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 37 deletions.
44 changes: 39 additions & 5 deletions model/README.md
Expand Up @@ -16,15 +16,49 @@ export DATA_PATH=$PWD/.cache
export MODEL_PATH=$PWD/.saved_models
```

2. Then download the OA data.
2. Then download the OA message tree JSONL file or declare the HuggingFace
dataset to use.

Create a new or modify an existing configuration section in the `config.yaml`
(SFT), `config_rm.yaml` (RM) or `config_rl.yaml` (RL) YAML configuration files
located in the `model_training/configs/` directory and specify the OA JSONL data
file or HuggingFace dataset to use.

- To use a local OASST JSONL file (either `.jsonl` or `.jsonl.gz`) specify the
file name with the `input_file_path` configuration option. Place the file
either in the `cache_dir` (`DATA_PATH`) or specify an absolute path.

```bash
cp /path/to/<oa.jsonl> $DATA_PATH
cp /path/to/<oasst.trees.jsonl> $DATA_PATH
```

Example:

```yaml
my_data_config:
datasets:
- oasst_export:
input_file_path: oasst_export.trees.jsonl.gz
```

Change the `<oa.jsonl>` file used in the `model_training/configs/config.yaml`,
`model_training/configs/config_rl.yaml` and `reward/instructor/rank_datasets.py`
files.
- To use a HuggingFace dataset specify the dataset name with the
`hf_dataset_name` configuration option.

Example:

```yaml
my_data_config:
datasets:
- oasst_export:
hf_dataset_name: OpenAssistant/oasst1
```

_Note_: If both `hf_dataset_name` and `input_file_path` are specified
`input_file_path` will take precedence.

See the
[OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1)
dataset card on the HuggingFace hub for more information.

- (TODO) add better parsing of the config files that is consistent for sft, rm
and rl training.
Expand Down
26 changes: 22 additions & 4 deletions model/model_training/configs/config.yaml
Expand Up @@ -186,7 +186,9 @@ oasst_only:
datasets:
- oasst_export:
lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk"
input_file_path: 2023-04-04_oasst_ready.jsonl.gz
hf_dataset_name: OpenAssistant/oasst1
#input_file_path: 2023-04-12_oasst_ready.trees.jsonl.gz
#top_k: 1
val_split: 0.05
sort_by_length: false
use_custom_sampler: false
Expand All @@ -206,14 +208,28 @@ oasst_export_eu:
datasets:
- oasst_export:
lang: "en,es,de,fr"
input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz
hf_dataset_name: OpenAssistant/oasst1
- gpt4all
- alpaca
- code_alpaca
- oig_file:
source_url: https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl
max_count: 10000
min_length: 100
val_split: 0.1
- oig_file:
source_url: https://huggingface.co/datasets/laion/OIG/raw/main/unified_grade_school_math_instructions.jsonl
val_split: 0.1
min_length: 100
sort_by_length: false
use_custom_sampler: false

oasst_export_latin_cyrillic:
save_strategy: epoch
datasets:
- oasst_export:
lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk"
input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz
hf_dataset_name: OpenAssistant/oasst1
- alpaca
- oig_file:
source_url: https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl
Expand Down Expand Up @@ -364,7 +380,7 @@ llama-30b-sft-6:
datasets:
- oasst_export:
lang: "bg,ca,cs,da,de,en,es,fr,hr,hu,it,nl,pl,pt,ro,ru,sl,sr,sv,uk"
input_file_path: 2023-04-12_oasst_release_ready_synth.jsonl.gz
hf_dataset_name: OpenAssistant/oasst1
val_split: 0.05
- vicuna:
val_split: 0.05
Expand Down Expand Up @@ -712,6 +728,7 @@ galactica-125m:
gradient_accumulation_steps: 2
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
dtype: fp32

gpt-jt:
learning_rate: 8e-6
Expand Down Expand Up @@ -761,3 +778,4 @@ debug:
log_wandb: false
verbose: true
num_train_epochs: 0.2
dtype: fp32
6 changes: 2 additions & 4 deletions model/model_training/configs/config_rm.yaml
Expand Up @@ -49,7 +49,6 @@ oasst-rm-1-pythia-6.9b:
pooling: last
datasets:
- augment_oasst:
#input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl
input_file_path: augmented_latin_cyrillic_oasst_2023-03-27_v2.jsonl
- anthropic_rlhf:
fraction: 0.1
Expand Down Expand Up @@ -98,10 +97,9 @@ oasst-rm-1-pythia-2.8b:
datasets:
- oasst_export:
lang: "en,es,de,fr"
input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz
hf_dataset_name: OpenAssistant/oasst1
val_split: 0.1
- augment_oasst:
#input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl
input_file_path: augmented_latin_cyrillic_oasst_2023-03-27_v2.jsonl
- anthropic_rlhf:
fraction: 0.1
Expand Down Expand Up @@ -142,7 +140,7 @@ oasst-rm-1-pythia-1.4b:
datasets:
- oasst_export:
lang: "en,es,de,fr"
input_file_path: 2023-03-27_oasst_research_ready_synth.jsonl.gz
hf_dataset_name: OpenAssistant/oasst1
val_split: 0.1
- augment_oasst:
input_file_path: augmented_latin_cyrillic_oasst_2023-03-27.jsonl
Expand Down
36 changes: 24 additions & 12 deletions model/model_training/custom_datasets/oasst_dataset.py
@@ -1,8 +1,9 @@
from pathlib import Path
from typing import Literal, Optional
from typing import Iterable, Literal, Optional

from model_training.custom_datasets.formatting import DatasetEntrySft, Role, Utterance
from oasst_data import ExportMessageNode, read_message_trees, visit_threads_depth_first
from oasst_data import ExportMessageNode, read_dataset_message_trees, read_message_trees, visit_threads_depth_first
from oasst_data.schemas import ExportMessageTree
from torch import Generator
from torch.utils.data import Dataset, random_split

Expand All @@ -20,7 +21,8 @@ def __getitem__(self, index):


def load_oasst_export(
input_file_path: str | Path,
input_file_path: Optional[str | Path] = None,
hf_dataset_name: Optional[str] = "OpenAssistant/oasst1",
val_split: float = 0.2,
lang: str = "en",
top_k: Optional[int] = None,
Expand All @@ -31,20 +33,27 @@ def load_oasst_export(
if mode not in ("sft", "rm", "rl"):
raise ValueError(f"Unknown dataset mode: {mode}")

lang_codes = lang.split(",")
lang_codes: list[str] = lang.split(",")

generator = Generator()
generator.manual_seed(manual_seed)

if not isinstance(input_file_path, Path):
input_file_path = Path(input_file_path)
if not input_file_path.is_absolute() and data_path:
if not isinstance(data_path, Path):
data_path = Path(data_path)
input_file_path = data_path / input_file_path
tree_iter: Iterable[ExportMessageTree] = None
if input_file_path:
if not isinstance(input_file_path, Path):
input_file_path = Path(input_file_path)
if not input_file_path.is_absolute() and data_path:
if not isinstance(data_path, Path):
data_path = Path(data_path)
input_file_path = data_path / input_file_path
tree_iter = read_message_trees(input_file_path)
elif hf_dataset_name:
tree_iter = read_dataset_message_trees(hf_dataset_name, split="train+validation")
else:
raise RuntimeError("Either `input_file_path` or `hf_dataset_name` must be specified.")

threads_per_tree = []
for tree in read_message_trees(input_file_path):
for tree in tree_iter:
if tree.tree_state != "ready_for_export" or not tree.prompt.review_result or tree.prompt.lang not in lang_codes:
continue

Expand Down Expand Up @@ -145,6 +154,9 @@ def flatten(ds: ListDataset) -> ListDataset:
train = flatten(splits[0])
val = flatten(splits[1])

print(f"OASST data {str(input_file_path)}: {len(train)=}, {len(val)=}")
if input_file_path:
print(f"OASST JSONL file {str(input_file_path)}: {len(train)=}, {len(val)=}")
else:
print(f"OASST HF dataset {hf_dataset_name}: {len(train)=}, {len(val)=}")

return train, val
11 changes: 10 additions & 1 deletion oasst-data/oasst_data/__init__.py
@@ -1,4 +1,11 @@
from oasst_data.reader import read_message_list, read_message_tree_list, read_message_trees, read_messages
from oasst_data.reader import (
read_dataset_message_trees,
read_dataset_messages,
read_message_list,
read_message_tree_list,
read_message_trees,
read_messages,
)
from oasst_data.schemas import (
ExportMessageEvent,
ExportMessageEventEmoji,
Expand Down Expand Up @@ -33,4 +40,6 @@
"visit_messages_depth_first",
"write_message_trees",
"write_messages",
"read_dataset_message_trees",
"read_dataset_messages",
]
90 changes: 80 additions & 10 deletions oasst-data/oasst_data/reader.py
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Iterable, Optional, TextIO

import pydantic
from datasets import load_dataset

from .schemas import ExportMessageNode, ExportMessageTree

Expand All @@ -17,22 +18,24 @@ def open_jsonl_read(input_file_path: str | Path) -> TextIO:
return input_file_path.open("r", encoding="UTF-8")


def read_oasst_obj(line: str) -> ExportMessageTree | ExportMessageNode:
dict_tree = json.loads(line)
def read_oasst_obj(obj_dict: dict) -> ExportMessageTree | ExportMessageNode:
# validate data
if "message_id" in dict_tree:
return pydantic.parse_obj_as(ExportMessageNode, dict_tree)
elif "message_tree_id" in dict_tree:
return pydantic.parse_obj_as(ExportMessageTree, dict_tree)
if "message_id" in obj_dict:
return pydantic.parse_obj_as(ExportMessageNode, obj_dict)
elif "message_tree_id" in obj_dict:
return pydantic.parse_obj_as(ExportMessageTree, obj_dict)

raise RuntimeError("Unknown object in jsonl file")


def read_oasst_jsonl(input_file_path: str | Path) -> Iterable[ExportMessageTree | ExportMessageNode]:
def read_oasst_jsonl(
input_file_path: str | Path,
) -> Iterable[ExportMessageTree | ExportMessageNode]:
with open_jsonl_read(input_file_path) as file_in:
# read one object per line
for line in file_in:
yield read_oasst_obj(line)
dict_tree = json.loads(line)
yield read_oasst_obj(dict_tree)


def read_message_trees(input_file_path: str | Path) -> Iterable[ExportMessageTree]:
Expand All @@ -42,18 +45,85 @@ def read_message_trees(input_file_path: str | Path) -> Iterable[ExportMessageTre


def read_message_tree_list(
input_file_path: str | Path, filter: Optional[Callable[[ExportMessageTree], bool]] = None
input_file_path: str | Path,
filter: Optional[Callable[[ExportMessageTree], bool]] = None,
) -> list[ExportMessageTree]:
return [t for t in read_message_trees(input_file_path) if not filter or filter(t)]


def convert_hf_message(row: dict) -> None:
emojis = row.get("emojis")
if emojis:
row["emojis"] = dict(zip(emojis["name"], emojis["count"]))
labels = row.get("labels")
if labels:
row["labels"] = {
name: {"value": value, "count": count}
for name, value, count in zip(labels["name"], labels["value"], labels["count"])
}


def read_messages(input_file_path: str | Path) -> Iterable[ExportMessageNode]:
for x in read_oasst_jsonl(input_file_path):
assert isinstance(x, ExportMessageNode)
yield x


def read_message_list(
input_file_path: str | Path, filter: Optional[Callable[[ExportMessageNode], bool]] = None
input_file_path: str | Path,
filter: Optional[Callable[[ExportMessageNode], bool]] = None,
) -> list[ExportMessageNode]:
return [t for t in read_messages(input_file_path) if not filter or filter(t)]


def read_dataset_message_trees(
hf_dataset_name: str = "OpenAssistant/oasst1",
split: str = "train+validation",
) -> Iterable[ExportMessageTree]:
dataset = load_dataset(hf_dataset_name, split=split)

tree_dict: dict = None
parents: list = None
for row in dataset:
convert_hf_message(row)
if row["parent_id"] is None:
if tree_dict:
tree = read_oasst_obj(tree_dict)
assert isinstance(tree, ExportMessageTree)
yield tree

tree_dict = {
"message_tree_id": row["message_id"],
"tree_state": row["tree_state"],
"prompt": row,
}
parents = []
else:
while parents[-1]["message_id"] != row["parent_id"]:
parents.pop()
parent = parents[-1]
if "replies" not in parent:
parent["replies"] = []
parent["replies"].append(row)

row.pop("message_tree_id", None)
row.pop("tree_state", None)
parents.append(row)

if tree_dict:
tree = read_oasst_obj(tree_dict)
assert isinstance(tree, ExportMessageTree)
yield tree


def read_dataset_messages(
hf_dataset_name: str = "OpenAssistant/oasst1",
split: str = "train+validation",
) -> Iterable[ExportMessageNode]:
dataset = load_dataset(hf_dataset_name, split=split)

for row in dataset:
convert_hf_message(row)
message = read_oasst_obj(row)
assert isinstance(message, ExportMessageNode)
yield message
3 changes: 2 additions & 1 deletion oasst-data/pyproject.toml
Expand Up @@ -7,7 +7,8 @@ authors = [
]
dependencies = [
"pydantic>=1.10.4",
"loguru==0.6.0"
"loguru==0.6.0",
"datasets>=2.12.0"
]

[project.optional-dependencies]
Expand Down

0 comments on commit 463d729

Please sign in to comment.