Skip to content

Commit

Permalink
Fix tests (#14289)
Browse files Browse the repository at this point in the history
  • Loading branch information
NielsRogge authored Nov 6, 2021
1 parent 24b30d4 commit 34307bb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 6 additions & 2 deletions tests/test_modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def test_training(self):
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config)
model.to(torch_device)
model.train()
Expand All @@ -259,7 +261,9 @@ def test_training_gradient_checkpointing(self):
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config)
model.to(torch_device)
model.train()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_modeling_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ def test_training(self):
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
if model_class.__name__ == "SegformerForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config)
model.to(torch_device)
model.train()
Expand Down

0 comments on commit 34307bb

Please sign in to comment.