Skip to content

support dynamo/compile in torchcomms #55

@d4l3k

Description

@d4l3k

Currently if you try to use torchcomms with dynamo or torch.compile it throws an error:

torch._dynamo.exc.Unsupported: Unsupported method call
  Explanation: Dynamo does not know how to trace method `all_reduce` of class `TorchComm`
  Hint: Avoid calling `TorchComm.all_reduce` in your code.
  Hint: Please report an issue to PyTorch.

  Developer debug context: call_method UserDefinedObjectVariable(TorchComm) all_reduce [LazyVariableTracker(unrealized: <class 'torch.Tensor'>), LazyVariableTracker(unrealized: <class 'pybind11_builtins.pybind11_static_property'>)] {'async_op': ConstantVariable(bool: False)}

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0156.html

from user code:
   File "/home/tristanr/scripts/torchcomms_dynamo.py", line 15, in my_func
    comm.all_reduce(t, torchcomms.ReduceOp.SUM, async_op=False)

We want to support this for at least graph capture cases.

"""
Invoke with:
torchrun --nnodes 1 --nproc_per_node=gpu ~/scripts/torchcomms_dynamo.py
"""

import torch
import torch.distributed as dist
import torchcomms

comm = torchcomms.new_comm('ncclx', torch.device('cuda'), store=None, name='1234')

t = torch.ones(10, device=comm.get_device())

def my_func(t):
    comm.all_reduce(t, torchcomms.ReduceOp.SUM, async_op=False)
    t *= 10
    return t

try:
    compiled_func = torch.compile(my_func, fullgraph=True)
    compiled_func(t)

finally:
    comm.finalize()

In c10d::ProcessGroup we register one torch op per collective and always run them through the dispatcher to support this tracing. See https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/ProcessGroup.hpp#L270-L295 for more details.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions