Skip to content

Commit

Permalink
add check dataset appearances, update vicuna to dataset entry (#2818)
Browse files Browse the repository at this point in the history
- add a script which can be used to check if any dataset consists
specific regular expressions or words
- update vicuna so that it can be used with the new `DatasetEntry` class
- remove single references from vicuna (so `[1]` is removed, but I found
with the script mentioned above that there are a couple of occurances
where our `re_reference_remove` regex hits, but it is actually a list
(for language like e.g. python), therefore just remove single
references)
- if human response is none, then sample from multiple answers like
`['please continue', '...']`

This PR depends on #2809
  • Loading branch information
CloseChoice committed Apr 21, 2023
1 parent 22ea9d9 commit cd3f07f
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 17 deletions.
109 changes: 109 additions & 0 deletions model/model_training/check_dataset_appearances.py
@@ -0,0 +1,109 @@
"""
This script should help to detect any keywords or other unwanted appearances in the datasets
RUN WITH:
python check_dataset_appearances.py -d <datasets> --cache_dir <path-to-cache-dir> --mode <one of sft, rm, rl>
e.g.:
python check_dataset_appearances.py -d gpt4all webgpt --cache_dir .cache --mode sft
"""
import argparse
import pprint
from collections import defaultdict

from model_training.custom_datasets import get_one_dataset
from model_training.custom_datasets.entities import Mode
from model_training.custom_datasets.formatting import DatasetEntry
from model_training.custom_datasets.qa_datasets import (
re_reference_remove,
re_single_reference_remove,
re_whitespace_newline_match,
)
from model_training.custom_datasets.utils import FILTER_BY_WORDS

RE_TO_CHECK = [re_whitespace_newline_match, re_reference_remove, re_single_reference_remove]
STRINGS_TO_CHECK = [*FILTER_BY_WORDS]


def argument_parsing():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--datasets",
nargs="+",
required=True,
help="""
Multiple datasets can be passed to set different options.
For example, run as:
./check_dataset_counts.py --datasets math oasst_export_eu
to check the counts of the math and the oasst_export_eu dataset.
""",
)
parser.add_argument("--mode", dest="mode", type=Mode, choices=list(Mode))
parser.add_argument("--cache_dir", dest="cache_dir", type=str)

args, _ = parser.parse_known_args()

return args


def check_in_dataset_row(row: str | list[str] | tuple[str], matched=dict[str, list]):
def _check_single_string(row: str, matched: dict[str, list]) -> dict[str, list]:
for exp in RE_TO_CHECK:
if exp.match(row) is not None:
matched[exp].append(row)
for string in STRINGS_TO_CHECK:
if string in row:
string_idx = row.index(string)
matched[string].append(row[max(string_idx - 50, 0) : string_idx + 50])
return matched

if isinstance(row, str):
matched = _check_single_string(row, matched)
elif isinstance(row, (list, tuple)):
for r in row:
if not isinstance(r, str):
raise ValueError(f"Unexpected type: {type(row)}")
matched = _check_single_string(r, matched)
elif isinstance(row, DatasetEntry):
formatted = row.get_formatted(mode=args.mode, eos_token="</s>")
for r in formatted:
if not isinstance(r, str):
raise ValueError(f"Unexpected type: {type(r)}")
matched = _check_single_string(
r.replace("<|assistant|>", "").replace("<|prompter|>", "").replace("</s>", ""), matched
)
else:
raise ValueError(f"Received unexpected type: {type(row)}.")
return matched


def iterate_over_dataset(ds):
matched = defaultdict(list)
for row in ds:
check_in_dataset_row(row, matched)
return matched


if __name__ == "__main__":
args = argument_parsing()
pp = pprint.PrettyPrinter(indent=4)

