Skip to content

Commit

Permalink
Updated SODA to use DatasetEntry (#2913)
Browse files Browse the repository at this point in the history
Issue: #2827 
Updated `SODA` class to use DatasetEntry which generalises the data and
provides consistent data structure to store and access Q&A pairs.

Changes includes
- Updated SODA class to store Q/A pairs as DatasetEntry object.
- Updated the `__getitem__` function to return DatasetEntry object
  • Loading branch information
hardikyagnik committed Apr 27, 2023
1 parent bbbd370 commit 075b141
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions model/model_training/custom_datasets/qa_datasets.py
Expand Up @@ -232,7 +232,7 @@ def __getitem__(self, index) -> list[str] | tuple[list[str], list[str]]:
class SODA(Dataset):
name = "soda"

def process_soda_convo(self, data: dict[str, Any], input_max_length: int) -> list[list[str]] | None:
def process_soda_convo(self, data: dict[str, Any], input_max_length: int) -> DatasetEntry | None:
play_as = data["speakers"][1]
dialogue_bg = "{}{}".format(
# QA_SPECIAL_TOKENS["StartPrefix"],
Expand All @@ -256,7 +256,9 @@ def process_soda_convo(self, data: dict[str, Any], input_max_length: int) -> lis
data["dialogue"][0] = f"{dialogue_bg} {data['dialogue'][0]}"
# Use only input_max_length characters
truncated_dialogue = [k[:input_max_length] for k in data["dialogue"]]
return truncated_dialogue
questions = [q for idx, q in enumerate(truncated_dialogue) if idx % 2 == 0]
answers = [a for idx, a in enumerate(truncated_dialogue) if idx % 2 == 1]
return DatasetEntry(questions=questions, answers=answers)

def __init__(self, cache_dir, mode="sft", input_max_length=1024) -> None:
super().__init__()
Expand All @@ -275,13 +277,10 @@ def __init__(self, cache_dir, mode="sft", input_max_length=1024) -> None:
def __len__(self) -> int:
return len(self.pairs)

def __getitem__(self, index) -> list[str] | tuple[str]:
def __getitem__(self, index) -> DatasetEntry:
# special token added during preprocess
if self.mode == "sft":
return self.pairs[index]
elif self.mode == "rl":
# add prefix + first human question
return (self.pairs[index][0] + " " + self.pairs[index][1],)
dialogue = self.pairs[index]
return dialogue


class SODADialogue(Dataset):
Expand Down

0 comments on commit 075b141

Please sign in to comment.