Skip to content

Commit

Permalink
[feature] Add rallio new instruction dataset v3
Browse files Browse the repository at this point in the history
  • Loading branch information
theblackcat102 committed Feb 6, 2023
1 parent 0be4d88 commit 7421615
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 41 deletions.
7 changes: 5 additions & 2 deletions model/supervised_finetuning/custom_datasets/__init__.py
@@ -1,7 +1,7 @@
"""
High level functions for model training
"""
from custom_datasets.prompt_dialogue import InstructionTuning, PromptGeneratedDataset
from custom_datasets.prompt_dialogue import InstructionTuning, PrivateInstructionTuning, PromptGeneratedDataset
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, TranslatedQA, WebGPT
from custom_datasets.summarization import SummarizationDataset
from custom_datasets.toxic_conversation import ProsocialDialogue, ProsocialDialogueExplaination
Expand Down Expand Up @@ -32,7 +32,7 @@
"debate_sum",
"tldr_news",
]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning"]
OTHER = ["prosocial_dialogue", "explain_prosocial", "instruct_tuning", "private_tuning"]


def train_val_dataset(dataset, val_split=0.2):
Expand Down Expand Up @@ -92,6 +92,9 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name == "instruct_tuning":
dataset = InstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "private_tuning":
dataset = PrivateInstructionTuning(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
elif dataset_name == "translate_qa":
dataset = TranslatedQA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.01)
Expand Down
44 changes: 43 additions & 1 deletion model/supervised_finetuning/custom_datasets/prompt_dialogue.py
Expand Up @@ -2,7 +2,7 @@
import os
from urllib.request import urlopen

from custom_datasets.formatting import format_pair
from custom_datasets.formatting import QA_SPECIAL_TOKENS, format_pair
from torch.utils.data import Dataset


Expand Down Expand Up @@ -102,3 +102,45 @@ def __len__(self):

def __getitem__(self, index):
return format_pair(self.pairs[index])


class PrivateInstructionTuning(Dataset):
"""
We have seen some promising capabilities from instruction tuning
with the following mix of datasets that are derived from datasets
available online.
The files for this data are in json format as a list of tuples
where each tuple is (source,instruction_response_pair)
Not to be confused with unatural instruction
"""

name = "private_tuning"
filename = "oa_v3_fixed_plus_safety.jsonl"

def __init__(self, cache_dir) -> None:
super().__init__()
os.makedirs(cache_dir, exist_ok=True)

self.pairs = []
for file_link in [self.filename]:
basename = file_link.split("/")[-1]
instruction_tune_file = os.path.join(cache_dir, basename)

with open(instruction_tune_file, "r", encoding="utf-8") as f:
for line in f:
row = json.loads(line)
prefix = ""
for _, convo in enumerate(row["text"].split("User:")):
if "Assistant" in convo:
prompt, answer = convo.split("Assistant:", maxsplit=1)
answer = answer.replace("<|endoftext|>", "").strip()
self.pairs.append((prefix + QA_SPECIAL_TOKENS["Question"] + prompt, answer))
prefix += "".join(format_pair((prompt, answer)))

def __len__(self):
return len(self.pairs)

def __getitem__(self, index):
prompt, answer = self.pairs[index]
return "{}{}".format(prompt, QA_SPECIAL_TOKENS["Answer"]), answer
41 changes: 3 additions & 38 deletions model/supervised_finetuning/custom_datasets/qa_datasets.py
Expand Up @@ -3,7 +3,6 @@
"""
import json
import os
import random
import re
from urllib.request import urlopen

Expand Down Expand Up @@ -116,7 +115,7 @@ class QADataset(Dataset):
"reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"},
}

def __init__(self, dataset, cache_dir, split, mix_prob=0.2):
def __init__(self, dataset, cache_dir, split):
self.no_val = False
if dataset in self.DATASET_FORMAT_MAPPING:
context = self.DATASET_FORMAT_MAPPING[dataset]
Expand All @@ -139,23 +138,11 @@ def __init__(self, dataset, cache_dir, split, mix_prob=0.2):
else:
raise ValueError("Unknown dataset : " + dataset)
self.length = len(self.dataset)
self.mix_prob = mix_prob

def __len__(self):
return self.length

def __getitem__(self, idx):
if self.mix_prob > 0 and random.random() < self.mix_prob and idx > 5 and idx < (self.length - 5):

additional = random.randint(0, 10) - 5
while additional == idx:
additional = random.randint(0, 10) - 5

answer_pair = self.index_fn(self.dataset[additional + idx])
history_text = "".join(format_pair(answer_pair))
question, answer = self.index_fn(self.dataset[idx])
question = history_text + question
return format_pair((question, answer))

data = self.dataset[idx]
return format_pair(self.index_fn(data))
Expand Down Expand Up @@ -312,9 +299,8 @@ class JokeExplaination(Dataset):
name = "joke"
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"

def __init__(self, cache_dir, mix_prob=0.2) -> None:
def __init__(self, cache_dir) -> None:
super().__init__()
self.mix_prob = mix_prob
os.makedirs(cache_dir, exist_ok=True)
joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl")
if not os.path.exists(joke_explain_filename):
Expand All @@ -341,26 +327,15 @@ def __len__(self):
return self.length

def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5

history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))

return format_pair(self.pairs[index])


class TranslatedQA(Dataset):

name = "oa_translated"

def __init__(self, cache_dir, mix_prob=0.2) -> None:
def __init__(self, cache_dir) -> None:
super().__init__()
self.mix_prob = mix_prob
os.makedirs(cache_dir, exist_ok=True)
path = os.path.join(cache_dir, "oa_translated")
os.makedirs(path, exist_ok=True)
Expand All @@ -383,14 +358,4 @@ def __len__(self):
return self.length

def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5

history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))

return format_pair(self.pairs[index])

0 comments on commit 7421615

Please sign in to comment.