Skip to content

Commit

Permalink
[DDP][Compile] Test to Ensure torch.compile works w/static_graph=True (
Browse files Browse the repository at this point in the history
…pytorch#114621)

Resolves pytorch#93672. This was
actually fixed by pytorch#103487 but I didn't
realize that PR also fixes torch compile at the time.

Differential Revision: [D51596148](https://our.internmc.facebook.com/intern/diff/D51596148/)

Pull Request resolved: pytorch#114621
Approved by: https://github.com/wconstab
  • Loading branch information
rohan-varma authored and hyperfraise committed Dec 21, 2023
1 parent 3c1f69e commit e12c766
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 9 deletions.
3 changes: 0 additions & 3 deletions docs/source/notes/ddp.rst
Expand Up @@ -70,9 +70,6 @@ DDP works with TorchDynamo. When used with TorchDynamo, apply the DDP model wra
before compiling the model, such that torchdynamo can apply ``DDPOptimizer``
(graph-break optimizations) based on DDP bucket sizes. (See `TorchDynamo DDPOptimizer <./ddp.html#torchdynamo-ddpoptimizer>`_ for more information.)

TorchDynamo support for DDP currently requires setting `static_graph=False`, due to
interactions between the graph tracing process and DDP's mechanism for observing operations happening on its module,
but this should be fixed ultimately.

.. code::
Expand Down
31 changes: 25 additions & 6 deletions test/distributed/test_dynamo_distributed.py
Expand Up @@ -290,26 +290,45 @@ def test_ddp_baseline_aot_eager_multiprocess(self):
outputs = m(inputs)
self.assertTrue(same(correct_outputs, outputs))

def _test_hf_bert_ddp_inductor(self, static_graph):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model, static_graph=static_graph)
run_hf_bert_ddp(self, model, inputs, "inductor")

@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor(self):
self._test_hf_bert_ddp_inductor(static_graph=False)

@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor_static_graph(self):
self._test_hf_bert_ddp_inductor(static_graph=True)

def _test_hf_bert_aot_eager(self, static_graph):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model)
run_hf_bert_ddp(self, model, inputs, "inductor")
model = DDP(model, static_graph=static_graph)
run_hf_bert_ddp(self, model, inputs, "aot_eager")

@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@patch.object(config, "optimize_ddp", True)
def test_hf_bert_ddp_aot_eager(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model)
run_hf_bert_ddp(self, model, inputs, "aot_eager")
self._test_hf_bert_aot_eager(static_graph=False)

@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@patch.object(config, "optimize_ddp", True)
def test_hf_bert_ddp_aot_eager_static_graph(self):
self._test_hf_bert_aot_eager(static_graph=True)

@skip_if_lt_x_gpu(2)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
Expand Down
32 changes: 32 additions & 0 deletions torch/testing/_internal/distributed/distributed_test.py
Expand Up @@ -10072,5 +10072,37 @@ def test_ddp_device_mesh_initialization(self):
model, device_mesh=device_mesh
)

@skip_if_lt_x_gpu(2)
@require_world_size(2)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],
f"The {BACKEND} backend does not support DistributedDataParallel",
)
def test_ddp_compile_static_graph(self):
"Tests that DDP works with torch compile when static_graph=True"
model = torch.nn.Linear(10, 10).cuda(self.rank)
model_clone = copy.deepcopy(model)
ddp = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.rank],
)
ddp_static = torch.nn.parallel.DistributedDataParallel(
model_clone,
device_ids=[self.rank],
static_graph=True
)
ddp = torch.compile(ddp)
ddp_static = torch.compile(ddp_static)
input = torch.rand(10, 10).cuda(self.rank)
# verify output and gradient parity
for _ in range(6):
out_ddp = ddp(input).sum()
out_ddp_static = ddp_static(input).sum()
self.assertEqual(out_ddp, out_ddp_static)
out_ddp.backward()
out_ddp_static.backward()
for p1, p2 in zip(ddp.parameters(), ddp_static.parameters()):
self.assertEqual(p1.grad, p2.grad)


instantiate_parametrized_tests(DistributedTest._DistTestBase)

0 comments on commit e12c766

Please sign in to comment.