diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 259b4cc916d3..87ed1d9d17e5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -885,11 +885,11 @@ def test_model_parallelism(self): @require_torch_gpu def test_sharded_checkpoints(self): + torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() model = model.to(torch_device) - torch.manual_seed(0) base_output = model(**inputs_dict) model_size = compute_module_sizes(model)[""] @@ -909,7 +909,8 @@ def test_sharded_checkpoints(self): new_model = new_model.to(torch_device) torch.manual_seed(0) - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + if "generator" in inputs_dict: + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) @@ -942,7 +943,8 @@ def test_sharded_checkpoints_device_map(self): new_model = new_model.to(torch_device) torch.manual_seed(0) - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + if "generator" in inputs_dict: + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))