train_datasets, val_datasets = {}, {}
for dataset_name in args.datasets:
print(f"start with dataset {dataset_name}")
train, val = get_one_dataset(None, dataset_name, mode=args.mode.value, data_path=args.cache_dir)
train_datasets[dataset_name] = train
if val is not None:
val_datasets[dataset_name] = val
matched_train = iterate_over_dataset(train)
matched_val = iterate_over_dataset(val)
if len(matched_train) != 0:
pp.pprint(f"Found the following occurances in TRAIN {dataset_name}:")
pp.pprint(dict(matched_train))
if len(matched_val) != 0:
pp.pprint(f"Found the following occurances in VAL {dataset_name}:")
pp.pprint(dict(matched_val))
if len(matched_train) + len(matched_val) == 0:
print("Did not find of the specified regular expressions or filter words.")
45 changes: 28 additions & 17 deletions model/model_training/custom_datasets/qa_datasets.py
Expand Up @@ -20,6 +20,8 @@

# @agoryuno contributed this
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
re_single_reference_remove = re.compile(r"\[\s?\d+\s?\]")
re_whitespace_newline_match = re.compile(r"^[\s\n]*$")


LINKING_CHARS = ["\n", "\n\n", " "]
Expand Down Expand Up @@ -486,35 +488,49 @@ class Vicuna(Dataset):
name = "vicuna"

@staticmethod
def process_vicuna_conversations(data: list[dict[str, None | str]], input_max_length: int) -> list[str] | None:
dialogue = []
def process_vicuna_conversations(
data: list[dict[str, None | str]], input_max_length: int
) -> tuple[list[str], list[str]] | None:
role = None
messages = []
# drop conversations that start with Bot
if len(data["conversations"]) == 0 or data["conversations"][0]["from"] != "human":
return None
questions = []
answers = []
for line in data["conversations"]:
speaker = line["from"] # 'human' or 'gpt'
message = line["value"]

if message is None or message == "":
if speaker == "gpt":
return None
elif speaker == "human":
# replace empty messages with one of the following
message = random.choice(["...", "Please continue", "Go on", ""])
# remove markdown escaping in revision 192ab2185289094fc556ec8ce5ce1e8e587154ca
# python-markdownify with escape_asterisks & escape_underscores True is used
# for pre-processing the dataset.
# See also https://github.com/LAION-AI/Open-Assistant/issues/2510
message = message.replace(r"\_", "_")
message = message.replace(r"\*", "*")
message = re_single_reference_remove.sub("", message)

if role != speaker:
if role is not None:
dialogue.append("\n".join(messages))
if role == "human":
questions.append("\n".join(messages)[:input_max_length])
if role == "gpt":
answers.append("\n".join(messages)[:input_max_length])
messages = []
role = speaker
messages.append(message.strip())

if role is not None and len(messages) > 0:
dialogue.append("\n".join(messages))
dialogue_truncated = [k[:input_max_length] for k in dialogue]
return dialogue_truncated
if role == "human":
questions.append("\n".join(messages)[:input_max_length])
if role == "gpt":
answers.append("\n".join(messages)[:input_max_length])
return questions, answers

def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: int = 2048) -> None:
super().__init__()
Expand All @@ -530,20 +546,15 @@ def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: i
revision="192ab2185289094fc556ec8ce5ce1e8e587154ca",
)["train"]
for data in dataset:
if (
processed_data := self.process_vicuna_conversations(data, input_max_length=input_max_length)
) is not None:
self.pairs.append(processed_data)
if (qa := self.process_vicuna_conversations(data, input_max_length=input_max_length)) is not None:
self.pairs.append(DatasetEntry(questions=qa[0], answers=qa[1]))

def __len__(self) -> int:
return len(self.pairs)

def __getitem__(self, index: int) -> list[str] | tuple[str]:
dialogue: list[str] = self.pairs[index]
if self.mode == "sft":
return dialogue
elif self.mode == "rl":
return tuple(dialogue[:-1])
def __getitem__(self, index: int) -> DatasetEntry:
dialogue = self.pairs[index]
return dialogue


class DatabricksDolly15k(Dataset):
Expand Down

0 comments on commit cd3f07f

Please sign in to comment.