Skip to content

Commit

Permalink
[feature] Add mix conversation augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
theblackcat102 committed Feb 1, 2023
1 parent 638d8c1 commit f8eba68
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 12 deletions.
77 changes: 73 additions & 4 deletions model/supervised_finetuning/custom_datasets/qa_datasets.py
Expand Up @@ -3,6 +3,7 @@
"""
import json
import os
import random
import re
from urllib.request import urlopen

Expand Down Expand Up @@ -115,7 +116,7 @@ class QADataset(Dataset):
"reddit_asks": {"name": "eli5", "index_fn": index_eli5, "split_postfix": "_asks"},
}

def __init__(self, dataset, cache_dir, split):
def __init__(self, dataset, cache_dir, split, mix_prob=0.2):
self.no_val = False
if dataset in self.DATASET_FORMAT_MAPPING:
context = self.DATASET_FORMAT_MAPPING[dataset]
Expand All @@ -137,11 +138,25 @@ def __init__(self, dataset, cache_dir, split):
self.dataset = load_dataset(context["name"], **context["params"])
else:
raise ValueError("Unknown dataset : " + dataset)
self.length = len(self.dataset)
self.mix_prob = mix_prob

def __len__(self):
return len(self.dataset)
return self.length

def __getitem__(self, idx):
if self.mix_prob > 0 and random.random() < self.mix_prob and idx > 5 and idx < (self.length - 5):

additional = random.randint(0, 10) - 5
while additional == idx:
additional = random.randint(0, 10) - 5

answer_pair = self.index_fn(self.dataset[additional + idx])
history_text = "".join(format_pair(answer_pair))
question, answer = self.index_fn(self.dataset[idx])
question = history_text + question
return format_pair((question, answer))

data = self.dataset[idx]
return format_pair(self.index_fn(data))

Expand Down Expand Up @@ -297,8 +312,9 @@ class JokeExplaination(Dataset):
name = "joke"
url = "https://gist.github.com/theblackcat102/42b697e24a13fdb499e20edfbf618361/raw/1834dca207898c15f93b809d1195f6f6e47c9e1e/joke_explained.jsonl"

def __init__(self, cache_dir) -> None:
def __init__(self, cache_dir, mix_prob=0.2) -> None:
super().__init__()
self.mix_prob = mix_prob
os.makedirs(cache_dir, exist_ok=True)
joke_explain_filename = os.path.join(cache_dir, "joke_explaination.jsonl")
if not os.path.exists(joke_explain_filename):
Expand All @@ -319,9 +335,62 @@ def __init__(self, cache_dir) -> None:

if len(question) > 0 and len(answer) > 0:
self.pairs.append((question, answer))
self.length = len(self.pairs)

def __len__(self):
return len(self.pairs)
return self.length

def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5

history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))

return format_pair(self.pairs[index])


class TranslatedQA(Dataset):

name = "oa_translated"

def __init__(self, cache_dir, mix_prob=0.2) -> None:
super().__init__()
self.mix_prob = mix_prob
os.makedirs(cache_dir, exist_ok=True)
path = os.path.join(cache_dir, "oa_translated")
os.makedirs(path, exist_ok=True)
import glob

self.pairs = []
for translated_jsonl in glob.glob(os.path.join(path, "*.jsonl")):
with open(translated_jsonl, "r") as f:
for line in f:
data = json.loads(line)
if "Python " in data["text"]:
continue
# incorrect, TODO: fix later
for convo_round in data["translate"]:
self.pairs.append((convo_round["human"], convo_round["answer"]))

self.length = len(self.pairs)

def __len__(self):
return self.length

def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5

history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))

return format_pair(self.pairs[index])
Expand Up @@ -57,7 +57,7 @@ def index_summary_merge(text, summary):
class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split, max_words=512):
self.name = dataset
if summarization_config_mapping[dataset][0] in ["billsum", "tldr_news"] & split == "validation":
if (dataset in ["billsum", "tldr_news"]) and (split == "validation"):
split = "test"
self.dataset = load_dataset(*summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.text_column, self.summary_column = summarization_name_mapping[dataset]
Expand Down
28 changes: 21 additions & 7 deletions model/supervised_finetuning/custom_datasets/translation.py
Expand Up @@ -75,20 +75,34 @@


class TranslationPair(Dataset):
def __init__(self) -> None:
def __init__(self, mix_prob=0.2) -> None:
super().__init__()
self.pairs = []
self.length = -1
self.mix_prob = mix_prob

def __len__(self):
if self.length < 0:
self.length = len(self.pairs)
return len(self.pairs)

def __getitem__(self, index):
if random.random() < self.mix_prob and index > 5 and index < (self.length - 5):
additional = random.randint(0, 10) - 5
while additional == index:
additional = random.randint(0, 10) - 5

history_text = "".join(format_pair(self.pairs[additional + index]))
question, answer = self.pairs[index]
question = history_text + question
return format_pair((question, answer))

return format_pair(self.pairs[index])


class WMT2019(TranslationPair):
def __init__(self, pair="zh-en", split="train") -> None:
super().__init__()
def __init__(self, pair="zh-en", split="train", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("wmt19", pair)[split]
self.pairs = []
src, tgt = pair.split("-")
Expand All @@ -108,8 +122,8 @@ class DiveMT(TranslationPair):

REMAP = {"tur": "tr", "ita": "it", "ukr": "uk", "nld": "nl", "vie": "vi", "ara": "ar"}

def __init__(self, split="train") -> None:
super().__init__()
def __init__(self, split="train", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("GroNLP/divemt", "main")[split]
tgt, src = "tgt_text", "src_text"
for row in dataset:
Expand All @@ -131,8 +145,8 @@ def __init__(self, split="train") -> None:
class TEDTalk(TranslationPair):
# NOTE: DO NOT use chinese pair, mix with traditional and cantonese, not clean

def __init__(self, pair="de-ja", split="train", year="2016") -> None:
super().__init__()
def __init__(self, pair="de-ja", split="train", year="2016", mix_prob=0.2) -> None:
super().__init__(mix_prob=mix_prob)
dataset = load_dataset("ted_talks_iwslt", language_pair=pair.split("-"), year=year)[split]
src, tgt = pair.split("-")
for row in dataset:
Expand Down

0 comments on commit f8eba68

Please sign in to comment.