Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ORT callable from various Pytorch compilers (LazyTensor, TorchDynamo, etc) #10460

Merged
merged 24 commits into from
Aug 22, 2022

Conversation

wschin
Copy link
Contributor

@wschin wschin commented Feb 3, 2022

A Pytorch backend is a program which is responsible for Pytorch's actual computation. For example, if you register a new GPU backend, Pytorch may call that backend for GPU computation. torch_xla provides an innovative mechanism of adding Pytorch backend. Comparing with eager mode backend, torch_xla captures bigger graphs before invoking computation. Pytorch community is actively working on extending torch_xla framework to support custom backend. This PR demonstrates how to add ORT as one of them.

Highlights: with this PR,

  • ORT can capture computation from torch.jit, torch.fx, normal Pytorch computation (e.g., calling torch.mul and an nn.Module).
  • Unsupported computation is automatically done in Pytorch, so no fallback is required in our code.

Features implemented:

  • Various tensor types are supported.
  • CPU and GPU mode.
  • Session caching. We reuse session if input schema is the same.
  • No memory copy when exchanging tensors between ORT and Pytorch.

From the test script below, we show that no model code change is needed when enabling this feature.

import itertools
import torch

# Lines to set ORT as a new backend
import lazy_tensor_core
from onnxruntime.capi import _pybind_state as ost
import lazy_tensor_core.core.lazy_model as ltm
ost.register_ort_as_torch_jit_executor()
lazy_tensor_core._LAZYC._ltc_init_ts_backend()

def model(x):
    y = x * x
    z = y + x
    p = z * x
    q = p - x
    r = q / x
    return r

def run_forward(tag, x):
    y = model(x)
    ltm.mark_step()
    return y


def run_forward_backward(tag, x):
    assert x.requires_grad
    y = run_forward(tag, x)
    y.sum().backward()
    ltm.mark_step()
    return y, x.grad

def run(x, dtype, can_run_backward):
    baseline_device = 'cpu'
    checked_device = 'lazy'
    x_cpu = torch.tensor(x, device=baseline_device, dtype=dtype)
    x_lazy = torch.tensor(x, device=checked_device, dtype=dtype)

    expected = run_forward(baseline_device, x_cpu)
    generated = run_forward(checked_device, x_lazy)
    torch.allclose(expected, generated.to(baseline_device))

    if not can_run_backward:
      return

    expected, expected_grad = run_forward_backward(baseline_device, x_cpu.requires_grad_())
    generated, generated_grad = run_forward_backward(checked_device, x_lazy.requires_grad_())
    torch.allclose(expected, generated.to(baseline_device))
    torch.allclose(expected_grad, generated_grad.to(baseline_device))

x_floats = [-1.0, 1, [-1.0, 1.0, 2.0], [-0.0, 0.0, 7.0]]
x_ints = [[1, -1], [2, 1], [-2, -2]]
float_types = [torch.float64, torch.float, torch.float16]
int_types = [torch.int64, torch.int32]

# Test float types.
for x, dtype in itertools.product(x_floats, float_types):
  run(x, dtype, can_run_backward=True)

# Test int types.
for x, dtype in itertools.product(x_ints, int_types):
  run(x, dtype, can_run_backward=False)

Steps to build Pytorch LazyTensor

  1. Clone Pytorch master branch
  2. cd pytorch
  3. python setup.py clean
  4. VERBOSE=1 BUILD_LAZY_TS_BACKEND=1 TORCH_CUDA_ARCH_LIST="7.0;7.2;7.5;8.0;8.6" PATH=/usr/local/cuda-11.3/lib64:/usr/local/cuda-11.3/include:/usr/local/cuda-11.3/bin:$PATH CUDACXX=/usr/local/cuda-11.3/bin/nvcc ONNX_NAMESPACE=onnx1 DEBUG=1 BUILD_SHARED_LIBS=1 BUILD_CAFFE2=0 BUILD_CAFFE2_OPS=0 USE_GLOO=1 USE_NCCL=0 USE_NUMPY=1 USE_OBSERVERS=1 USE_OPENMP=1 USE_DISTRIBUTED=1 USE_MPI=1 BUILD_PYTHON=1 USE_MKLDNN=0 USE_CUDA=1 BUILD_TEST=1 USE_FBGEMM=1 USE_NNPACK=1 USE_QNNPACK=0 USE_XNNPACK=1 python3 setup.py develop

After you have Pytorch LazyTensor, build ORT (this PR) with
Torch_DIR=path_to_cloned_pytorch_repo CUDACXX=/usr/local/cuda-11.1/bin/nvcc ./build.sh --config Debug --enable_training --use_cuda --cuda_home /usr/local/cuda-11.1 --cudnn_home /usr/local/cuda-11.1 --build_wheel --parallel --skip_tests --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 --cuda_version=11.1 --enable_nvtx_profile --enable_training_torch_interop --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON --enable_lazy_tensor

@xadupre
Copy link
Member

xadupre commented Feb 3, 2022

What happens when the conversion to ONNX fails? Is there a way to benchmark both runtimes in the same file?

@wschin
Copy link
Contributor Author

wschin commented Feb 22, 2022

What happens when the conversion to ONNX fails? Is there a way to benchmark both runtimes in the same file?

Yes. We introduced a debugging functionality to calls both ORT and Pytorch for the same graph and compare their results. This function can be triggered by ORT_LT_CHECK_BASELINE=1. We surely can add time comparison.

@lgtm-com
Copy link

lgtm-com bot commented May 6, 2022

This pull request introduces 1 alert when merging 9bfc4a1 into 2a90922 - view on LGTM.com

new alerts:

  • 1 for Commented-out code

@lgtm-com
Copy link

lgtm-com bot commented Jun 23, 2022

This pull request introduces 1 alert when merging 4ae149e into fa7f80c - view on LGTM.com

new alerts:

  • 1 for Commented-out code

@lgtm-com
Copy link

lgtm-com bot commented Jun 25, 2022

This pull request introduces 1 alert when merging 2eb64b3 into f4ba199 - view on LGTM.com

new alerts:

  • 1 for Commented-out code

@wschin wschin force-pushed the wechi/ltc1 branch 2 times, most recently from 71f0705 to 48dfa4d Compare July 12, 2022 07:52
LORT likely doesn't work with aten fallback so we only test LORT in its own CI.
@wschin wschin changed the title [Draft] ORT as Pytorch's XLA-like Backend Make ORT callable from various Pytorch compilers (LazyTensor, TorchDynamo, etc) Jul 12, 2022
Revert "Revert changes to enable external CUDA allocator. Will add it later."

This reverts commit d5487f2.

Fix external allocator
@@ -28,25 +28,58 @@ if(onnxruntime_ENABLE_TRAINING)
list(REMOVE_ITEM onnxruntime_pybind_srcs ${ONNXRUNTIME_ROOT}/python/onnxruntime_pybind_module.cc)
endif()

