Skip to content

Commit

Permalink
Add dataset loader for MegaCodeTraining112k & Evol-Instruct-Code-80k-…
Browse files Browse the repository at this point in the history
…v1 (#3605)

Added code to load `rombodawg/MegaCodeTraining112k` (key: megacode) and
`nickrosh/Evol-Instruct-Code-80k-v1` (key: evol_instruct_code).
Also added an optional `fill_min_length` parameter to
`InstructionDataset` class. If specified instructions are concatenate
until the total string length of prompts and completions exceeds
`fill_min_length`. Seed for random order can optionally be specified
(default: 42).

Example:
```
  datasets:
    - megacode:
        fill_min_length: 24000
    - evol_instruct_code:
        fill_min_length: 24000
```

- updated transformers dependency to `==4.31.0`
  • Loading branch information
andreaskoepf committed Jul 25, 2023
1 parent 0c2fa4c commit bc5b70d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 18 deletions.
69 changes: 52 additions & 17 deletions model/model_training/custom_datasets/instruction.py
@@ -1,6 +1,9 @@
"""
These are in the form of 'INSTRUCTION', 'RESPONSE'
"""
import random
from typing import Optional

from datasets import load_dataset
from model_training.custom_datasets.formatting import DatasetEntry, create_dataset_entry_qa
from model_training.custom_datasets.utils import _filter_by_words
Expand All @@ -25,49 +28,80 @@
"oa_stackexchange": "donfu/oa-stackexchange",
"tell_a_joke": "mikegarts/oa_tell_a_joke_20000",
"wizardlm_70k": "ehartford/WizardLM_alpaca_evol_instruct_70k_unfiltered",
"megacode": "rombodawg/MegaCodeTraining112k",
"evol_instruct_code": "nickrosh/Evol-Instruct-Code-80k-v1",
}


class InstructionDataset(Dataset):
def __init__(self, dataset, cache_dir, split, mode="sft"):
def __init__(self, dataset, cache_dir, split, mode="sft", fill_min_length: Optional[int] = None, seed: int = 42):
assert mode in ("sft", "rl")
self.name = dataset
self.mode = mode
data_files = None
if dataset == "minimath":
self.instruction_column = "question"
self.response_column = "answer"
elif dataset == "wizardlm_70k":
elif dataset in ("wizardlm_70k", "evol_instruct_code"):
self.instruction_column = "instruction"
self.response_column = "output"
elif dataset == "megacode":
self.instruction_column = "prompt"
self.response_column = "completion"
data_files = "RombosCodeTraining112k.json"
else:
self.instruction_column = "INSTRUCTION"
self.response_column = "RESPONSE"

ds = load_dataset(INSTRUCTION_DATASETS[dataset], cache_dir=cache_dir, split=split)
self.dataset = []
num_invalid = 0
for i in range(len(ds)):
data = ds[i]

ds = load_dataset(INSTRUCTION_DATASETS[dataset], cache_dir=cache_dir, split=split, data_files=data_files)
self.dataset: list[tuple[list[str], list[str]]] = []

questions, answers = [], []
item_len = 0

rng = random.Random(seed)
order = list(range(len(ds)))
rng.shuffle(order)

# filter entries and optionally combine multiple entries
for i in order:
entry = ds[i]
q = entry[self.instruction_column]
a = entry[self.response_column]
if (
data[self.instruction_column] is not None
and len(data[self.instruction_column].strip()) > 0
and data[self.response_column] is not None
and len(data[self.response_column].strip()) > 0
and _filter_by_words(data[self.instruction_column])
and _filter_by_words(data[self.response_column])
q is not None
and len(q.strip()) > 0
and a is not None
and len(a.strip()) > 0
and _filter_by_words(q)
and _filter_by_words(a)
):
self.dataset.append(data)
questions.append(q)
answers.append(a)
item_len += len(a) + len(q)

if fill_min_length is None or fill_min_length < item_len:
self.dataset.append((questions, answers))
item_len = 0
questions, answers = [], []
else:
num_invalid += 1

if len(questions) > 0 and len(answers) > 0:
self.dataset.append((questions, answers))

if num_invalid > 0:
print(f"[Warning] {num_invalid} entries of {dataset} were invalid.")

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx) -> DatasetEntry:
data = self.dataset[idx]
lang = None
questions, answers = self.dataset[idx]

lang: str | None = None
# use "en" for datasets which have more than 95% English messages
if self.name in [
"humaneval_mbpp_codegen_qa",
Expand All @@ -78,9 +112,10 @@ def __getitem__(self, idx) -> DatasetEntry:
"tell_a_joke",
]:
lang = "en"

return create_dataset_entry_qa(
mode=self.mode,
questions=[data[self.instruction_column]],
answers=[data[self.response_column]],
questions=questions,
answers=answers,
lang=lang,
)
2 changes: 1 addition & 1 deletion model/pyproject.toml
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"langcodes==3.3.0",
"tqdm>=4.65.0",
"pydantic==1.10.7",
"transformers @ git+https://github.com/huggingface/transformers.git@e4a52b6a1536b1d9ef1ac55168bc4fede25605bc",
"transformers==4.31.0",
"wandb>=0.15.5",
]

Expand Down

0 comments on commit bc5b70d

Please sign in to comment.