Skip to content

Commit

Permalink
cpplint & Eager mode: refactor and add comments to empty_* functions,…
Browse files Browse the repository at this point in the history
… general lint cleanup in ort_aten (#12238)

* empty* comments and code reuse

* lint

* more cpplint

* add cpplint settings

* test empty
  • Loading branch information
msftlincoln committed Jul 20, 2022
1 parent 72c689a commit 424120d
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 73 deletions.
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,10 @@
"python.linting.pydocstyleArgs": [
"--convention=google"
],
"python.linting.banditEnabled": true
"python.linting.banditEnabled": true,
"cpplint.lineLength": 120,
"cpplint.filters": [
"-build/include_subdir",
"-runtime/references"
]
}
143 changes: 71 additions & 72 deletions orttraining/orttraining/eager/ort_aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
#include <c10/util/irange.h>
#include <ATen/WrapDimUtils.h>

#include <algorithm>
#include <vector>
#include <utility>

namespace torch_ort {
namespace eager {

//#pragma region Helpers
// #pragma region Helpers
using NodeAttributes = onnxruntime::NodeAttributes;
namespace {
inline bool is_device_supported(at::DeviceType type) {
Expand All @@ -34,7 +37,7 @@ namespace {
throw std::runtime_error("ORT copy: device not supported");
}
}
}
} // namespace

at::Tensor aten_tensor_from_ort(
OrtValue&& ot,
Expand All @@ -59,7 +62,7 @@ const std::vector<at::Tensor> aten_tensor_from_ort(

onnxruntime::MLDataType ort_scalar_type_from_aten(
at::ScalarType dtype) {
switch (dtype){
switch (dtype) {
case at::kFloat:
return onnxruntime::DataTypeImpl::GetType<float>();
case at::kDouble:
Expand Down Expand Up @@ -107,7 +110,7 @@ OrtValue create_ort_value(
break;
}
default:
// TODO: support more types
// TODO(unknown): support more types
// For most at::ScalarType, it should be safe to just call value.to<>
// on it, but for now we want to explicitly know when we've encountered
// a new scalar type while bringing up ORT eager mode.
Expand All @@ -131,13 +134,17 @@ OrtValue create_ort_value(
auto element_type = ort_scalar_type_from_aten(tensor.scalar_type());

OrtValue ort_tensor;
onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape(tensor.sizes().vec()), tensor.data_ptr(),
*mem_info, ort_tensor, 0L /* offset = 0 - because tensor.data_ptr() includes the underyling offset */,
tensor.strides().vec());
onnxruntime::Tensor::InitOrtValue(
element_type,
onnxruntime::TensorShape(tensor.sizes().vec()),
tensor.data_ptr(),
*mem_info, ort_tensor,
0L, // offset = 0 - because tensor.data_ptr() includes the underyling offset
tensor.strides().vec());
return ort_tensor;
}

OrtValue create_ort_value(const at::Tensor& tensor){
OrtValue create_ort_value(const at::Tensor& tensor) {
auto& invoker = GetORTInvoker(tensor.device());
return create_ort_value(invoker, tensor);
}
Expand All @@ -146,7 +153,7 @@ std::vector<OrtValue> create_ort_value(
onnxruntime::ORTInvoker& invoker,
at::TensorList values) {
auto output = std::vector<OrtValue>{};
for (auto element: values){
for (auto element : values) {
output.push_back(create_ort_value(element));
}
return output;
Expand All @@ -157,7 +164,7 @@ onnx::AttributeProto create_ort_attribute(
at::Scalar value,
const bool isTensor,
at::ScalarType type) {
if (isTensor){
if (isTensor) {
onnx::AttributeProto attr;
attr.set_name(name);
attr.set_type(onnx::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR);
Expand Down Expand Up @@ -190,8 +197,7 @@ onnx::AttributeProto create_ort_attribute(
ORT_THROW("Unsupported: at::ScalarType::", value.type());
}
return attr;
}
else{
} else {
return create_ort_attribute(name, value, value.type());
}
}
Expand Down Expand Up @@ -254,33 +260,33 @@ onnx::AttributeProto create_ort_attribute(
return attr;
}

bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(at::Scalar scalar, const std::vector<at::ScalarType>& valid_types) {
return std::find(valid_types.begin(), valid_types.end(), scalar.type()) != valid_types.end();
}

bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(at::Tensor tensor, const std::vector<at::ScalarType>& valid_types) {
return std::find(valid_types.begin(), valid_types.end(), tensor.scalar_type()) != valid_types.end();
}

bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(at::IntArrayRef arrary, const std::vector<at::ScalarType>& valid_types) {
return std::find(valid_types.begin(), valid_types.end(), at::kInt) != valid_types.end() ||
std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
}

bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(int64_t val, const std::vector<at::ScalarType>& valid_types) {
return std::find(valid_types.begin(), valid_types.end(), at::kLong) != valid_types.end();
}

bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(c10::optional<int64_t> val, const std::vector<at::ScalarType>& valid_types) {
return IsSupportedType(val.value(), valid_types);
}

bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types){
bool IsSupportedType(at::TensorList tensors, const std::vector<at::ScalarType>& valid_types) {
return IsSupportedType(tensors[0], valid_types);
}

ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype){
switch (dtype){
ONNX_NAMESPACE::TensorProto_DataType GetONNXTensorProtoDataType(at::ScalarType dtype) {
switch (dtype) {
case at::kFloat:
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
case at::kDouble:
Expand Down Expand Up @@ -349,7 +355,7 @@ c10::optional<at::ScalarType> PromoteScalarTypesWithCategory(
return typeFromTensor;
}

OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type){
OrtValue CastToType(onnxruntime::ORTInvoker& invoker, const OrtValue& input, at::ScalarType type) {
std::vector<OrtValue> output(1);
NodeAttributes attrs(2);
attrs["to"] = create_ort_attribute(
Expand Down Expand Up @@ -425,7 +431,7 @@ void resize_output(
resize_impl_ort_(invoker, output, shape);
}

//#pragma endregion
// #pragma endregion

/*
* Resize backing store of a TensorImpl.
Expand Down Expand Up @@ -530,52 +536,44 @@ void resize_impl_ort_(
return;
}

//#pragma region Hand-Implemented ATen Ops
// #pragma region Hand-Implemented ATen Ops

namespace aten {

at::Tensor empty_memory_format(
at::Tensor empty_strided(
at::IntArrayRef size,
// *,
at::IntArrayRef stride,
c10::optional<at::ScalarType> dtype_opt,
c10::optional<at::Layout> layout_opt,
c10::optional<at::Device> device_opt,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format);

assert(dtype_opt.has_value());
assert(device_opt.has_value());
c10::optional<at::Layout> layout_opt, // Ignored because there's no ONNX support.
c10::optional<at::Device> device_opt, // Will be ORT by the time this is dispatched.
c10::optional<bool> pin_memory_opt) { // Ignored because there's no ONNX support.
ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);

// TODO: validate options and memory format
// TODO: figure out how to get the correct element type.
OrtValue ot;
assert(device_opt.has_value());
at::ScalarType dtype = c10::dtype_or_default(dtype_opt);
auto& invoker = GetORTInvoker(*device_opt);
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(*dtype_opt), onnxruntime::TensorShape(size.vec()),
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot);
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()),
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot,
stride.vec());
return aten_tensor_from_ort(
std::move(ot),
at::TensorOptions()
.device(*device_opt)
.dtype(*dtype_opt));
.dtype(dtype));
}

at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride, c10::optional<at::ScalarType> dtype_opt,
c10::optional<at::Layout> layout_opt, c10::optional<at::Device> device_opt,
c10::optional<bool> pin_memory_opt) {
ORT_LOG_FN(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
at::Tensor empty_memory_format(
at::IntArrayRef size,
c10::optional<at::ScalarType> dtype_opt,
c10::optional<at::Layout> layout_opt,
c10::optional<at::Device> device_opt,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) { // Ignored because there's no ONNX support.
ORT_LOG_FN(size, dtype_opt, layout_opt, device_opt, pin_memory, memory_format);

// TODO: how to handle type conversion
OrtValue ot;
assert(device_opt.has_value());
// TODO: how to support layout
// assert(!layout_opt.has_value());
at::ScalarType dtype = c10::dtype_or_default(dtype_opt);
auto& invoker = GetORTInvoker(*device_opt);
onnxruntime::Tensor::InitOrtValue(ort_scalar_type_from_aten(dtype), onnxruntime::TensorShape(size.vec()),
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), ot,
stride.vec());
return aten_tensor_from_ort(std::move(ot), at::TensorOptions().device(*device_opt).dtype(dtype));
// Use the strided impl with default (no strides specified).
return empty_strided(size, at::IntArrayRef({}), dtype_opt, layout_opt, device_opt, pin_memory);
}

// aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
Expand All @@ -602,9 +600,9 @@ at::Tensor as_strided(
at::Tensor _reshape_alias(
const at::Tensor& self,
at::IntArrayRef size,
at::IntArrayRef stride){
at::IntArrayRef stride) {
ORT_LOG_FN(self, size, stride);
// TODO: support stride
// TODO(unknown): support stride
auto& invoker = GetORTInvoker(self.device());
auto ort_input = create_ort_value(invoker, self);
return aten_tensor_from_ort(
Expand Down Expand Up @@ -645,7 +643,7 @@ at::Tensor& copy_(
: src.device());
const auto ort_src = create_ort_value(invoker, src);
auto ort_self = create_ort_value(invoker, self);
if (self.scalar_type() != src.scalar_type()){
if (self.scalar_type() != src.scalar_type()) {
// invoke cast first
std::vector<OrtValue> ort_cast_output(1);
onnxruntime::NodeAttributes attrs(1);
Expand All @@ -661,8 +659,7 @@ at::Tensor& copy_(
"ORT return failure status:" + status.ErrorMessage());

copy(invoker, ort_cast_output[0], ort_self);
}
else{
} else {
copy(invoker, ort_src, ort_self);
}

Expand All @@ -671,7 +668,7 @@ at::Tensor& copy_(

at::Tensor _copy_from_and_resize(
const at::Tensor& self,
const at::Tensor& dst){
const at::Tensor& dst) {
ORT_LOG_FN(self, dst);

assert_tensor_supported(self);
Expand All @@ -688,11 +685,11 @@ at::Tensor _copy_from_and_resize(
return self;
}

at::Tensor& zero_(at::Tensor& self){
at::Tensor& zero_(at::Tensor& self) {
auto& invoker = GetORTInvoker(self.device());
auto ort_in_self = create_ort_value(invoker, self);
OrtValue flag_val;
//construct a constant tensor
// construct a constant tensor
auto element_type = onnxruntime::DataTypeImpl::GetType<int64_t>();
onnxruntime::Tensor::InitOrtValue(element_type, onnxruntime::TensorShape({}),
invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), flag_val);
Expand All @@ -715,18 +712,19 @@ at::Tensor& zero_(at::Tensor& self){
return self;
}

// TODO: enhance opgen.py to support inplace binary operations.
// TODO(unknown): enhance opgen.py to support inplace binary operations.
// aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
at::Tensor& add__Tensor(
at::Tensor& self,
const at::Tensor& other,
const at::Scalar& alpha) {
ORT_LOG_FN(self, other, alpha);

auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
if (
!IsSupportedType(alpha, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
!IsSupportedType(other, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16}) ||
!IsSupportedType(self, {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16})) {
!IsSupportedType(alpha, st) ||
!IsSupportedType(other, st) ||
!IsSupportedType(self, st)) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(add__Tensor)>::call(self, other, alpha);
Expand Down Expand Up @@ -827,8 +825,9 @@ bool keepdim,
at::Tensor& out) {
ORT_LOG_FN(self, dim, keepdim, out);

auto st = {at::kDouble, at::kLong, at::kHalf, at::kShort, at::kInt, at::kByte, at::kFloat, at::kBFloat16};
if (
!IsSupportedType(self, {at::kLong, at::kShort, at::kHalf, at::kBFloat16, at::kFloat, at::kByte, at::kInt, at::kDouble})) {
!IsSupportedType(self, st)) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(argmax_out)>::call(self, dim, keepdim, out);
Expand Down Expand Up @@ -1034,7 +1033,7 @@ at::Tensor& _log_softmax_out(
ORT_LOG_FN(self, dim, half_to_float, out);

if (
!IsSupportedType(self, {at::kBFloat16,at::kDouble,at::kFloat,at::kHalf})) {
!IsSupportedType(self, {at::kBFloat16, at::kDouble, at::kFloat, at::kHalf})) {
return at::native::call_fallback_fn<
&at::native::cpu_fallback,
ATEN_OP(_log_softmax_out)>::call(self, dim, half_to_float, out);
Expand Down Expand Up @@ -1096,7 +1095,7 @@ at::Tensor& _log_softmax_out(
ort_outputs_2_Transpose[0] = ort_input_out;

NodeAttributes attrs_2(1);
attrs_2["perm"] = create_ort_attribute("perm", axes);;
attrs_2["perm"] = create_ort_attribute("perm", axes);

status = invoker.Invoke("Transpose", {
std::move(ort_outputs_1_LogSoftmax[0]),
Expand Down Expand Up @@ -1165,9 +1164,9 @@ at::Tensor& mm_out(
}


} // namespace aten
} // namespace aten

//#pragma endregion
// #pragma endregion

} // namespace eager
} // namespace torch_ort
} // namespace eager
} // namespace torch_ort
7 changes: 7 additions & 0 deletions orttraining/orttraining/eager/test/ort_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ def test_zero_stride(self):
cpu_tensor_copied = ort_tensor.cpu()
assert cpu_tensor_copied.stride() == (0, 0, 0)

def test_empty(self):
device = self.get_device()
cpu_tensor = torch.empty(size=(3, 4))
ort_tensor = torch.empty(size=(3, 4), device=device)
assert ort_tensor.is_ort
assert ort_tensor.size() == cpu_tensor.size()

def test_softmax(self):
device = self.get_device()
cpu_tensor = torch.rand(3, 5)
Expand Down

0 comments on commit 424120d

Please sign in to comment.