Skip to content

Commit

Permalink
fix labels (#6213)
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Aug 3, 2020
1 parent cedc547 commit 0b41867
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]])
return {"input_ids": inputs, "labels": labels}
else:
labels = batch.clone().detach()
labels[labels == self.tokenizer.pad_token_id] = -100
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels}

def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
Expand Down

0 comments on commit 0b41867

Please sign in to comment.