Skip to content

Commit

Permalink
Long context datasets (#3646)
Browse files Browse the repository at this point in the history
## What
Added support for RAG based dataloaders 
Currently supports 
- shahules786/Multi-chapter-summaries
  • Loading branch information
shahules786 committed Aug 9, 2023
1 parent 9a208ea commit c4c9f37
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
9 changes: 8 additions & 1 deletion model/model_training/custom_datasets/__init__.py
Expand Up @@ -5,7 +5,12 @@

import numpy as np
from model_training.custom_datasets.extra_rm_datasets import load_anthropic_rlhf, load_hellaswag, load_shp
from model_training.custom_datasets.instruction import INSTRUCTION_DATASETS, InstructionDataset
from model_training.custom_datasets.instruction import (
INSTRUCTION_DATASETS,
RAG_DATASETS,
InstructionDataset,
RAGDataset,
)
from model_training.custom_datasets.oasst_dataset import load_oasst_export
from model_training.custom_datasets.pretrain_datasets import FanFics, RedPajama
from model_training.custom_datasets.prompt_dialogue import DolphinMix, Gpt4All, OrcaChat, load_oig_file
Expand Down Expand Up @@ -181,6 +186,8 @@ def get_one_dataset(
dataset = OrcaChat(cache_dir=data_path, **kwargs)
elif dataset_name == "dolphin-mix":
dataset = DolphinMix(cache_dir=data_path, **kwargs)
elif dataset_name in RAG_DATASETS.keys():
dataset = RAGDataset(dataset_name, cache_dir=data_path, **kwargs)
else:
raise ValueError(f"Unknown dataset {dataset_name}")

Expand Down
29 changes: 29 additions & 0 deletions model/model_training/custom_datasets/instruction.py
Expand Up @@ -124,3 +124,32 @@ def __getitem__(self, idx) -> DatasetEntry:
answers=answers,
lang=lang,
)


RAG_DATASETS = {
"multi-chapter-summaries": "shahules786/Multi-chapter-summaries",
}


class RAGDataset(Dataset):
def __init__(
self,
dataset,
split: str = "train",
cache_dir: str = ".cache/",
):
if dataset not in RAG_DATASETS.keys():
raise ValueError(f"Invalid dataset {dataset}")

if dataset == "multi-chapter-summaries":
self.prompt, self.context, self.response = "prompt", "context", "summary"

self.dataset = load_dataset(RAG_DATASETS[dataset], cache_dir=cache_dir)[split]

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
prompt, context, response = [self.dataset[idx][key] for key in [self.prompt, self.context, self.response]]

return create_dataset_entry_qa(mode="sft", questions=[prompt + context], answers=[response])

0 comments on commit c4c9f37

Please sign in to comment.