From 0b418673575fb5aa2f6f657fb33d64eafb1d700f Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 3 Aug 2020 19:49:35 +0530 Subject: [PATCH] fix labels (#6213) --- src/transformers/data/data_collator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 085f7a68a8aac..cf8eb996f8e5e 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -87,7 +87,8 @@ def __call__(self, examples: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) return {"input_ids": inputs, "labels": labels} else: labels = batch.clone().detach() - labels[labels == self.tokenizer.pad_token_id] = -100 + if self.tokenizer.pad_token_id is not None: + labels[labels == self.tokenizer.pad_token_id] = -100 return {"input_ids": batch, "labels": labels} def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor: