From 6e449cb575efa3845792e27f921e151ee5b96024 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 16:37:15 +0530 Subject: [PATCH 1/6] fix --- tests/models/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 259b4cc916d3..6d5df1133b38 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -885,11 +885,11 @@ def test_model_parallelism(self): @require_torch_gpu def test_sharded_checkpoints(self): + torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() model = model.to(torch_device) - torch.manual_seed(0) base_output = model(**inputs_dict) model_size = compute_module_sizes(model)[""] From d76972cc1d4055d0d09c3bdf677a26f08b7fcbf7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 16:46:29 +0530 Subject: [PATCH 2/6] fix --- tests/models/test_modeling_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 6d5df1133b38..15a075f442e6 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -887,6 +887,7 @@ def test_model_parallelism(self): def test_sharded_checkpoints(self): torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + print(f"{inputs_dict['sample'][0, :3, :3, 3].flatten()=}") model = self.model_class(**config).eval() model = model.to(torch_device) @@ -910,6 +911,7 @@ def test_sharded_checkpoints(self): torch.manual_seed(0) _, inputs_dict = self.prepare_init_args_and_inputs_for_common() + print(f"{inputs_dict['sample'][0, :3, :3, 3].flatten()=}") new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) From f8f96abd84350b51ba718c2b8ee485cd89348573 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 16:48:10 +0530 Subject: [PATCH 3/6] ugly --- tests/models/autoencoders/test_models_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 0fc185b602a3..75d9ad54f85e 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -138,7 +138,7 @@ def dummy_input(self): batch_size = 4 num_channels = 3 sizes = (32, 32) - + torch.manual_seed(0) image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) return {"sample": image} From 77fb3bdbf5550ddd8fa57042a50b16c9a1f466c5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 16:49:42 +0530 Subject: [PATCH 4/6] okay --- tests/models/autoencoders/test_models_vae.py | 2 +- tests/models/test_modeling_common.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 75d9ad54f85e..0fc185b602a3 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -138,7 +138,7 @@ def dummy_input(self): batch_size = 4 num_channels = 3 sizes = (32, 32) - torch.manual_seed(0) + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) return {"sample": image} diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 15a075f442e6..a9bb752b00a3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -887,7 +887,6 @@ def test_model_parallelism(self): def test_sharded_checkpoints(self): torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() - print(f"{inputs_dict['sample'][0, :3, :3, 3].flatten()=}") model = self.model_class(**config).eval() model = model.to(torch_device) @@ -910,8 +909,8 @@ def test_sharded_checkpoints(self): new_model = new_model.to(torch_device) torch.manual_seed(0) - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - print(f"{inputs_dict['sample'][0, :3, :3, 3].flatten()=}") + if "generator" in inputs_dict: + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) From ae534d1df3747ebbb67081dd4473fde3eebbb5f6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 16:59:04 +0530 Subject: [PATCH 5/6] fix more --- tests/models/test_modeling_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a9bb752b00a3..11a8c67879de 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -943,6 +943,8 @@ def test_sharded_checkpoints_device_map(self): new_model = new_model.to(torch_device) torch.manual_seed(0) + if "generator" in inputs_dict: + _, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) From 43207aedc2ecfeed0ac3ec81bcead77ff3b38fe0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 5 Jul 2024 17:02:23 +0530 Subject: [PATCH 6/6] fix oops --- tests/models/test_modeling_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 11a8c67879de..87ed1d9d17e5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -945,7 +945,6 @@ def test_sharded_checkpoints_device_map(self): torch.manual_seed(0) if "generator" in inputs_dict: _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))