diff --git a/docs/source/notes/ddp.rst b/docs/source/notes/ddp.rst index 48cadb2218c0b..43256a2a68677 100644 --- a/docs/source/notes/ddp.rst +++ b/docs/source/notes/ddp.rst @@ -70,9 +70,6 @@ DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wra before compiling the model, such that torchdynamo can apply ``DDPOptimizer`` (graph-break optimizations) based on DDP bucket sizes. (See `TorchDynamo DDPOptimizer <./ddp.html#torchdynamo-ddpoptimizer>`_ for more information.) -TorchDynamo support for DDP currently requires setting `static_graph=False`, due to -interactions between the graph tracing process and DDP's mechanism for observing operations happening on its module, -but this should be fixed ultimately. .. code:: diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 82d4248fb6cb1..c9646638bcfb6 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -290,26 +290,45 @@ def test_ddp_baseline_aot_eager_multiprocess(self): outputs = m(inputs) self.assertTrue(same(correct_outputs, outputs)) + def _test_hf_bert_ddp_inductor(self, static_graph): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + model, inputs = get_hf_bert(self.rank) + model = DDP(model, static_graph=static_graph) + run_hf_bert_ddp(self, model, inputs, "inductor") + @skip_if_lt_x_gpu(2) @import_transformers_or_skip() @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): + self._test_hf_bert_ddp_inductor(static_graph=False) + @skip_if_lt_x_gpu(2) + @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @patch.object(config, "optimize_ddp", True) + @patch.object(torch._inductor.config, "fallback_random", True) + def test_hf_bert_ddp_inductor_static_graph(self): + self._test_hf_bert_ddp_inductor(static_graph=True) + + def _test_hf_bert_aot_eager(self, static_graph): with _dynamo_dist_per_rank_init(self.rank, self.world_size): model, inputs = get_hf_bert(self.rank) - model = DDP(model) - run_hf_bert_ddp(self, model, inputs, "inductor") + model = DDP(model, static_graph=static_graph) + run_hf_bert_ddp(self, model, inputs, "aot_eager") @skip_if_lt_x_gpu(2) @import_transformers_or_skip() @patch.object(config, "optimize_ddp", True) def test_hf_bert_ddp_aot_eager(self): - with _dynamo_dist_per_rank_init(self.rank, self.world_size): - model, inputs = get_hf_bert(self.rank) - model = DDP(model) - run_hf_bert_ddp(self, model, inputs, "aot_eager") + self._test_hf_bert_aot_eager(static_graph=False) + + @skip_if_lt_x_gpu(2) + @import_transformers_or_skip() + @patch.object(config, "optimize_ddp", True) + def test_hf_bert_ddp_aot_eager_static_graph(self): + self._test_hf_bert_aot_eager(static_graph=True) @skip_if_lt_x_gpu(2) @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index a6e09e7c2e20a..3196d3cc7ea1f 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -10072,5 +10072,37 @@ def test_ddp_device_mesh_initialization(self): model, device_mesh=device_mesh ) + @skip_if_lt_x_gpu(2) + @require_world_size(2) + @skip_but_pass_in_sandcastle_if( + BACKEND not in DistTestCases.backend_feature["ddp"], + f"The {BACKEND} backend does not support DistributedDataParallel", + ) + def test_ddp_compile_static_graph(self): + "Tests that DDP works with torch compile when static_graph=True" + model = torch.nn.Linear(10, 10).cuda(self.rank) + model_clone = copy.deepcopy(model) + ddp = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[self.rank], + ) + ddp_static = torch.nn.parallel.DistributedDataParallel( + model_clone, + device_ids=[self.rank], + static_graph=True + ) + ddp = torch.compile(ddp) + ddp_static = torch.compile(ddp_static) + input = torch.rand(10, 10).cuda(self.rank) + # verify output and gradient parity + for _ in range(6): + out_ddp = ddp(input).sum() + out_ddp_static = ddp_static(input).sum() + self.assertEqual(out_ddp, out_ddp_static) + out_ddp.backward() + out_ddp_static.backward() + for p1, p2 in zip(ddp.parameters(), ddp_static.parameters()): + self.assertEqual(p1.grad, p2.grad) + instantiate_parametrized_tests(DistributedTest._DistTestBase)