Skip to content

Commit

Permalink
Add alpaca reverse augmentation possibility (#2342)
Browse files Browse the repository at this point in the history
closes #2335

- Add possiblity to reverse question and response for alpaca dataset
- add debug configs to show how to used

Alpaca with reverse augmentation can be run with:
```bash
python trainer_sft.py --configs pythia-70m-deduped
```
I couldn't run this due to unscale gradient errors, so before this gets
merged we should do a run to check whether this reversal really improves
the loss.

---------

Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
  • Loading branch information
CloseChoice and andreaskoepf committed Apr 7, 2023
1 parent f1e7437 commit f8c1cd2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 28 deletions.
3 changes: 2 additions & 1 deletion model/model_training/configs/config.yaml
Expand Up @@ -256,8 +256,9 @@ llama-30b:
save_total_limit: 4
use_flash_attention: true

pythia:
pythia-70m-deduped:
learning_rate: 8e-6
# model_name: EleutherAI/pythia-1b-deduped
model_name: EleutherAI/pythia-70m-deduped
weight_decay: 0.0
max_length: 520
Expand Down
9 changes: 3 additions & 6 deletions model/model_training/custom_datasets/__init__.py
Expand Up @@ -10,14 +10,13 @@
from model_training.custom_datasets.prompt_dialogue import Gpt4All, load_oig_file
from model_training.custom_datasets.qa_datasets import (
SODA,
Alpaca,
CodeAlpaca,
JokeExplaination,
QADataset,
SODADialogue,
TranslatedQA,
Vicuna,
WebGPT,
load_alpaca_dataset,
)
from model_training.custom_datasets.rank_datasets import AugmentedOA
from model_training.custom_datasets.summarization import HFSummary, HFSummaryPairs, SummarizationDataset
Expand Down Expand Up @@ -118,10 +117,8 @@ def get_one_dataset(
dataset = DiveMT()
elif dataset_name == "webgpt":
dataset = WebGPT(mode=mode)
elif dataset_name == "alpaca":
dataset = Alpaca(mode=mode, cache_dir=data_path)
elif dataset_name == "code_alpaca":
dataset = CodeAlpaca(mode=mode, cache_dir=data_path)
elif dataset_name in ("alpaca", "code_alpaca"):
train, eval = load_alpaca_dataset(dataset_name, val_split=val_split, cache_dir=data_path, **kwargs)
elif dataset_name == "gpt4all":
dataset = Gpt4All(mode=mode, cache_dir=data_path)
elif dataset_name == "prosocial_dialogue":
Expand Down
76 changes: 55 additions & 21 deletions model/model_training/custom_datasets/qa_datasets.py
Expand Up @@ -12,7 +12,8 @@

import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset
from torch import Generator
from torch.utils.data import Dataset, Subset, random_split

# @agoryuno contributed this
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
Expand Down Expand Up @@ -420,40 +421,73 @@ def __getitem__(self, index):
return self.pairs[index]


class AlpacaBase(Dataset):
def __init__(self, dataset_name: str, mode: str, cache_dir: str = None) -> None:
class AlpacaBaseDataset(Dataset):
def __init__(self, data: list, mode: str):
super().__init__()
self.data = data
if mode not in ("sft", "rl"):
raise NotImplementedError(
f"Alpaca Dataset for mode {self.mode} is not implemented. Currently supported modes are 'sft' and 'rl'."
)
self.mode = mode
dataset = load_dataset(dataset_name, cache_dir=cache_dir)
rows = []
for row in dataset["train"]:
question = row["instruction"]
if len(row["input"]) > 0:
input_ = "{}\n{}".format(question, row["input"])
else:
input_ = question
rows.append((input_, row["output"]))
self.rows = rows

def __len__(self):
return len(self.rows)
return len(self.data)

def __getitem__(self, index):
question, answer = self.rows[index]
question, answer = self.data[index]
if self.mode == "sft":
return (question, answer)
elif self.mode == "rl":
return (question,)


class Alpaca(AlpacaBase):
def __init__(self, mode: str = "sft", cache_dir: str = None) -> None:
super().__init__(dataset_name="yahma/alpaca-cleaned", mode=mode, cache_dir=cache_dir)
def load_alpaca_dataset(
dataset_name: str,
val_split: float,
cache_dir: str,
mode: str = "sft",
manual_seed: int = 287631038922,
reverse_augmentation: bool = False,
keep_unreversed: bool = True,
) -> tuple[AlpacaBaseDataset, AlpacaBaseDataset]:
generator = Generator()
generator.manual_seed(manual_seed)

def process_split(
dataset: Subset, reverse_augmentation: bool = False, keep_unreversed: bool = True
) -> list[tuple[str, str]]:
data = []
for row in dataset:
question = row["instruction"]
if len(row["input"]) > 0:
input_ = "{}\n{}".format(question, row["input"])
else:
input_ = question
if reverse_augmentation:
data.append((row["output"], input_))
# in case of reverse augmentation we just keep both, reversed and unreversed data
if keep_unreversed:
data.append((input_, row["output"]))
else:
data.append((input_, row["output"]))
return data

if dataset_name == "alpaca":
dataset = load_dataset("yahma/alpaca-cleaned", cache_dir=cache_dir)
elif dataset_name == "code_alpaca":
dataset = load_dataset("sahil2801/CodeAlpaca-20k", cache_dir=cache_dir)
else:
raise ValueError(f"Expected dataset_name to be 'alapaca' or 'code_alpaca'. Received {dataset_name}.")

class CodeAlpaca(AlpacaBase):
def __init__(self, mode: str = "sft", cache_dir: str = None) -> None:
super().__init__(dataset_name="sahil2801/CodeAlpaca-20k", mode=mode, cache_dir=cache_dir)
splits = random_split(dataset["train"], lengths=[1.0 - val_split, val_split], generator=generator)
train = AlpacaBaseDataset(
process_split(splits[0], reverse_augmentation=reverse_augmentation, keep_unreversed=keep_unreversed), mode=mode
)
val = AlpacaBaseDataset(
process_split(splits[1], reverse_augmentation=False, keep_unreversed=keep_unreversed), mode=mode
)
return train, val


class Vicuna(Dataset):
Expand Down

0 comments on commit f8c1cd2

Please sign in to comment.