Skip to content

Commit

Permalink
Conversion to tensors requires padding (#10661)
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Mar 11, 2021
1 parent 2adc8c9 commit 7e44287
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tests/test_modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,9 @@ def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
self.assertListEqual(self.expected_text, generated_words)

def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer(self.src_text, return_tensors="pt", **tokenizer_kwargs).to(torch_device)
model_inputs = self.tokenizer(self.src_text, padding=True, return_tensors="pt", **tokenizer_kwargs).to(
torch_device
)
self.assertEqual(self.model.device, model_inputs.input_ids.device)
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_tf_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
self.assertListEqual(self.expected_text, generated_words)

def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, return_tensors="tf")
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, padding=True, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
)
Expand Down

0 comments on commit 7e44287

Please sign in to comment.