Skip to content

Commit

Permalink
Update DORT to follow PyTorch changes (microsoft#16394)
Browse files Browse the repository at this point in the history
Fix microsoft#16355. The root cause change in PyTorch is
[#103302](pytorch/pytorch#103302), which seem
blocking calling make_fx inside a dynamo backend.

Changes:
1. Move decomposition to `register_backend.py`, so we don't have to call
`make_fx` inside DORT, which triggers a bunch of new exceptions.
2. Remove shape inference based on FakeTensorProp since the FX graph
received from dynamo contains all shapes now.
3. Fix a macro bug so that DORT can build without CUDA.

Before (3),
```
#if defined(USE_CUDA) || defined(USE_ROCM)
  virtual PhiloxGenerator& PhiloxGenerator__Default() = 0;
#ifdef ENABLE_TRAINING_TORCH_INTEROP
...
#endif
#endif
```
After (3),
```
#if defined(USE_CUDA) || defined(USE_ROCM)
  virtual PhiloxGenerator& PhiloxGenerator__Default() = 0;
#endif
#ifdef ENABLE_TRAINING_TORCH_INTEROP
...
#endif
```
The later one looks better since the `ENABLE_TRAINING_TORCH_INTEROP` is
for Python bridge code, not for random-number-generating kernels
`PhiloxGenerator`.
  • Loading branch information
wschin authored and carzh committed Jun 23, 2023
1 parent f2f46b7 commit 4c3c172
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -924,8 +924,8 @@ struct ProviderHost {
#endif

#if defined(USE_CUDA) || defined(USE_ROCM)

virtual PhiloxGenerator& PhiloxGenerator__Default() = 0;
#endif

#ifdef ENABLE_TRAINING_TORCH_INTEROP
virtual void contrib__PythonOpBase__Init(contrib::PythonOpBase* p, const OpKernelInfo& info) = 0;
Expand All @@ -940,7 +940,6 @@ struct ProviderHost {
virtual language_interop_ops::torch::RefCountTracker& GetRefCountTrackerInstance() = 0;
virtual void RefCountTracker__DumpDetails(const language_interop_ops::torch::RefCountTracker* p, const std::string& phase_name) = 0;
#endif
#endif

#if defined(USE_CANN)
virtual RandomGenerator& RandomGenerator__Default() = 0;
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,8 @@ struct ProviderHostImpl : ProviderHost {
#endif

#if defined(USE_CUDA) || defined(USE_ROCM)

PhiloxGenerator& PhiloxGenerator__Default() override { return PhiloxGenerator::Default(); }
#endif

#ifdef ENABLE_TRAINING_TORCH_INTEROP
void contrib__PythonOpBase__Init(contrib::PythonOpBase* p, const OpKernelInfo& info) override { p->PythonOpBase::Init(info); }
Expand All @@ -1092,7 +1092,6 @@ struct ProviderHostImpl : ProviderHost {
return p->language_interop_ops::torch::RefCountTracker::DumpDetails(phase_name);
}
#endif
#endif

#if defined(USE_CANN)
RandomGenerator& RandomGenerator__Default() override { return RandomGenerator::Default(); }
Expand Down
15 changes: 3 additions & 12 deletions orttraining/orttraining/python/training/torchdynamo/ort_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
import torch.onnx
import torch.onnx._onnx_supported_ops
from torch._decomp import decomposition_table
from torch._dynamo.utils import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
Expand Down Expand Up @@ -182,8 +180,8 @@ def _get_support_dictionaries_and_decomposition_tables() -> (
(
_SUPPORT_DICT,
_EXTRA_SUPPORT_DICT,
_ATEN2ATEN_DECOMP,
_ATEN2PRIM_DECOMP,
ATEN2ATEN_DECOMP,
ATEN2PRIM_DECOMP,
) = _get_support_dictionaries_and_decomposition_tables()


Expand Down Expand Up @@ -628,15 +626,8 @@ def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphMod
if graph_module in self._partitioner_cache:
partitioned_prim_graph_module = self._partitioner_cache[graph_module]
else:
prim_graph_module = make_fx(
graph_module, tracing_mode="fake", _allow_non_fake_inputs=True, decomposition_table=_ATEN2ATEN_DECOMP
)(*args)
prim_graph_module = graph_module
# TODO(wechi): this is required for removing aten::_to_copy in _replace_to_copy_with_to.
# We need input and output tensors' devices to decide if aten::_to_copy is just a Cast.
fake_mode = detect_fake_mode(args)
if not fake_mode:
fake_mode = torch._subclasses.FakeTensorMode()
FakeTensorProp(prim_graph_module, mode=fake_mode).propagate(*args)
_replace_to_copy_with_to(prim_graph_module)
partitioner = CapabilityBasedPartitioner(
prim_graph_module, self._supported_ops, allows_single_node_partition=False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd

from .ort_backend import OrtBackend
from .ort_backend import ATEN2ATEN_DECOMP, OrtBackend

# This should be the underlying compiler for ALL graphs if
# the user uses ORT to accelerate PyTorch via Dynamo.
Expand All @@ -28,7 +28,9 @@
# compiled_model = torch._dynamo.optimize(aot_ort)(model)
# result = compiled_model(torch.rand(2, 2, dtype=torch.float)
# result.sum().backward()
aot_ort = aot_autograd(fw_compiler=DEFAULT_BACKEND, partition_fn=min_cut_rematerialization_partition)
aot_ort = aot_autograd(
fw_compiler=DEFAULT_BACKEND, partition_fn=min_cut_rematerialization_partition, decompositions=ATEN2ATEN_DECOMP
)

# Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called).
#
Expand Down

0 comments on commit 4c3c172

Please sign in to comment.