Skip to content

Commit

Permalink
[DDP] Rename state_dict var to ddp_state (pytorch#103282)
Browse files Browse the repository at this point in the history
This name is confusing in the context that it is just a dictionary
used to pass state to DDP backward pass.

Differential Revision: [D46580516](https://our.internmc.facebook.com/intern/diff/D46580516/)
Pull Request resolved: pytorch#103282
Approved by: https://github.com/awgu
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Jun 14, 2023
1 parent 2d745b9 commit 2076a2f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torch/nn/parallel/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,12 @@ class _BufferCommHook:
# is completed.
class _DDPSink(Function):
@staticmethod
def forward(ctx, reducer, state_dict, *inputs):
def forward(ctx, reducer, ddp_state, *inputs):
# set_materialize_grads(False) will ensure that None gradients stay as
# None and are not filled with zeros.
ctx.set_materialize_grads(False)
ctx.reducer = reducer
ctx.state_dict = state_dict
ctx.ddp_state = ddp_state
ret = tuple(
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
)
Expand All @@ -254,7 +254,7 @@ def forward(ctx, reducer, state_dict, *inputs):
def backward(ctx, *grad_outputs):
# Enqueue delay allreduce for static graph training on the first
# iteration.
if ctx.state_dict["static_graph"] and ctx.state_dict["num_iterations"] == 1:
if ctx.ddp_state["static_graph"] and ctx.ddp_state["num_iterations"] == 1:
Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
ctx.reducer._delay_all_reduce
)
Expand Down Expand Up @@ -1468,7 +1468,7 @@ def _post_forward(self, output):
if (self.find_unused_parameters and not self.static_graph) or (
self.static_graph and self.num_iterations == 1
):
state_dict = {
ddp_state = {
"static_graph": self.static_graph,
"num_iterations": self.num_iterations,
}
Expand All @@ -1492,7 +1492,7 @@ def _post_forward(self, output):
# param.grad field is not touched and we don't error out.
passthrough_tensor_list = _DDPSink.apply(
self.reducer,
state_dict,
ddp_state,
*output_tensor_list,
)
for i in range(len(output_placeholders)):
Expand Down

0 comments on commit 2076a2f

Please sign in to comment.