Skip to content

Commit

Permalink
Feature/add rm mode to dataset entry (#2867)
Browse files Browse the repository at this point in the history
works towards #2819

---------

Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
  • Loading branch information
CloseChoice and andreaskoepf committed Apr 24, 2023
1 parent 28ae5ef commit bcc9360
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 40 deletions.
63 changes: 41 additions & 22 deletions model/model_training/custom_datasets/formatting.py
Expand Up @@ -64,36 +64,55 @@ def system_tag(self, eos_token: str) -> str | None:
system_tag = f"{QA_SPECIAL_TOKENS['System']}{system_tag_key_values}\n{eos_token}"
return system_tag

def get_formatted(self, mode: Mode, eos_token: str) -> str | list[str]:
def _get_formatted_rm(self, eos_token: str, max_replies: str, system_tag: None | str):
assert len(self.answers) > 1
answers = self.answers[:max_replies]
match len(self.questions):
case 0:
question = ""
# todo: not sure if this case is correct but it is equivalent to current non-dataset entry behaviour
answers = [f"{a}{eos_token}" for a in answers]
case 1:
question = f"{QA_SPECIAL_TOKENS['Question']}{self.questions[0]}{eos_token}"
answers = [f"{QA_SPECIAL_TOKENS['Answer']}{a}{eos_token}" for a in answers]
case _:
raise ValueError("Received more than one question in RM mode. This is unexpected. Aborting")
if system_tag is not None:
question = f"{system_tag}{question}"
return (question, answers) # NotImplementedError("This is currently not implemented.")

def get_formatted(self, mode: Mode, eos_token: str, **kwargs) -> str | list[str] | tuple[str, list[str]]:
system_tag = self.system_tag(eos_token)
if mode == Mode.rl:
if system_tag is not None:
return f"{system_tag}{QA_SPECIAL_TOKENS['Question']}{self.questions[0]}{QA_SPECIAL_TOKENS['Answer']}"
else:
return f"{QA_SPECIAL_TOKENS['Question']}{self.questions[0]}{QA_SPECIAL_TOKENS['Answer']}"
if system_tag is not None:
qa_list = [system_tag]
elif mode == Mode.rm:
return self._get_formatted_rm(
eos_token=eos_token, max_replies=kwargs.get("max_replies", 5), system_tag=system_tag
)
else:
qa_list = list()
for q, a in zip_longest(self.questions, self.answers):
match (q, a):
case (str(), str()):
qa_list.extend(
[
f"{QA_SPECIAL_TOKENS['Question']}{q}{eos_token}",
f"{QA_SPECIAL_TOKENS['Answer']}{a}{eos_token}",
]
)
case (str(), None):
qa_list.append(f"{QA_SPECIAL_TOKENS['Question']}{q}{eos_token}")
case (None, None):
break
case (None, str()):
raise ValueError("Received answer without getting corresponding question. Aborting")
if mode == Mode.sft:
if system_tag is not None:
qa_list = [system_tag]
else:
qa_list = list()
for q, a in zip_longest(self.questions, self.answers):
match (q, a):
case (str(), str()):
qa_list.extend(
[
f"{QA_SPECIAL_TOKENS['Question']}{q}{eos_token}",
f"{QA_SPECIAL_TOKENS['Answer']}{a}{eos_token}",
]
)
case (str(), None):
qa_list.append(f"{QA_SPECIAL_TOKENS['Question']}{q}{eos_token}")
case (None, None):
break
case (None, str()):
raise ValueError("Received answer without getting corresponding question. Aborting")
return qa_list
elif mode == Mode.rm:
raise NotImplementedError("This is currently not implemented.")

@classmethod
def create_from_prompter_assistant_interplay(cls, qa: dict[str, str]):
Expand Down
43 changes: 26 additions & 17 deletions model/model_training/custom_datasets/ranking_collator.py
@@ -1,7 +1,9 @@
from dataclasses import dataclass
from typing import Optional, Union

from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from model_training.custom_datasets.entities import Mode
from model_training.custom_datasets.formatting import DatasetEntry
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase

from .formatting import format_pairs, format_reply

Expand All @@ -19,25 +21,30 @@ class RankingDataCollator:
pad_to_multiple_of: Optional[int] = None
max_replies: Optional[int] = 5

def process_one(self, example, return_length=False):
messages, replies = example

if self.max_replies:
assert self.max_replies > 1, "max_replies parameter must be > 1 or None"
if len(replies) > self.max_replies:
replies = replies[: self.max_replies]

def process_one(
self, example: tuple[str | list[str] | None, list[str]] | DatasetEntry, return_length: int = False
) -> list[BatchEncoding]:
assert self.tokenizer.eos_token
eos = self.tokenizer.eos_token

if messages is None or len(messages) == 1 and messages[0] is None:
# special handling for non-dialogue datasets like Hellaswag
prefix = ""
replies = [r + eos for r in replies]
if isinstance(example, DatasetEntry):
prefix, replies = example.get_formatted(mode=Mode.rm, eos_token=eos)
else:
# append eos token to each messages
prefix = "".join(format_pairs(messages, eos_token=eos))
replies = [format_reply(r, eos_token=eos) for r in replies]
messages, replies = example

if self.max_replies:
assert self.max_replies > 1, "max_replies parameter must be > 1 or None"
if len(replies) > self.max_replies:
replies = replies[: self.max_replies]

if messages is None or len(messages) == 1 and messages[0] is None:
# special handling for non-dialogue datasets like Hellaswag
prefix = ""
replies = [r + eos for r in replies]
else:
# append eos token to each messages
prefix = "".join(format_pairs(messages, eos_token=eos))
replies = [format_reply(r, eos_token=eos) for r in replies]

prefix_tokens = self.tokenizer(prefix, padding=False, truncation=False)
reply_tokens = [self.tokenizer(r, padding=False, truncation=False) for r in replies]
Expand All @@ -60,7 +67,9 @@ def process_one(self, example, return_length=False):

return reply_tokens

def __call__(self, examples):
def __call__(
self, examples: list[tuple[str | list[str] | None, list[str]]] | list[DatasetEntry]
) -> tuple[list[BatchEncoding], list[int]]:
flat_tokenized, cu_lens = [], [0]
n_samples = 0
for example in examples:
Expand Down
150 changes: 149 additions & 1 deletion model/model_training/tests/test_ranking_collator.py
@@ -1,10 +1,158 @@
from argparse import Namespace

import pytest
import torch
from model_training.custom_datasets import get_one_dataset
from model_training.custom_datasets.formatting import QA_SPECIAL_TOKENS, DatasetEntry
from model_training.custom_datasets.ranking_collator import RankingDataCollator
from model_training.utils.utils import get_tokenizer
from model_training.utils.utils import get_tokenizer, match_tokenizer_name
from torch.utils.data import DataLoader
from transformers.models.auto.tokenization_auto import AutoTokenizer


@pytest.fixture
def pythia_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("tests/resources/data_collator", local_files_only=True)
# for this test we use the pythia special tokens but note that this test is model agnostic
tokenizer_config = match_tokenizer_name("pythia")

tokenizer.add_special_tokens(
{
"pad_token": tokenizer_config.special_tokens.pad_token,
"eos_token": tokenizer_config.special_tokens.eos_token,
"sep_token": tokenizer_config.special_tokens.sep_token,
}
)

additional_special_tokens = list(QA_SPECIAL_TOKENS.values())

tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
return tokenizer


def test_ranking_collator_system_tag(pythia_tokenizer):
first_example = DatasetEntry(
questions=["First instruction."],
answers=["Answer to first instruction.", "Answer to first instruction."],
lang="en",
quality=0.7,
)
second_example = DatasetEntry(
questions=["Second instruction."],
answers=["Answer to second instruction.", "Answer to second instruction."],
humor=0.1,
length=1000,
)
examples = [first_example, second_example]
rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
batch, cu_lens = rdc(examples=examples)
assert len(batch) == 2
assert cu_lens == [0, len(first_example.answers), len(first_example.answers) + len(second_example.answers)]
assert batch.data["attention_mask"].shape[0] == 4 # we have 5 replies in total
assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape
eos = pythia_tokenizer.eos_token

# check each instruction
first_example_first_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][0])
f"{QA_SPECIAL_TOKENS['Question']}{first_example.questions[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_example.answers[0]}{eos}" in first_example_first_answer_decoded
"lang: en" in first_example_first_answer_decoded
"quality: 0.7" in first_example_first_answer_decoded

first_example_second_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][1])
f"{QA_SPECIAL_TOKENS['Question']}{first_example.questions[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_example.answers[1]}{eos}" in first_example_second_answer_decoded
"lang: en" in first_example_second_answer_decoded
"quality: 0.7" in first_example_second_answer_decoded

second_example_first_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][2])
f"{QA_SPECIAL_TOKENS['Question']}{second_example.questions[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{second_example.answers[0]}{eos}" in second_example_first_answer_decoded
"humor: 0.1" in second_example_first_answer_decoded
"length: 1000" in second_example_first_answer_decoded

second_example_second_answer_decoded = pythia_tokenizer.decode(batch.data["input_ids"][2])
f"{QA_SPECIAL_TOKENS['Question']}{second_example.questions[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{second_example.answers[0]}{eos}" in second_example_second_answer_decoded
"humor: 0.1" in second_example_second_answer_decoded
"length: 1000" in second_example_second_answer_decoded


def test_ranking_collator_no_messages(pythia_tokenizer):
first_messages = None
first_replies = [
"Response A to None",
"Response B to None",
"Response C to None",
]
examples = [(first_messages, first_replies)]
rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
eos = pythia_tokenizer.eos_token
examples_ds = [DatasetEntry(questions=first_messages or [], answers=first_replies)]
# make sure that formatting via dataset entry and lists is the same
for ex in [examples, examples_ds]:
batch, cu_lens = rdc(examples=ex)
assert len(batch) == 2
assert cu_lens == [0, len(first_replies)]
assert batch.data["attention_mask"].shape[0] == 3 # we have 5 replies in total
assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape

# check each instruction
assert pythia_tokenizer.decode(batch.data["input_ids"][0]) == f"{first_replies[0]}{eos}"
assert pythia_tokenizer.decode(batch.data["input_ids"][1]) == f"{first_replies[1]}{eos}"
assert pythia_tokenizer.decode(batch.data["input_ids"][2]) == f"{first_replies[2]}{eos}"
assert (batch.attention_mask == torch.where(batch.input_ids == 1, 0, 1)).all()


def test_ranking_collator_local(pythia_tokenizer):
first_messages = ["First Instruction."]
first_replies = [
"Response A to First Instruction",
"Response B to First Instruction",
"First Response C to First Instruction",
]
second_messages = ["Second Instruction."]
second_replies = ["Response A to Second Instruction", "Response B to Second Instruction"]
examples = [(first_messages, first_replies), (second_messages, second_replies)]
rdc = RankingDataCollator(tokenizer=pythia_tokenizer, padding=True)
eos = pythia_tokenizer.eos_token
pad = pythia_tokenizer.pad_token

examples_ds = [
DatasetEntry(questions=first_messages, answers=first_replies),
DatasetEntry(questions=second_messages, answers=second_replies),
]
# make sure that formatting via dataset entry and lists is the same
for ex in [examples, examples_ds]:
batch, cu_lens = rdc(examples=ex)

assert len(batch) == 2
assert cu_lens == [0, len(first_replies), len(first_replies) + len(second_replies)]
assert batch.data["attention_mask"].shape[0] == 5 # we have 5 replies in total
assert batch.data["input_ids"].shape == batch.data["attention_mask"].shape
# check each instruction
assert (
pythia_tokenizer.decode(batch.data["input_ids"][0])
== f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[0]}{eos}"
+ 5 * pad
)
assert (
pythia_tokenizer.decode(batch.data["input_ids"][1])
== f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[1]}{eos}"
+ 5 * pad
)
assert (
pythia_tokenizer.decode(batch.data["input_ids"][2])
== f"{QA_SPECIAL_TOKENS['Question']}{first_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{first_replies[2]}{eos}"
)
assert (
pythia_tokenizer.decode(batch.data["input_ids"][3])
== f"{QA_SPECIAL_TOKENS['Question']}{second_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{second_replies[0]}{eos}"
+ 4 * pad
)
assert (
pythia_tokenizer.decode(batch.data["input_ids"][4])
== f"{QA_SPECIAL_TOKENS['Question']}{second_messages[0]}{eos}{QA_SPECIAL_TOKENS['Answer']}{second_replies[1]}{eos}"
+ 4 * pad
)

assert (batch.attention_mask == torch.where(batch.input_ids == 1, 0, 1)).all()


@pytest.mark.skip(reason="manual")
Expand Down

0 comments on commit bcc9360

Please sign in to comment.