Skip to content

Commit

Permalink
Joke explanation generalization (#2899)
Browse files Browse the repository at this point in the history
For the Issue #2827.
I have made changes to JokeExplaniation Class.
This PR implements the DatasetEntry class in the JokeExplaination class
to generalize the data. The DatasetEntry class provides a consistent
data structure for storing joke-explanation pairs, making it easier to
work with the data.
and made changes in AlpacaGpt4 to correct the annotation in one of its
methods.
The changes in this PR include:
- Adding a new DatasetEntry class to represent joke-explanation pairs
- Updating the JokeExplaination class to use DatasetEntry objects to
store data
- Replacing the AlpacaGpt4 class __getitem__ method with correct
annotation

---------

Co-authored-by: sampatkalyan <120446217+Andavarapu-Sampat-Kalyan@users.noreply.github.com>
  • Loading branch information
sampatkalyan and ASampatKalyan committed Apr 27, 2023
1 parent 0a8d408 commit cab4b58
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions model/model_training/custom_datasets/qa_datasets.py
Expand Up @@ -333,8 +333,6 @@ def __init__(self, cache_dir) -> None:
with open(joke_explain_filename, "w") as fout:
fout.write(content)

question = ""
answer = ""
self.pairs = []
with open(joke_explain_filename, "r") as f:
for line in f:
Expand All @@ -343,16 +341,12 @@ def __init__(self, cache_dir) -> None:
# DO NOT change this
# its the data that had syntax error
explanation = data["explaination"]
self.pairs.append((joke, explanation))
self.pairs.append(DatasetEntry(questions=[joke], answers=[explanation]))

if len(question) > 0 and len(answer) > 0:
self.pairs.append((question, answer))
self.length = len(self.pairs)

def __len__(self):
def __len__(self) -> int:
return self.length

def __getitem__(self, index):
def __getitem__(self, index) -> DatasetEntry:
return self.pairs[index]


Expand Down Expand Up @@ -610,6 +604,6 @@ def _process_instruction(self, row: dict[str, str], input_max_length: int) -> Da
def __len__(self) -> int:
return len(self.rows)

def __getitem__(self, index: int) -> list[str] | tuple[str]:
def __getitem__(self, index: int) -> DatasetEntry:
dialogue = self.rows[index]
return dialogue

0 comments on commit cab4b58

Please sign in to comment.