Skip to content

Commit

Permalink
Correct Collator (#1975)
Browse files Browse the repository at this point in the history
Fixes a mistake in the labels mask, related to #1974.
  • Loading branch information
sanagno committed Mar 6, 2023
1 parent 85fbaf6 commit ba336e3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
13 changes: 7 additions & 6 deletions model/model_training/custom_datasets/dialogue_collator.py
Expand Up @@ -54,12 +54,13 @@ def __call__(self, features):
list(map(lambda x: x[1], flatten_message["offset_mapping"])),
)
)
label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1)
try:
label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
except IndexError:
# due to truncation, we might not have the last termination token
label_mask[-1] = False
label_mask = np.array(list(map(lambda x: x % 2 == 1, message_indices)))
label_mask[-1] = False # make sure last token is inactive, has an effect only when truncatting
# try:
# label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True
# except IndexError:
# # due to truncation, we might not have the last termination token
# label_mask[-1] = False

label_masks.append(label_mask)
if len(flatten_message["input_ids"]) < self.mix_length_threshold and self.samples_mixing:
Expand Down
5 changes: 1 addition & 4 deletions model/model_training/custom_datasets/formatting.py
Expand Up @@ -3,10 +3,7 @@

def format_pair(pairs):
return [
"{}{}{}".format(QA_SPECIAL_TOKENS["Question"], pairs[i], QA_SPECIAL_TOKENS["Answer"])
if i % 2 == 0
else pairs[i]
for i in range(len(pairs))
"{}{}".format(QA_SPECIAL_TOKENS["Question" if i % 2 == 0 else "Answer"], pairs[i]) for i in range(len(pairs))
]


Expand Down

0 comments on commit ba336e3

Please sign in to comment.