diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ec2bca537fc9..98c095f96804 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1001,6 +1001,11 @@ def test_determinism(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() def check_determinism(first, second): + # Simply don't compare if both tensors only contain `nan` elements + # See: https://github.com/huggingface/transformers/pull/40661 + if torch.all(torch.isnan(first)) and torch.all(torch.isnan(second)): + return + out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() out_1 = out_1[~np.isnan(out_1)]