if (onnxruntime_ENABLE_EAGER_MODE)
# Add Pytorch as a library.
if (onnxruntime_ENABLE_LAZY_TENSOR OR onnxruntime_ENABLE_EAGER_MODE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you also go with the aot approach that link with pytorch during build time? do we have concerns for training workload with different pytorch version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the major concern is that user has to build PyTorch from source. I am not sure if I can use pre-built PyTorch, but it's a good idea to try pre-built PyTorch. I will update this comment for the result.

onnxruntime/core/providers/cuda/cuda_allocator.h Outdated Show resolved Hide resolved
onnxruntime/core/providers/cuda/cuda_allocator.h Outdated Show resolved Hide resolved
orttraining/orttraining/lazy_tensor/accelerator.cpp Outdated Show resolved Hide resolved

switch (node->kind()) {
// TODO: add as many ops as possible.
case aten::embedding:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we able to borrow the list from exporter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes by using the following script but exporter may only partially support an op. Do you want me to call this Python function from C++ and generate the list in this PR?

from torch.onnx import _onnx_supported_ops
aten_list = _onnx_supported_ops.onnx_supported_ops()
for name, opset_version in aten_list:
    print(f'"``{name}``","{opset_version}"\n')

static onnxruntime::Environment& pybind_default_env = GetLtcEnv();
// All sessions use the same config.
static onnxruntime::SessionOptions sess_opts;
return std::make_unique<onnxruntime::InferenceSession>(sess_opts, pybind_default_env);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the Cuda execution provider get registered?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In class CUDAExecutionProviderPool, it creates for one EP for each GPU.

NvtxRange range("Call sess.Run");
#endif
// Inputs are ready. Let's run ORT.
ORT_THROW_IF_ERROR(sess.Run(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing i guess we haven't pay much attention to yet: how did we make sure we are on the same stream with the pytorch? and if user defined customized streams in pytorch, like "with torch.cuda.stream", how will this reflect to our backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current stream from Pytorch can be retrieved by calling this function. ORT session or Aaron's new API should consume this stream.

orttraining/orttraining/lazy_tensor/bridge.h Show resolved Hide resolved

// Class holding the CUDA EPs (one unique EP per device)
// shared by all sessions.
class CUDAExecutionProviderPool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a execution provider pool in training python module, will we merge these two part together in the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try to integrate them. If fail, I will create an issue for tracking. Btw, I have replaced LTC environment with ORT-eager's environment.

orttraining/orttraining/lazy_tensor/register.cpp Outdated Show resolved Hide resolved
1. Reuse ORT-eager mode's environment.
2. Remove unused ctor.
cmake/CMakeLists.txt Outdated Show resolved Hide resolved
@@ -1767,7 +1775,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
# Adding the torch lib path for loading DLLs for onnxruntime in eager mode
# This works for Python 3.7 and below, and doesn't work for Python 3.8+
# User will need to import torch before onnxruntime and it will work for all versions
if args.build_eager_mode and is_windows():
if (args.build_eager_mode or args.enable_lazy_tensor) and is_windows():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, it is not tested for python 3.8+?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using Python 3.8.... I doubt if we really need this comment. @souptc , any comments?

Comment on lines +248 to +253
pybind11::gil_scoped_acquire guard{};
// Retrieve Python exporter function.
pybind11::function export_to_onnx =
pybind11::reinterpret_borrow<pybind11::function>(
pybind11::module::import("onnxruntime.training.experimental.exporter")
.attr("_export_jit_graph_to_onnx_model_proto"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:) can we request exporter team to provide nice C++ apis?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@baijumeswani FYI @wschin created a nice issue for this pytorch/pytorch#83764.

// Types of the inputs (typed to IValue) we got when compile the subgraph.
// Since the subgraph is compiled for these type, feeding
// inputs with different types may fail.
std::vector<c10::TypePtr> input_types_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_types_ is only used in ExampleRun. Can we remove it from being a member variable? And potentially remove the entire section in ExampleRun dedicated to inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExampleRun is required unfortunately. We have output_types_ and I think when debugging, the poor engineer would also like to see the corresponding input_types_. May we keep it?

return $?
}

os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple shellcheck warnings fyi. You can view them in the files tab or run shellcheck locally.

orttraining/orttraining/lazy_tensor/debug.cc Outdated Show resolved Hide resolved
orttraining/orttraining/lazy_tensor/debug.cc Outdated Show resolved Hide resolved
orttraining/orttraining/lazy_tensor/debug.cc Outdated Show resolved Hide resolved
return group;
}

torch::jit::value_list sortReverseTopological(torch::jit::ArrayRef<torch::jit::Value*> inputs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should method names be in PascalCase according to the style guide? Curious if it is intentionally different than those in orttraining/orttraining/lazy_tensor/flags.cc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function requires another round of refinement. It's mainly borrowed from Pytorch, if I change the naming too much here, it will be hard to map it back to Pytorch code when doing the refinement. Maybe let's keep it for now. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg! Maybe a todo comment somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will create item for tracking.

souptc
souptc previously approved these changes Aug 19, 2022
baijumeswani
baijumeswani previously approved these changes Aug 19, 2022
@wschin wschin merged commit dc486d1 into main Aug 22, 2022
@wschin wschin deleted the wechi/ltc1 branch August 22, 2022 16:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants