Skip to content

Commit

Permalink
suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed May 14, 2024
1 parent 7e2ba6d commit 59b1378
Showing 1 changed file with 3 additions and 12 deletions.
15 changes: 3 additions & 12 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3863,19 +3863,10 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

# In case of additional token (like class) we define a custom `mask_length`
if hasattr(self.model_tester, "mask_length"):
dummy_mask = torch.cat(
[
dummy_mask,
torch.zeros(self.model_tester.mask_length - dummy_mask.size(0)),
]
)
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
else:
dummy_mask = torch.cat(
[
dummy_mask,
torch.zeros(self.model_tester.seq_length - dummy_mask.size(0)),
]
)
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)

Expand Down

0 comments on commit 59b1378

Please sign in to comment.