diff --git a/bindings/python/iree/runtime/CMakeLists.txt b/bindings/python/iree/runtime/CMakeLists.txt index 13bb1c07cb7e..f0bbdb78e380 100644 --- a/bindings/python/iree/runtime/CMakeLists.txt +++ b/bindings/python/iree/runtime/CMakeLists.txt @@ -11,13 +11,9 @@ iree_pyext_module( SRCS "initialize_module.cc" "binding.h" - "function_abi.h" "hal.h" - "host_types.h" "vm.h" - "function_abi.cc" "hal.cc" - "host_types.cc" "status_utils.cc" "status_utils.h" "vm.cc" @@ -51,15 +47,6 @@ iree_py_library( ::PyExtRt ) -iree_py_test( - NAME - function_abi_test - SRCS - "function_abi_test.py" - LABELS - "nokokoro" -) - iree_py_test( NAME hal_test diff --git a/bindings/python/iree/runtime/__init__.py b/bindings/python/iree/runtime/__init__.py index 0596d889aed6..c0eb79601cf0 100644 --- a/bindings/python/iree/runtime/__init__.py +++ b/bindings/python/iree/runtime/__init__.py @@ -13,12 +13,8 @@ from . import binding # Pull some of the native symbols into the public API. -# FunctionAbi imports -from .binding import FunctionAbi # Hal imports from .binding import BufferUsage, HalBuffer, HalDevice, HalDriver, HalElementType, MemoryAccess, MemoryType, Shape -# HostTypeFactory imports -from .binding import HostTypeFactory # Vm imports from .binding import create_hal_module, Linkage, VmVariantList, VmFunction, VmInstance, VmContext, VmModule # SystemApi diff --git a/bindings/python/iree/runtime/function_abi.cc b/bindings/python/iree/runtime/function_abi.cc deleted file mode 100644 index ce1158f47480..000000000000 --- a/bindings/python/iree/runtime/function_abi.cc +++ /dev/null @@ -1,842 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "bindings/python/iree/runtime/function_abi.h" - -#include - -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "bindings/python/iree/runtime/hal.h" -#include "bindings/python/iree/runtime/status_utils.h" -#include "bindings/python/iree/runtime/vm.h" -#include "iree/base/api.h" -#include "iree/base/signature_parser.h" -#include "iree/hal/api.h" -#include "iree/modules/hal/hal_module.h" -#include "iree/vm/api.h" - -namespace iree { -namespace python { - -namespace { - -class SipLinearizeInputsVisitor { - public: - SipLinearizeInputsVisitor(SipSignatureParser& parser, py::tuple& py_args, - py::dict& py_kwargs, - absl::InlinedVector& linear_py_args) - : parser_(parser), - py_args_(py_args), - py_kwargs_(py_kwargs), - linear_py_args_(linear_py_args) {} - - void IntegerKey(SipSignatureParser& p, int k) { - auto current = tos(); - try { - auto current_seq = current.cast(); - stack_.push_back(current_seq[k]); - } catch (std::exception& e) { - auto message = - absl::StrCat("Expected sequence index ", k, " not found in ", - py::repr(current).cast()); - SetError(std::move(message)); - } - } - - void StringKey(SipSignatureParser& p, absl::string_view k) { - auto current = tos(); - py::str py_k(k.data(), k.size()); - try { - auto current_dict = tos().cast(); - stack_.push_back(current_dict[py_k]); - } catch (std::exception& e) { - auto message = absl::StrCat("Expected key '", k, "' not found in ", - py::repr(current).cast()); - SetError(std::move(message)); - } - } - - void OpenStruct(SipSignatureParser& p, - SipSignatureParser::StructType struct_type) { - // Only structs directly off of the root are opened without a key. - if (!stack_.empty()) return; - - py::handle tos; - switch (struct_type) { - case SipSignatureParser::StructType::kDict: - tos = py_kwargs_; - break; - case SipSignatureParser::StructType::kSequence: - tos = py_args_; - break; - } - stack_.push_back(tos); - } - - void CloseStruct(SipSignatureParser& p) { - if (!stack_.empty()) { - stack_.pop_back(); - } - } - - void MapToRawSignatureIndex(SipSignatureParser& p, int index) { - if (static_cast(linear_py_args_.size()) <= index) { - linear_py_args_.resize(index + 1); - } - linear_py_args_[index] = tos(); - if (!stack_.empty()) { - stack_.pop_back(); - } - } - - private: - py::handle tos() { - if (stack_.empty()) { - SetError("Mismatched structures during unpacking arguments"); - return py::handle(); - } - return stack_.back(); - } - - void SetError(std::string message) { parser_.SetError(message); } - - SipSignatureParser& parser_; - py::tuple& py_args_; - py::dict& py_kwargs_; - absl::InlinedVector& linear_py_args_; - - // The struct stack. Top is the last. - // When the stack is empty, opening a struct will push the first entry: - // py_args_ if a sequence and py_kwargs_ if a dict. Otherwise, new stack - // levels are opened upon key resolution. - // Either CloseStruct or MapToRawSignatureIndex terminate each level of - // the stack. - absl::InlinedVector stack_; -}; - -class SipStructureResultsVisitor { - public: - SipStructureResultsVisitor( - SipSignatureParser& parser, - absl::InlinedVector& linear_py_results) - : parser_(parser), linear_py_results_(linear_py_results) {} - - void IntegerKey(SipSignatureParser& p, int k) { - pending_assign_key_ = py::int_(k); - } - - void StringKey(SipSignatureParser& p, absl::string_view k) { - pending_assign_key_ = py::str(k.data(), k.size()); - } - - void OpenStruct(SipSignatureParser& p, - SipSignatureParser::StructType struct_type) { - py::object struct_obj; - bool is_dict; - switch (struct_type) { - case SipSignatureParser::StructType::kDict: - struct_obj = py::dict(); - is_dict = true; - break; - case SipSignatureParser::StructType::kSequence: - struct_obj = py::list(); - is_dict = false; - break; - default: - SetError("Illegal structure type"); - return; - } - // Must assign before pushing so as to assign to the prior level. - AssignCurrent(struct_obj); - stack_.push_back(std::make_pair(std::move(struct_obj), is_dict)); - } - - void CloseStruct(SipSignatureParser& p) { - if (!stack_.empty()) stack_.pop_back(); - pending_assign_key_ = py::none(); // Just in case (for error path). - } - - void MapToRawSignatureIndex(SipSignatureParser& p, int index) { - if (index < 0 || index >= static_cast(linear_py_results_.size())) { - SetError("Raw result index out of range in reflection metadata"); - return; - } - py::object current_obj = linear_py_results_[index]; - AssignCurrent(std::move(current_obj)); - } - - py::object ConsumeResult() { - if (result) - return std::move(result); - else - return py::none(); - } - - private: - void AssignCurrent(py::object value) { - if (stack_.empty()) { - if (result) { - SetError("Attempt to unpack multiple roots"); - return; - } - result = std::move(value); - } else { - if (!pending_assign_key_ || pending_assign_key_.is_none()) { - SetError("Attempt to assign out of order"); - return; - } - - try { - auto stack_entry = stack_.back(); - bool is_dict = stack_entry.second; - if (is_dict) { - stack_entry.first.cast()[pending_assign_key_] = value; - } else { - int index = pending_assign_key_.cast(); - py::list l = stack_entry.first.cast(); - // Technically, signature keys can come out of order, which is sad. - // none-fill the list as needed to fill the gap. - // TODO: Further guarantees can be enforced at conversion time, - // simplifying this. - bool extended = false; - int list_size = l.size(); - if (list_size <= index) { - while (l.size() < index) { - l.append(py::none()); - extended = true; - } - l.append(std::move(value)); - } else { - l[index] = std::move(value); - } - pending_assign_key_ = py::none(); - } - } catch (std::exception& e) { - SetError("Corrupt sip signature: Signature/data type mismatch"); - pending_assign_key_ = py::none(); - } - } - } - - void SetError(std::string message) { parser_.SetError(message); } - - SipSignatureParser& parser_; - absl::InlinedVector& linear_py_results_; - py::object result; - - // Parse state. - // A new level of the stack is opened for each container. Each entry is a - // pair of (container, is_dict). If not is_dict, it is assumed to be a list. - absl::InlinedVector, 4> stack_; - // If a pending key has been set for a following assignment, it is noted - // here. The nested assignments, the call sequence is: - // 1. OpenStruct - // For-each key: - // a. IntegerKey or StringKey - // b. MapToRawSignatureIndex - // 2. CloseStruct - // For single-result situations, it is legal to just have a single, top-level - // call to MapToRawSignatureIndex, which causes the entire result to be - // equal to the current object. - py::object pending_assign_key_; -}; - -// Python friendly entry-point for creating an instance from a list -// of attributes. This is not particularly efficient and is primarily -// for testing. Typically, this will be created directly from a function -// and the attribute introspection will happen internal to C++. -std::unique_ptr PyCreateAbi( - HalDevice& device, std::shared_ptr host_type_factory, - std::vector> attrs) { - auto lookup = - [&attrs](absl::string_view key) -> absl::optional { - for (const auto& kv : attrs) { - if (kv.first == key) return kv.second; - } - return absl::nullopt; - }; - return FunctionAbi::Create(device, std::move(host_type_factory), lookup); -} - -VmVariantList PyAllocateResults(FunctionAbi* self, VmVariantList& f_args, - bool static_alloc) { - auto f_results = VmVariantList::Create(self->raw_result_arity()); - if (static_alloc) { - // For static dispatch, attempt to fully allocate and perform shape - // inference. - self->AllocateResults(absl::MakeConstSpan(self->raw_config().results), - f_args, f_results); - } - return f_results; -} - -// RAII wrapper for a Py_buffer which calls PyBuffer_Release when it goes -// out of scope. -class PyBufferReleaser { - public: - PyBufferReleaser(Py_buffer& b) : b_(b) {} - ~PyBufferReleaser() { PyBuffer_Release(&b_); } - - private: - Py_buffer& b_; -}; - -pybind11::error_already_set RaiseBufferMismatchError( - std::string message, py::handle obj, - const RawSignatureParser::Description& desc) { - message.append("For argument = "); - auto arg_py_str = py::str(obj); - auto arg_str = static_cast(arg_py_str); - message.append(arg_str); - message.append(" (expected "); - desc.ToString(message); - message.append(")"); - return RaiseValueError(message.c_str()); -} - -// Verifies and maps the py buffer shape and layout to the bound argument. -// Returns false if not compatible. -void MapBufferAttrs(Py_buffer& py_view, - const RawSignatureParser::Description& desc, - absl::InlinedVector& dynamic_dims) { - // Verify that rank matches. - if (py_view.ndim != desc.dims.size()) { - throw RaiseBufferMismatchError( - absl::StrCat("Mismatched buffer rank (received: ", py_view.ndim, - ", expected: ", desc.dims.size(), "): "), - py::handle(py_view.obj), desc); - } - - // Verify that the item size matches. - size_t f_item_size = - AbiConstants::kScalarTypeSize[static_cast(desc.buffer.scalar_type)]; - if (f_item_size != py_view.itemsize) { - throw RaiseBufferMismatchError( - absl::StrCat("Mismatched buffer item size (received: ", - py_view.itemsize, ", expected: ", f_item_size, "): "), - py::handle(py_view.obj), desc); - } - - // Note: The python buffer format does not map precisely to IREE's type - // system, so the below is only advisory for where they do match. Otherwise, - // it is basically a bitcast. - const char* f_expected_format = - kScalarTypePyFormat[static_cast(desc.buffer.scalar_type)]; - - // If the format is booleans, we should treat it as bytes. - const char* f_found_format = py_view.format; - if (strcmp(f_found_format, "?") == 0) { - f_found_format = "b"; - } - - if (f_expected_format != nullptr && - strcmp(f_expected_format, f_found_format) != 0) { - throw RaiseBufferMismatchError( - absl::StrCat("Mismatched buffer format (received: ", py_view.format, - ", expected: ", f_expected_format, "): "), - py::handle(py_view.obj), desc); - } - - // Verify shape, populating dynamic_dims while looping. - for (size_t i = 0; i < py_view.ndim; ++i) { - auto py_dim = py_view.shape[i]; - auto f_dim = desc.dims[i]; - if (f_dim < 0) { - // Dynamic. - dynamic_dims.push_back(py_dim); - } else if (py_dim != f_dim) { - // Mismatch. - throw RaiseBufferMismatchError( - absl::StrCat("Mismatched buffer dim (received: ", py_dim, - ", expected: ", f_dim, "): "), - py::handle(py_view.obj), desc); - } - } -} - -void PackScalar(const RawSignatureParser::Description& desc, py::handle py_arg, - VmVariantList& f_args) { - iree_vm_value_t value; - value.type = IREE_VM_VALUE_TYPE_I32; - switch (desc.scalar.type) { - case AbiConstants::ScalarType::kUint8: - case AbiConstants::ScalarType::kUint16: - case AbiConstants::ScalarType::kUint32: { - value.i32 = py_arg.cast(); - break; - } - case AbiConstants::ScalarType::kSint8: - case AbiConstants::ScalarType::kSint16: - case AbiConstants::ScalarType::kSint32: { - value.i32 = py_arg.cast(); - break; - } - default: - throw RaisePyError(PyExc_NotImplementedError, "Unsupported scalar type"); - } - CheckApiStatus(iree_vm_list_push_value(f_args.raw_ptr(), &value), - "Could not pack scalar argument"); -} - -py::object UnpackScalar(const RawSignatureParser::Description& desc, - const iree_vm_variant_t& f_result) { - switch (desc.scalar.type) { - case AbiConstants::ScalarType::kUint8: - case AbiConstants::ScalarType::kUint16: - case AbiConstants::ScalarType::kUint32: { - return py::int_(static_cast(f_result.i32)); - } - case AbiConstants::ScalarType::kSint8: - case AbiConstants::ScalarType::kSint16: - case AbiConstants::ScalarType::kSint32: { - return py::int_(f_result.i32); - } - default: - throw RaisePyError(PyExc_NotImplementedError, "Unsupported scalar type"); - } -} - -} // namespace - -//------------------------------------------------------------------------------ -// FunctionAbi -//------------------------------------------------------------------------------ - -std::string FunctionAbi::DebugString() const { - RawSignatureParser p; - auto s = p.FunctionSignatureToString(raw_config_.signature); - if (!s) { - return ""; - } - auto result = absl::StrCat(""); - return result; -} - -std::unique_ptr FunctionAbi::Create( - HalDevice& device, std::shared_ptr host_type_factory, - AttributeLookup lookup) { - auto abi = - std::make_unique(device, std::move(host_type_factory)); - - // Fetch key attributes for the raw ABI. - auto raw_version = lookup("fv"); - auto raw_fsig_str = lookup("f"); - - // Validation. - if (!raw_fsig_str) { - throw RaiseValueError("No raw abi reflection metadata for function"); - } - if (!raw_version || *raw_version != "1") { - throw RaiseValueError("Unsupported raw function ABI version"); - } - - // Parse signature. - abi->raw_config().signature = std::string(*raw_fsig_str); - RawSignatureParser raw_parser; - raw_parser.VisitInputs(*raw_fsig_str, - [&abi](const RawSignatureParser::Description& d) { - abi->raw_config().inputs.push_back(d); - }); - raw_parser.VisitResults(*raw_fsig_str, - [&abi](const RawSignatureParser::Description& d) { - abi->raw_config().results.push_back(d); - }); - if (raw_parser.GetError()) { - auto message = absl::StrCat( - "Error parsing raw ABI signature: ", *raw_parser.GetError(), " ('", - *raw_fsig_str, "')"); - throw RaiseValueError(message.c_str()); - } - - auto reported_abi = lookup("abi"); - auto sip_signature = lookup("sip"); - if (reported_abi && *reported_abi == "sip" && sip_signature) { - abi->sip_signature_ = std::string(*sip_signature); - } - return abi; -} - -void FunctionAbi::Pack(py::tuple& py_args, py::dict& py_kwargs, - absl::Span descs, VmVariantList& args, - bool writable) { - absl::InlinedVector linear_py_args; - if (!sip_signature_) { - // There is no python -> linear translation. - size_t e = py_args.size(); - linear_py_args.resize(e); - for (size_t i = 0; i < e; ++i) { - linear_py_args[i] = py_args[i]; - } - } else { - // Linearize based on sip signature. - // Note that we use explicit errors here and do not let exceptions escape - // since parsing may be happening in a library not compiled for exceptions. - SipSignatureParser parser; - SipLinearizeInputsVisitor visitor(parser, py_args, py_kwargs, - linear_py_args); - parser.VisitInputs(visitor, *sip_signature_); - auto error = parser.GetError(); - if (error) { - auto message = - absl::StrCat("Could not unpack python arguments: ", *error); - throw RaiseValueError(message.c_str()); - } - } - RawPack(descs, absl::MakeSpan(linear_py_args), args, writable); -} - -py::object FunctionAbi::Unpack(absl::Span descs, - VmVariantList& f_results) { - absl::InlinedVector linear_py_results; - linear_py_results.resize(f_results.size()); - RawUnpack(descs, f_results, absl::MakeSpan(linear_py_results)); - - if (!sip_signature_) { - // Just emulate unpacking to a tuple, which is the standard way of - // returning multiple results from a python function. - auto linear_size = linear_py_results.size(); - if (linear_size == 0) { - return py::none(); - } else if (linear_size == 1) { - return std::move(linear_py_results.front()); - } - // Fall back to tuple multi-result form. - py::tuple py_result_tuple(linear_size); - for (size_t i = 0; i < linear_size; ++i) { - py_result_tuple[i] = std::move(linear_py_results[i]); - } - return std::move(py_result_tuple); // Without move, warns of copy. - } - - // Structured unpack with the sip signature. - // Note that we use explicit errors here and do not let exceptions escape - // since parsing may be happening in a library not compiled for exceptions. - SipSignatureParser parser; - SipStructureResultsVisitor visitor(parser, linear_py_results); - parser.VisitResults(visitor, *sip_signature_); - auto error = parser.GetError(); - if (error) { - auto message = - absl::StrCat("Could not create python structured results: ", *error); - throw RaiseValueError(message.c_str()); - } - - assert(!PyErr_Occurred()); - return visitor.ConsumeResult(); -} - -void FunctionAbi::RawPack(absl::Span descs, - absl::Span py_args, VmVariantList& f_args, - bool writable) { - if (descs.size() != py_args.size()) { - throw RaiseValueError("Mismatched RawPack() input arity"); - } - - for (size_t i = 0, e = descs.size(); i < e; ++i) { - const Description& desc = descs[i]; - switch (desc.type) { - case RawSignatureParser::Type::kBuffer: - PackBuffer(desc, py_args[i], f_args, writable); - break; - case RawSignatureParser::Type::kRefObject: - throw RaisePyError(PyExc_NotImplementedError, - "Ref objects not yet supported"); - break; - case RawSignatureParser::Type::kScalar: - PackScalar(desc, py_args[i], f_args); - break; - default: - throw RaisePyError(PyExc_NotImplementedError, - "Unsupported argument type"); - } - } -} - -void FunctionAbi::RawUnpack(absl::Span descs, - VmVariantList& f_results, - absl::Span py_results) { - py::object this_object = - py::cast(this, py::return_value_policy::take_ownership); - if (descs.size() != f_results.size() || descs.size() != py_results.size()) { - std::string s = std::string("Mismatched RawUnpack() result arity; descs=") + - std::to_string(descs.size()) + - ", f_results=" + std::to_string(f_results.size()) + - ", py_results=" + std::to_string(py_results.size()); - throw RaiseValueError(s.c_str()); - } - for (size_t i = 0, e = descs.size(); i < e; ++i) { - const Description& desc = descs[i]; - iree_vm_variant_t f_result = iree_vm_variant_empty(); - iree_status_t status = - iree_vm_list_get_variant(f_results.raw_ptr(), i, &f_result); - if (!iree_status_is_ok(status)) { - iree_status_ignore(status); - throw RaiseValueError("Could not get result from list"); - } - switch (desc.type) { - case RawSignatureParser::Type::kBuffer: { - iree_hal_buffer_view_t* buffer_view = - iree_hal_buffer_view_deref(f_result.ref); - if (!buffer_view) { - throw RaiseValueError( - "Could not deref result buffer view (wrong type?)"); - } - iree_hal_buffer_t* raw_buffer = - iree_hal_buffer_view_buffer(buffer_view); - if (!raw_buffer) { - throw RaiseValueError("Could not deref result buffer (wrong type?)"); - } - HalBuffer buffer = HalBuffer::RetainAndCreate(raw_buffer); - - // Extract dims from the buffer view. - size_t rank = 0; - absl::InlinedVector dims(6); - iree_status_t status = iree_hal_buffer_view_shape( - buffer_view, dims.capacity(), dims.data(), &rank); - if (iree_status_is_out_of_range(status)) { - dims.resize(rank); - status = iree_hal_buffer_view_shape(buffer_view, dims.capacity(), - dims.data(), &rank); - } - CheckApiStatus(status, "Error extracting shape"); - dims.resize(rank); - - // Deal with int32_t != int (but require 32bits). Happens on some - // embedded platforms. - static_assert(sizeof(dims[0]) == sizeof(int), - "expected int to be 32 bits"); - py_results[i] = host_type_factory_->CreateImmediateNdarray( - desc.buffer.scalar_type, - absl::MakeConstSpan(reinterpret_cast(dims.data()), - dims.size()), - std::move(buffer), this_object); - break; - } - case RawSignatureParser::Type::kRefObject: - throw RaisePyError(PyExc_NotImplementedError, - "Ref objects not yet supported"); - break; - case RawSignatureParser::Type::kScalar: - py_results[i] = UnpackScalar(desc, f_result); - break; - default: - throw RaisePyError(PyExc_NotImplementedError, - "Unsupported result type"); - } - } -} - -void FunctionAbi::AllocateResults(absl::Span descs, - VmVariantList& f_args, - VmVariantList& f_results) { - if (f_args.size() != raw_config().inputs.size()) { - throw RaiseValueError("Mismatched AllocateResults() input arity"); - } - - for (size_t i = 0, e = descs.size(); i < e; ++i) { - const Description& desc = descs[i]; - iree_device_size_t alloc_size = - AbiConstants::kScalarTypeSize[static_cast( - desc.buffer.scalar_type)]; - switch (desc.type) { - case RawSignatureParser::Type::kBuffer: { - absl::InlinedVector dims; - for (auto dim : desc.dims) { - if (dim < 0) { - // If there is a dynamic dim, fallback to completely func allocated - // result. This is the worst case because it will force a - // pipeline stall. - // TODO(laurenzo): Invoke shape resolution function if available - // to allocate full result. - f_results.AppendNullRef(); - } - alloc_size *= dim; - dims.push_back(dim); - } - - // Static cases are easy. - iree_hal_buffer_t* raw_buffer; - CheckApiStatus(iree_hal_allocator_allocate_buffer( - device_.allocator(), - static_cast( - IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE), - IREE_HAL_BUFFER_USAGE_ALL, alloc_size, &raw_buffer), - "Error allocating host visible buffer"); - auto element_type = static_cast( - kScalarTypeToHalElementType[static_cast( - desc.scalar.type)]); - iree_hal_buffer_view_t* buffer_view; - CheckApiStatus( - iree_hal_buffer_view_create(raw_buffer, dims.data(), dims.size(), - element_type, &buffer_view), - "Error allocating buffer_view"); - iree_hal_buffer_release(raw_buffer); - iree_vm_ref_t buffer_view_ref = - iree_hal_buffer_view_move_ref(buffer_view); - CheckApiStatus( - iree_vm_list_push_ref_move(f_results.raw_ptr(), &buffer_view_ref), - "Error moving buffer"); - break; - } - case RawSignatureParser::Type::kRefObject: - throw RaisePyError(PyExc_NotImplementedError, - "Ref objects not yet supported"); - break; - case RawSignatureParser::Type::kScalar: - break; - default: - throw RaisePyError(PyExc_NotImplementedError, - "Unsupported allocation argument type"); - } - } -} - -void FunctionAbi::PackBuffer(const RawSignatureParser::Description& desc, - py::handle py_arg, VmVariantList& f_args, - bool writable) { - // Request a view of the buffer (use the raw python C API to avoid some - // allocation and copying at the pybind level). - Py_buffer py_view; - // Note that only C-Contiguous ND-arrays are presently supported, so - // only request that via PyBUF_ND. Long term, we should consult an - // "oracle" in the runtime to determine the precise required format and - // set flags accordingly (and fallback/copy on failure). - int flags = PyBUF_FORMAT | PyBUF_ND; - if (writable) { - flags |= PyBUF_WRITABLE; - } - - // Acquire the backing buffer and setup RAII release. - if (PyObject_GetBuffer(py_arg.ptr(), &py_view, flags) != 0) { - // The GetBuffer call is required to set an appropriate error. - throw py::error_already_set(); - } - PyBufferReleaser py_view_releaser(py_view); - - // Whether the py object needs to be retained with the argument. - // Should be set to true if directly mapping, false if copied. - bool depends_on_pyobject = false; - - // Verify compatibility. - absl::InlinedVector dynamic_dims; - MapBufferAttrs(py_view, desc, dynamic_dims); - - // Allocate a HalBuffer. - // This is hard-coded to C-contiguous right now. - // TODO(laurenzo): Expand to other layouts as needed. - // TODO(laurenzo): Wrap and retain original buffer (depends_on_pyobject=true). - iree_hal_buffer_t* raw_buffer; - CheckApiStatus(iree_hal_allocator_allocate_buffer( - device_.allocator(), - static_cast( - IREE_HAL_MEMORY_TYPE_HOST_LOCAL | - IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE), - IREE_HAL_BUFFER_USAGE_ALL, py_view.len, &raw_buffer), - "Failed to allocate device visible buffer"); - CheckApiStatus( - iree_hal_buffer_write_data(raw_buffer, 0, py_view.buf, py_view.len), - "Error writing to input buffer"); - - // Only capture the reference to the exporting object (incrementing it) - // once guaranteed successful. - if (depends_on_pyobject) { - // Note for future implementation: there needs to be a place to stash - // references to be kept alive which back a buffer. This is likely an - // additional bag of refs returned from this function, which can then - // be attached to an invocation. - throw RaisePyError(PyExc_NotImplementedError, - "Dependent buffer arguments not implemented"); - } - - // Create the buffer_view. (note that numpy shape is ssize_t) - auto element_type = static_cast( - kScalarTypeToHalElementType[static_cast(desc.scalar.type)]); - absl::InlinedVector dims(py_view.ndim); - std::copy(py_view.shape, py_view.shape + py_view.ndim, dims.begin()); - iree_hal_buffer_view_t* buffer_view; - CheckApiStatus( - iree_hal_buffer_view_create(raw_buffer, dims.data(), dims.size(), - element_type, &buffer_view), - "Error allocating buffer_view"); - iree_hal_buffer_release(raw_buffer); - iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_move_ref(buffer_view); - CheckApiStatus(iree_vm_list_push_ref_move(f_args.raw_ptr(), &buffer_view_ref), - "Error moving buffer view"); -} - -std::vector SerializeVmVariantList(VmVariantList& vm_list) { - size_t size = vm_list.size(); - std::vector results; - results.reserve(size); - for (iree_host_size_t i = 0; i < size; ++i) { - iree_vm_variant_t variant = iree_vm_variant_empty(); - iree_status_t status = - iree_vm_list_get_variant(vm_list.raw_ptr(), i, &variant); - CheckApiStatus(status, "Failed to get vm variant from list"); - - if (iree_vm_variant_is_value(variant)) { - results.push_back("i32=" + std::to_string(variant.i32)); - } else if (iree_vm_variant_is_ref(variant) && - iree_hal_buffer_view_isa(variant.ref)) { - auto buffer_view = iree_hal_buffer_view_deref(variant.ref); - - std::string result_str(4096, '\0'); - iree_status_t status; - do { - iree_host_size_t actual_length = 0; - iree_host_size_t max_element_count = - std::numeric_limits::max(); - status = iree_hal_buffer_view_format(buffer_view, max_element_count, - result_str.size() + 1, - &result_str[0], &actual_length); - result_str.resize(actual_length); - } while (iree_status_is_out_of_range(status)); - CheckApiStatus(status, - "Failed to create a string representation of the inputs"); - - results.push_back(result_str); - } else { - RaiseValueError( - "Expected vm_list's elements to be scalars or buffer views."); - } - } - return results; -} - -void SetupFunctionAbiBindings(pybind11::module m) { - py::class_>(m, "FunctionAbi") - .def(py::init(&PyCreateAbi)) - .def("__repr__", &FunctionAbi::DebugString) - .def_property_readonly("raw_input_arity", &FunctionAbi::raw_input_arity) - .def_property_readonly("raw_result_arity", &FunctionAbi::raw_result_arity) - .def("pack_inputs", - [](FunctionAbi* self, py::args py_args, py::kwargs py_kwargs) { - VmVariantList f_args = VmVariantList::Create(py_args.size()); - self->Pack(py_args, py_kwargs, - absl::MakeConstSpan(self->raw_config().inputs), f_args, - false /* writable */); - return f_args; - }) - .def("serialize_vm_list", - [](FunctionAbi* self, VmVariantList& vm_list) { - return SerializeVmVariantList(vm_list); - }) - .def("allocate_results", &PyAllocateResults, py::arg("f_results"), - py::arg("static_alloc") = true) - .def("unpack_results", [](FunctionAbi* self, VmVariantList& f_results) { - return self->Unpack(absl::MakeConstSpan(self->raw_config().results), - f_results); - }); -} - -} // namespace python -} // namespace iree diff --git a/bindings/python/iree/runtime/function_abi.h b/bindings/python/iree/runtime/function_abi.h deleted file mode 100644 index f91bb5a9fdc8..000000000000 --- a/bindings/python/iree/runtime/function_abi.h +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_BINDINGS_PYTHON_IREE_RT_FUNCTION_ABI_H_ -#define IREE_BINDINGS_PYTHON_IREE_RT_FUNCTION_ABI_H_ - -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "binding.h" -#include "hal.h" -#include "host_types.h" -#include "iree/base/signature_parser.h" -#include "vm.h" - -namespace iree { -namespace python { - -// Forward declarations. -class HalDevice; - -// Instantiated with function attributes in order to process inputs/outputs. -class FunctionAbi { - public: - using AttributeLookup = - std::function(absl::string_view)>; - FunctionAbi(HalDevice& device, - std::shared_ptr host_type_factory) - : device_(HalDevice::RetainAndCreate(device.raw_ptr())), - host_type_factory_(std::move(host_type_factory)) {} - virtual ~FunctionAbi() = default; - - using Description = RawSignatureParser::Description; - using InputDescriptionVector = absl::InlinedVector; - using ResultDescriptionVector = absl::InlinedVector; - - struct RawConfig { - InputDescriptionVector inputs; - ResultDescriptionVector results; - - // The following are retained to aid debugging but may be empty if - // disabled. - std::string signature; - }; - - // Creates an instance based on the function attributes. - static std::unique_ptr Create( - HalDevice& device, std::shared_ptr host_type_factory, - AttributeLookup lookup); - - RawConfig& raw_config() { return raw_config_; } - int raw_input_arity() const { return raw_config_.inputs.size(); } - int raw_result_arity() const { return raw_config_.results.size(); } - - // Structured packing. Linearizes structures according to the ABI and - // delegates to RawPack. - void Pack(pybind11::tuple& py_args, pybind11::dict& kwargs, - absl::Span descs, VmVariantList& args, - bool writable); - - // Structured unpacking. Delegates to RawUnpack and delinearizes according to - // the ABI. - pybind11::object Unpack(absl::Span descs, - VmVariantList& f_results); - - // Raw packing. These always operate on the linear span of raw inputs and - // results. Some ABIs perform a higher level of mapping on top of this, - // which can be accessed via the non-prefixed Pack/Unpack methods. - // Given a span of descriptions, packs the given py_args into the span - // of function args. All spans must be of the same size. - void RawPack(absl::Span descs, - absl::Span py_args, VmVariantList& args, - bool writable); - - // Raw unpacks f_results into py_results. - // Note that this consumes entries in f_results as needed, leaving them - // as nullptr. - // Ordinarily, this will be invoked along with AllocateResults() but it - // is broken out for testing. - void RawUnpack(absl::Span descs, VmVariantList& f_results, - absl::Span py_results); - - // Given bound function arguments (from RawPack or equiv) and signature - // descriptors, allocates results for the function invocation. For fully - // specified result types, this can be done purely by matching up - // reflection metadata and an oracle for determining layout. For dynamically - // shaped or data-dependent shaped results, the metadata about the function - // arguments may be required to generate additional allocation function calls. - // Finally, in truly data-dependent cases, some results may not be resolvable - // ahead of time, resulting in a nullptr in f_results. In such cases, the - // invocation must ensure proper barriers are in place to fully execute the - // function prior to delivering results to the user layer. - void AllocateResults(absl::Span descs, - VmVariantList& f_args, VmVariantList& f_results); - - // Gets the string representation. - std::string DebugString() const; - - private: - void PackBuffer(const RawSignatureParser::Description& desc, - py::handle py_arg, VmVariantList& f_args, bool writable); - - HalDevice device_; - std::shared_ptr host_type_factory_; - RawConfig raw_config_; - // If present, the SIP signature maps a "structured signature" to linearized - // input and result lists. In layman's terms, this maps the normal python - // *args and **kwargs calling convention with nested dicts and sequences. - // It is used by TensorFlow, which lacks higher level types for such things. - absl::optional sip_signature_; -}; - -void SetupFunctionAbiBindings(pybind11::module m); - -} // namespace python -} // namespace iree - -#endif // IREE_BINDINGS_PYTHON_IREE_RT_FUNCTION_ABI_H_ diff --git a/bindings/python/iree/runtime/function_abi_test.py b/bindings/python/iree/runtime/function_abi_test.py deleted file mode 100644 index 5697a653e7a4..000000000000 --- a/bindings/python/iree/runtime/function_abi_test.py +++ /dev/null @@ -1,237 +0,0 @@ -# Lint as: python3 -# Copyright 2019 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# pylint: disable=line-too-long -# pylint: disable=broad-except -"""Tests for the function abi.""" - -import re - -from absl import logging -from absl.testing import absltest -import iree.runtime -import numpy as np - -ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1 = ( - ("fv", "1"), - # Equiv to: - # (Buffer) -> (Buffer) - ("f", "I15!B11!d10d128d64R15!B11!t6d32d8d64"), -) - -ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1 = ( - ("fv", "1"), - # Equiv to: - # (Buffer) -> (Buffer) - ("f", "I15!B11!d-1d128d64R15!B11!t6d-1d8d64"), -) - -ATTRS_SIP_1LEVEL_DICT = ( - # Extracted from reflection attributes for mobilebert. - # (via iree-dump-module). - # Input dict of "input_ids", "input_mask", "segment_ids" - # Output dict of "end_logits", "start_logits" - # Raw signature is: - # (Buffer, Buffer, Buffer) -> (Buffer, Buffer) - ("fv", "1"), - ("f", "I34!B9!t6d1d384B9!t6d1d384B9!t6d1d384R19!B7!d1d384B7!d1d384"), - ("abi", "sip"), - ("sip", - "I53!D49!K10!input_ids_1K11!input_mask_2K12!segment_ids_0R39!D35!K11!end_logits_0K13!start_logits_1" - ), -) - -ATTRS_SIP_LINEAR_2ARG = ( - # SIP form of a function that takes 2 args of Buffer and - # returns one of the same type/shape. - ("fv", "1"), - ("f", "I11!B3!d1B3!d1R6!B3!d1"), - ("abi", "sip"), - ("sip", "I12!S9!k0_0k1_1R3!_0"), -) - - -class HostTypeFactory(absltest.TestCase): - - def test_baseclass(self): - htf = iree.runtime.HostTypeFactory() - logging.info("HostTypeFactory: %s", htf) - - -class FunctionAbiTest(absltest.TestCase): - - @classmethod - def setUpClass(cls): - super().setUpClass() - driver_names = iree.runtime.HalDriver.query() - for driver_name in driver_names: - logging.info("Try to create driver: %s", driver_name) - try: - cls.driver = iree.runtime.HalDriver.create(driver_name) - cls.device = cls.driver.create_default_device() - except Exception: - logging.error("Could not create driver: %s", driver_name) - else: - break - - def setUp(self): - super().setUp() - self.htf = iree.runtime.HostTypeFactory.get_numpy() - - def test_sip_dict_arg_result_success(self): - fabi = iree.runtime.FunctionAbi(self.device, self.htf, - ATTRS_SIP_1LEVEL_DICT) - self.assertEqual( - ", Buffer, Buffer) -> (Buffer, Buffer) SIP:'I53!D49!K10!input_ids_1K11!input_mask_2K12!segment_ids_0R39!D35!K11!end_logits_0K13!start_logits_1'>", - repr(fabi)) - input_ids = np.zeros((1, 384), dtype=np.int32) - input_mask = np.zeros((1, 384), dtype=np.int32) - segment_ids = np.zeros((1, 384), dtype=np.int32) - f_args = fabi.pack_inputs(input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids) - self.assertEqual( - "", - repr(f_args)) - f_results = fabi.allocate_results(f_args) - logging.info("f_results: %s", f_results) - self.assertEqual( - "", - repr(f_results)) - py_result = fabi.unpack_results(f_results) - start_logits = py_result["start_logits"] - end_logits = py_result["end_logits"] - self.assertEqual(np.float32, start_logits.dtype) - self.assertEqual(np.float32, end_logits.dtype) - self.assertEqual((1, 384), start_logits.shape) - self.assertEqual((1, 384), end_logits.shape) - - def test_sip_linear_success(self): - fabi = iree.runtime.FunctionAbi(self.device, self.htf, - ATTRS_SIP_LINEAR_2ARG) - self.assertEqual( - ", Buffer) -> (Buffer) SIP:'I12!S9!k0_0k1_1R3!_0'>", - repr(fabi)) - arg0 = np.zeros((1,), dtype=np.float32) - arg1 = np.zeros((1,), dtype=np.float32) - f_args = fabi.pack_inputs(arg0, arg1) - self.assertEqual( - "", - repr(f_args)) - f_results = fabi.allocate_results(f_args) - logging.info("f_results: %s", f_results) - self.assertEqual("", - repr(f_results)) - result = fabi.unpack_results(f_results) - print("SINGLE RESULT:", result) - self.assertEqual(np.float32, result.dtype) - self.assertEqual((1,), result.shape) - - def test_static_arg_success(self): - fabi = iree.runtime.FunctionAbi( - self.device, self.htf, - ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) - logging.info("fabi: %s", fabi) - self.assertEqual( - ") -> " - "(Buffer)>", repr(fabi)) - self.assertEqual(1, fabi.raw_input_arity) - self.assertEqual(1, fabi.raw_result_arity) - - arg = np.zeros((10, 128, 64), dtype=np.float32) - packed = fabi.pack_inputs(arg) - logging.info("packed: %s", packed) - self.assertEqual("", - repr(packed)) - - def test_static_result_success(self): - fabi = iree.runtime.FunctionAbi( - self.device, self.htf, - ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) - arg = np.zeros((10, 128, 64), dtype=np.float32) - f_args = fabi.pack_inputs(arg) - f_results = fabi.allocate_results(f_args) - logging.info("f_results: %s", f_results) - self.assertEqual("", - repr(f_results)) - py_result = fabi.unpack_results(f_results) - self.assertEqual(np.int32, py_result.dtype) - self.assertEqual((32, 8, 64), py_result.shape) - - def test_dynamic_alloc_result_success(self): - fabi = iree.runtime.FunctionAbi( - self.device, self.htf, - ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) - arg = np.zeros((10, 128, 64), dtype=np.float32) - f_args = fabi.pack_inputs(arg) - f_results = fabi.allocate_results(f_args, static_alloc=False) - logging.info("f_results: %s", f_results) - self.assertEqual("", repr(f_results)) - - def test_dynamic_arg_success(self): - fabi = iree.runtime.FunctionAbi( - self.device, self.htf, - ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1) - logging.info("fabi: %s", fabi) - self.assertEqual( - ") -> " - "(Buffer)>", repr(fabi)) - self.assertEqual(1, fabi.raw_input_arity) - self.assertEqual(1, fabi.raw_result_arity) - - arg = np.zeros((10, 128, 64), dtype=np.float32) - packed = fabi.pack_inputs(arg) - logging.info("packed: %s", packed) - self.assertEqual("", - repr(packed)) - - def test_static_arg_rank_mismatch(self): - fabi = iree.runtime.FunctionAbi( - self.device, self.htf, - ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) - logging.info("fabi: %s", fabi) - arg = np.zeros((10,), dtype=np.float32) - with self.assertRaisesRegex( - ValueError, - re.escape("Mismatched buffer rank (received: 1, expected: 3)")): - fabi.pack_inputs(arg) - - def test_static_arg_eltsize_mismatch(self): - fabi = iree.runtime.FunctionAbi( - self.device, self.htf, - ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) - logging.info("fabi: %s", fabi) - arg = np.zeros((10, 128, 64), dtype=np.float64) - with self.assertRaisesRegex( - ValueError, - re.escape("Mismatched buffer item size (received: 8, expected: 4)")): - fabi.pack_inputs(arg) - - def test_static_arg_dtype_mismatch(self): - fabi = iree.runtime.FunctionAbi( - self.device, self.htf, - ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) - logging.info("fabi: %s", fabi) - arg = np.zeros((10, 128, 64), dtype=np.int32) - with self.assertRaisesRegex( - ValueError, - re.escape("Mismatched buffer format (received: i, expected: f)")): - fabi.pack_inputs(arg) - - def test_static_arg_static_dim_mismatch(self): - fabi = iree.runtime.FunctionAbi( - self.device, self.htf, - ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) - logging.info("fabi: %s", fabi) - arg = np.zeros((10, 32, 64), dtype=np.float32) - with self.assertRaisesRegex( - ValueError, - re.escape("Mismatched buffer dim (received: 32, expected: 128)")): - fabi.pack_inputs(arg) - - -if __name__ == "__main__": - absltest.main() diff --git a/bindings/python/iree/runtime/host_types.cc b/bindings/python/iree/runtime/host_types.cc deleted file mode 100644 index b2dcdcf789da..000000000000 --- a/bindings/python/iree/runtime/host_types.cc +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "bindings/python/iree/runtime/host_types.h" - -#include - -#include "absl/container/inlined_vector.h" -#include "bindings/python/iree/runtime/hal.h" -#include "bindings/python/iree/runtime/status_utils.h" -#include "iree/base/signature_parser.h" -#include "pybind11/numpy.h" - -namespace iree { -namespace python { - -const std::array( - AbiConstants::ScalarType::kMaxScalarType) + - 1> - kScalarTypePyFormat = { - "f", // kIeeeFloat32 = 0, - nullptr, // kIeeeFloat16 = 1, - "d", // kIeeeFloat64 = 2, - nullptr, // kGoogleBfloat16 = 3, - "b", // kSint8 = 4, - "h", // kSint16 = 5, - "i", // kSint32 = 6, - "q", // kSint64 = 7, - "c", // kUint8 = 8, - "H", // kUint16 = 9, - "I", // kUint32 = 10, - "Q", // kUint64 = 11, -}; -static_assert(kScalarTypePyFormat.size() == - AbiConstants::kScalarTypeSize.size(), - "Mismatch kScalarTypePyFormat"); - -const std::array( - AbiConstants::ScalarType::kMaxScalarType) + - 1> - kScalarTypeToHalElementType = { - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, - 32), // kIeeeFloat32 = 0, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, - 16), // kIeeeFloat16 = 1, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_FLOAT_IEEE, - 64), // kIeeeFloat64 = 2, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_UNKNOWN, - 16), // kGoogleBfloat16 = 3, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, - 8), // kSint8 = 4, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, - 16), // kSint16 = 5, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, - 32), // kSint32 = 6, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_SIGNED, - 64), // kSint64 = 7, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, - 8), // kUint8 = 8, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, - 16), // kUint16 = 9, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, - 32), // kUint32 = 10, - IREE_HAL_ELEMENT_TYPE_VALUE(IREE_HAL_NUMERICAL_TYPE_INTEGER_UNSIGNED, - 64), // kUint64 = 11, -}; -static_assert(kScalarTypeToHalElementType.size() == - AbiConstants::kScalarTypeSize.size(), - "Mismatch kScalarTypeToHalElementType"); - -namespace { - -class PyMappedMemory { - public: - struct Description { - size_t element_size; - const char* format; - absl::InlinedVector dims; - absl::InlinedVector strides; - - static Description ForNdarray(AbiConstants::ScalarType scalar_type, - absl::Span dims) { - unsigned scalar_type_i = static_cast(scalar_type); - if (scalar_type_i > - static_cast(AbiConstants::ScalarType::kMaxScalarType)) { - throw RaiseValueError("Illegal ScalarType"); - } - - Description d; - d.element_size = AbiConstants::kScalarTypeSize[scalar_type_i]; - d.format = kScalarTypePyFormat[scalar_type_i]; - if (!d.format) { - throw RaisePyError(PyExc_NotImplementedError, - "Unimplemented ScalarType"); - } - if (!dims.empty()) { - d.dims.resize(dims.size()); - d.strides.resize(dims.size()); - - for (size_t i = 0, e = dims.size(); i < e; ++i) { - d.dims[i] = dims[i]; - } - d.strides[dims.size() - 1] = d.element_size; - for (int i = dims.size() - 2; i >= 0; --i) { - d.strides[i] = d.strides[i + 1] * dims[i + 1]; - } - } - return d; - } - }; - - PyMappedMemory(Description desc, iree_hal_buffer_mapping_t mapped_memory, - HalBuffer buffer, py::object parent_keep_alive) - : parent_keep_alive_(std::move(parent_keep_alive)), - desc_(std::move(desc)), - mapped_memory_(mapped_memory), - buf_(std::move(buffer)) {} - ~PyMappedMemory() { - if (buf_) { - iree_hal_buffer_unmap_range(&mapped_memory_); - } - } - PyMappedMemory(PyMappedMemory&& other) - : mapped_memory_(other.mapped_memory_), buf_(std::move(other.buf_)) {} - - const Description& desc() const { return desc_; } - - static std::unique_ptr Read(Description desc, - HalBuffer buffer, - py::object parent_keep_alive) { - iree_device_size_t byte_length = - iree_hal_buffer_byte_length(buffer.raw_ptr()); - iree_hal_buffer_mapping_t mapped_memory; - CheckApiStatus(iree_hal_buffer_map_range( - buffer.raw_ptr(), IREE_HAL_MEMORY_ACCESS_READ, - 0 /* element_offset */, byte_length, &mapped_memory), - "Could not map memory"); - return std::make_unique(std::move(desc), mapped_memory, - std::move(buffer), - std::move(parent_keep_alive)); - } - - py::buffer_info ToBufferInfo() { - // TODO(laurenzo): py::buffer_info is a heavy-weight way to get the - // buffer. See about implementing the lower level buffer protocol. - // Unfortunately, this part of the pybind C++ API is all defined in terms - // of std::vector, making it less efficient than necessary. - return py::buffer_info(mapped_memory_.contents.data, desc_.element_size, - desc_.format, desc_.dims.size(), desc_.dims, - desc_.strides); - } - - private: - // Important: Since the parent_keep_alive object may be keeping things - // alive needed to deallocate various other fields, it must be destructed - // last (by being first here). - py::object parent_keep_alive_; - Description desc_; - iree_hal_buffer_mapping_t mapped_memory_; - HalBuffer buf_; -}; - -class NumpyHostTypeFactory : public HostTypeFactory { - py::object CreateImmediateNdarray(AbiConstants::ScalarType element_type, - absl::Span dims, - HalBuffer buffer, - py::object parent_keep_alive) override { - std::unique_ptr mapped_memory = PyMappedMemory::Read( - PyMappedMemory::Description::ForNdarray(element_type, dims), - std::move(buffer), std::move(parent_keep_alive)); - // Since an immediate ndarray was requested, we can just return a native - // ndarray directly (versus a proxy that needs to lazily map on access). - auto buffer_info = mapped_memory->ToBufferInfo(); - auto py_mapped_memory = py::cast(std::move(mapped_memory), - py::return_value_policy::take_ownership); - return py::array(py::dtype(buffer_info), buffer_info.shape, - buffer_info.strides, buffer_info.ptr, - std::move(py_mapped_memory) /* base */); - } -}; - -} // namespace - -//------------------------------------------------------------------------------ -// HostTypeFactory -//------------------------------------------------------------------------------ - -std::shared_ptr HostTypeFactory::GetNumpyFactory() { - static auto global_instance = std::make_shared(); - return global_instance; -} - -py::object HostTypeFactory::CreateImmediateNdarray( - AbiConstants::ScalarType element_type, absl::Span dims, - HalBuffer buffer, py::object parent_keep_alive) { - throw RaisePyError(PyExc_NotImplementedError, - "CreateImmediateNdarray not implemented"); -} - -void SetupHostTypesBindings(pybind11::module m) { - py::class_>( - m, "HostTypeFactory") - .def(py::init<>()) - .def_static("get_numpy", &HostTypeFactory::GetNumpyFactory); - py::class_>( - m, "PyMappedMemory", py::buffer_protocol()) - .def_buffer(&PyMappedMemory::ToBufferInfo); -} - -} // namespace python -} // namespace iree diff --git a/bindings/python/iree/runtime/host_types.h b/bindings/python/iree/runtime/host_types.h deleted file mode 100644 index 2a9b7c85c855..000000000000 --- a/bindings/python/iree/runtime/host_types.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_BINDINGS_PYTHON_IREE_RT_HOST_TYPES_H_ -#define IREE_BINDINGS_PYTHON_IREE_RT_HOST_TYPES_H_ - -#include - -#include "absl/types/span.h" -#include "bindings/python/iree/runtime/binding.h" -#include "bindings/python/iree/runtime/hal.h" -#include "iree/base/signature_parser.h" - -namespace iree { -namespace python { - -extern const std::array< - const char*, - static_cast(AbiConstants::ScalarType::kMaxScalarType) + 1> - kScalarTypePyFormat; -extern const std::array< - uint32_t, - static_cast(AbiConstants::ScalarType::kMaxScalarType) + 1> - kScalarTypeToHalElementType; - -class HostTypeFactory { - public: - virtual ~HostTypeFactory() = default; - - // Creates a default implementation which interops with numpy. - static std::shared_ptr GetNumpyFactory(); - - // Creates a C-contiguous ndarray of the given element_type/dims and backed - // by the given buffer. The resulting array has no synchronization and is - // available for use immediately. - virtual py::object CreateImmediateNdarray( - AbiConstants::ScalarType element_type, absl::Span dims, - HalBuffer buffer, py::object parent_keep_alive); - - // TODO(laurenzo): Add a CreateDelayedNdarray() which is conditioned on - // a semaphore. This is actually what should be used for async results. -}; - -void SetupHostTypesBindings(pybind11::module m); - -} // namespace python -} // namespace iree - -#endif // IREE_BINDINGS_PYTHON_IREE_RT_HOST_TYPES_H_ diff --git a/bindings/python/iree/runtime/initialize_module.cc b/bindings/python/iree/runtime/initialize_module.cc index c630806ac310..28f6d92b2b98 100644 --- a/bindings/python/iree/runtime/initialize_module.cc +++ b/bindings/python/iree/runtime/initialize_module.cc @@ -5,9 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "bindings/python/iree/runtime/binding.h" -#include "bindings/python/iree/runtime/function_abi.h" #include "bindings/python/iree/runtime/hal.h" -#include "bindings/python/iree/runtime/host_types.h" #include "bindings/python/iree/runtime/status_utils.h" #include "bindings/python/iree/runtime/vm.h" #include "iree/base/status.h" @@ -21,8 +19,6 @@ PYBIND11_MODULE(binding, m) { iree_hal_driver_registry_default())); m.doc() = "IREE Binding Backend Helpers"; - SetupFunctionAbiBindings(m); - SetupHostTypesBindings(m); SetupHalBindings(m); SetupVmBindings(m); } diff --git a/bindings/python/iree/runtime/system_api.py b/bindings/python/iree/runtime/system_api.py index f847bcb3d0aa..4f4a5fcdf643 100644 --- a/bindings/python/iree/runtime/system_api.py +++ b/bindings/python/iree/runtime/system_api.py @@ -108,7 +108,6 @@ class Config: driver: _binding.HalDriver device: _binding.HalDevice vm_instance: _binding.VmInstance - host_type_factory: _binding.HostTypeFactory default_vm_modules: Tuple[_binding.VmModule, ...] def __init__(self, driver_name: Optional[str] = None): @@ -119,7 +118,6 @@ def __init__(self, driver_name: Optional[str] = None): hal_module = _binding.create_hal_module(self.device) strings_module = _binding.create_strings_module() tensorlist_module = _binding.create_tensorlist_module() - self.host_type_factory = _binding.HostTypeFactory.get_numpy() self.default_vm_modules = (hal_module, strings_module, tensorlist_module) @@ -178,56 +176,6 @@ def _convert_lists_to_tuples(pytree): return pytree -class BoundFunction: - """Wraps a VmFunction, VmContext and ABI into a pythonic function.""" - - def __init__(self, context: "SystemContext", - vm_function: _binding.VmFunction): - self._context = context - self._vm_function = vm_function - self._abi = context.create_function_abi(vm_function) - self._serialized_inputs = None - self._serialized_outputs = None - - def __call__(self, *args, **kwargs): - # Convert tensors, device arrays, ints, ... to IREE-friendly inputs. - args = [normalize_value(value) for value in args] - kwargs = {k: normalize_value(v) for k, v in kwargs.items()} - args = [_bool_to_int8(value) for value in args] - kwargs = {k: _bool_to_int8(v) for k, v in kwargs.items()} - - # NOTE: This is just doing sync dispatch right now. In the future, - # this should default to async and potentially have some kind of policy - # flag that can allow it to be overridden. - inputs = self._abi.pack_inputs(*args, **kwargs) - self._serialized_inputs = tuple(self._abi.serialize_vm_list(inputs)) - results = self._abi.allocate_results(inputs, static_alloc=False) - self._context._vm_context.invoke(self._vm_function, inputs, results) - self._serialized_outputs = tuple(self._abi.serialize_vm_list(results)) - unpacked_results = self._abi.unpack_results(results) - - # TODO(#5359): Add support for list and tuple return types. - # The SIP signature used by the runtime bindings cannot differentiate - # between Lists and Tuples, as it only has a single 'sequence' type. - # The function abi uses py::list when unpacking the results according to the - # SIP signature. The most common instance of a returned Sequence in Python - # however is multiple return values, and that is represented by a tuple. - # We manually change the return types of all Sequences to Tuple in order to - # match the semantics of this case. - unpacked_results = _convert_lists_to_tuples(unpacked_results) - - return unpacked_results - - def __repr__(self): - return f"" - - def get_serialized_values(self): - if self._serialized_inputs is None: - raise RuntimeError("Attempted to call get_serialized_values() before " - "any values were passed.") - return self._serialized_inputs, self._serialized_outputs - - class BoundModule: """Wraps a VmModule with its context and provides nice python accessors. @@ -262,21 +210,11 @@ def __getitem__(self, name): if vm_function is None: raise KeyError(f"Function '{name}' not found in module '{self}'") - # TODO: Remove this fork and delete the local BoundFunction once SIP is - # removed. We take the new path if there is a native IREE ABI attribute - # or no SIP ('f') attribute. - reflection_dict = vm_function.reflection - if "iree.abi" in reflection_dict or "f" not in reflection_dict: - # TODO: Needing to know the precise device to allocate on here is bad - # layering and will need to be fixed in some fashion if/when doing - # heterogenous dispatch. - return FunctionInvoker(self._context.vm_context, - self._context.config.device, vm_function) - - # Legacy SIP path. - bound_function = BoundFunction(self._context, vm_function) - self._lazy_functions[name] = bound_function - return bound_function + # TODO: Needing to know the precise device to allocate on here is bad + # layering and will need to be fixed in some fashion if/when doing + # heterogenous dispatch. + return FunctionInvoker(self._context.vm_context, + self._context.config.device, vm_function) def __repr__(self): return f"" @@ -338,11 +276,6 @@ def instance(self) -> _binding.VmInstance: def modules(self) -> BoundModules: return self._bound_modules - def create_function_abi(self, f: _binding.VmFunction) -> _binding.FunctionAbi: - return self._vm_context.create_function_abi(self._config.device, - self._config.host_type_factory, - f) - def add_vm_modules(self, vm_modules): assert self._is_dynamic, "Cannot 'add_module' on a static context" for m in vm_modules: diff --git a/bindings/python/iree/runtime/vm.cc b/bindings/python/iree/runtime/vm.cc index 0cb03d884486..7c272aea1e5e 100644 --- a/bindings/python/iree/runtime/vm.cc +++ b/bindings/python/iree/runtime/vm.cc @@ -9,7 +9,6 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" -#include "bindings/python/iree/runtime/function_abi.h" #include "bindings/python/iree/runtime/status_utils.h" #include "iree/base/api.h" #include "iree/base/status.h" @@ -130,36 +129,6 @@ void VmContext::RegisterModules(std::vector modules) { CheckApiStatus(status, "Error registering modules"); } -std::unique_ptr VmContext::CreateFunctionAbi( - HalDevice& device, std::shared_ptr host_type_factory, - iree_vm_function_t f) { - // Resolve attrs. - absl::InlinedVector, 4> - attrs; - for (int i = 0;; ++i) { - attrs.push_back({}); - auto status = iree_vm_get_function_reflection_attr( - f, i, &attrs.back().first, &attrs.back().second); - if (iree_status_is_not_found(status)) { - iree_status_ignore(status); - attrs.pop_back(); - break; - } - CheckApiStatus(status, "Error getting reflection attr"); - } - auto attr_lookup = - [&attrs](absl::string_view key) -> absl::optional { - for (const auto& attr : attrs) { - absl::string_view found_key(attr.first.data, attr.first.size); - absl::string_view found_value(attr.second.data, attr.second.size); - if (found_key == key) return found_value; - } - return absl::nullopt; - }; - - return FunctionAbi::Create(device, std::move(host_type_factory), attr_lookup); -} - void VmContext::Invoke(iree_vm_function_t f, VmVariantList& inputs, VmVariantList& outputs) { CheckApiStatus(iree_vm_invoke(raw_ptr(), f, nullptr, inputs.raw_ptr(), @@ -565,8 +534,6 @@ void SetupVmBindings(pybind11::module m) { py::arg("modules") = absl::optional>()) .def("register_modules", &VmContext::RegisterModules) .def_property_readonly("context_id", &VmContext::context_id) - .def("create_function_abi", &VmContext::CreateFunctionAbi, - py::arg("device"), py::arg("host_type_factory"), py::arg("f")) .def("invoke", &VmContext::Invoke); py::class_(m, "VmModule") diff --git a/bindings/python/iree/runtime/vm.h b/bindings/python/iree/runtime/vm.h index 6d775f5ca3c6..25bbfffd8add 100644 --- a/bindings/python/iree/runtime/vm.h +++ b/bindings/python/iree/runtime/vm.h @@ -9,7 +9,7 @@ #include "absl/types/optional.h" #include "bindings/python/iree/runtime/binding.h" -#include "bindings/python/iree/runtime/host_types.h" +#include "bindings/python/iree/runtime/hal.h" #include "iree/base/api.h" #include "iree/vm/api.h" #include "iree/vm/bytecode_module.h" @@ -144,11 +144,6 @@ class VmContext : public ApiRefCounted { // Synchronously invokes the given function. void Invoke(iree_vm_function_t f, VmVariantList& inputs, VmVariantList& outputs); - - // Creates a function ABI suitable for marshalling function inputs/results. - std::unique_ptr CreateFunctionAbi( - HalDevice& device, std::shared_ptr host_type_factory, - iree_vm_function_t f); }; class VmInvocation : public ApiRefCounted { diff --git a/bindings/python/iree/runtime/vm_test.py b/bindings/python/iree/runtime/vm_test.py index b10c0f5644b1..78a86d10d113 100644 --- a/bindings/python/iree/runtime/vm_test.py +++ b/bindings/python/iree/runtime/vm_test.py @@ -74,7 +74,6 @@ def setUpClass(cls): iree.compiler.core.DEFAULT_TESTING_DRIVER) cls.device = cls.driver.create_default_device() cls.hal_module = iree.runtime.create_hal_module(cls.device) - cls.htf = iree.runtime.HostTypeFactory.get_numpy() def test_variant_list(self): l = iree.runtime.VmVariantList(5) diff --git a/docs/developers/design_docs/function_abi.md b/docs/developers/design_docs/function_abi.md index 38b6a41d7bd8..e005918af61f 100644 --- a/docs/developers/design_docs/function_abi.md +++ b/docs/developers/design_docs/function_abi.md @@ -183,251 +183,3 @@ type specific fields: - `["sdict_kwargs", ...]`: Same as `sdict` but signifies to languages that allow keyword-argument passing that this is the keyword-argument dictionary. It can only ever be present as the last entry of the root arguments `slist`. - -## Deprecated V0 ABIs - -These will be removed as soon as the corresponding code is removed. - -### Generic Signature Mangling - -Where possible, ABI metadata is encoded into a plain-text signature in a way -that is easily transported across component boundaries and can be efficiently -implemented without additional dependencies (i.e. just string manipulation). - -The suggested format is manipulated via the C++ reference implementations -`SignatureBuilder` and `SignatureParser` classes (see -`iree/base/signature_parser.h`). See documentation and code for those classes -for more details. - -### ABIs - -#### Raw Function ABI - -All exported functions implement the raw function ABI, which defines the -metadata and calling convention for marshalling inputs and results to their -underlying implementations. - -_Attributes:_ - -- `fv` = 1 (current version of the raw function ABI) -- `f` = encoded raw function signature (see below) -- `fbr` = result buffer allocation function name (optional) - -The reflection metadata documented here augments the underlying type system such -that host language bindings can interop as needed. This additional metadata is -needed in most dynamic cases because the compiled assets operate on fundamental -types with most characteristics type erased away (think: `void*` level things vs -high-level `ShapedBuffer` level things). - -##### Grammar - -The signature is implemented in terms of the SignatureBuilder, using tagged -Integer and Spans. - -```text -signature ::= 'I' length-prefixed(type-sequence) - 'R' length-prefixed(type-sequence) - -type-sequence ::= (arg-result-type)* -arg-result-type ::= buffer-type - | ref-object-type - | scalar-type - | unrecognized-type -buffer-type ::= 'B' length-prefixed(scalar-element-type? dim*) -scalar-type ::= 'S' length-prefixed(scalar-element-type?) -scalar-element-type ::= 't' ( - '0' # IEEE float32 (default if not specified) - | '1' # IEEE float16 - | '2' # IEEE float64 - | '3' # Google bfloat16 - | '4' # Signed int8 - | '5' # Signed int16 - | '6' # Signed int32 - | '7' # Signed int64 - | '8' # Unsigned int8 - | '9' # Unsigned int16 - | '10' # Unsigned int32 - | '11' # Unsigned int64 - ) -dim :: = 'd' integer # -1 indicates a dynamic dim -ref-object-type ::= 'O' length-prefixed() # Details TBD -unrecognized-type ::= 'U' length-prefixed() - -# Lexical primitives -integer ::= -?[0-9]+ -length ::= [0-9]+ -# The `length` encodes the length in bytes of `production`, plus 1 for the '!'. -length-prefixed(production) ::= length '!' production -any-byte-sequence ::= -``` - -##### Interpretation and Rationale - -###### Memory layout - -The astute reader will note that the above metadata is insufficient to determine -the memory layout of a buffer. The reason is that any more specific details than -this (contiguity, strides, alignment, etc) can actually only be known once the -actual compute devices have been enumerated and the resulting matrix of -conversions is more dynamic than can be expressed in something as static as a -function signature. The above formulation is an input to an additional runtime -oracle which produces appropriate full buffer descriptions. - -While the exact implementation is host-language specific, consider the following -more detailed set of declarations that may exist in such a binding layer: - -```c++ -// Inspired heavily by the Py_buffer type. -// See: https://docs.python.org/3/c-api/buffer.html -struct BufferDescription { - ScalarType element_type; - // For contiguous arrays, this is is the length of the underlying memory. - // For non-contiguous, this is the size of the buffer if it were copied - // to a contiguous representation. - size_t len; - // Number of dims and strides. - size_t ndim; - int* shape; - int* strides; -}; - -// Mirrors the 'buffer-type' production in the above grammar. -struct SignatureBufferType; - -// Oracle which combines signature metadata with a user-provided, materialized -// BufferDescription to derive a BufferDescription that is compatible for -// invocation. Returns an updated buffer description if the original is -// not compatible or fully specified. -// This can be used in a couple of ways: -// a) On function invocation to determine whether a provided buffer can be -// used as-is or needs to be converted (copied). -// b) To provide a factory function to the host language to create a -// compatible buffer. -optional BufferDescriptionOracle( - DeviceContext*, SignatureBufferType, BufferDescription) - throws UnsupportedBufferException; -``` - -The above scheme should allow host-language and device coordination with respect -to buffer layout. For the moment, the responsibility to convert the buffer to a -compatible memory layout is on the host-language binding. However, often it is -the most efficient to schedule this for execution on a device. In the future, it -is anticipated that there will be a built-in pathway for scheduling such a -conversion (which would allow pipelining and offload of buffer conversions). - -###### Deferred result allocation - -In general, exported functions accept pre-allocated results that should be -mutated. For the simplest cases, such results can be `null` and retrieved upon -completion of the function. This, however, puts severe limitations on the -ability to pipeline. For fully specified signatures (no dynamic shapes), the -`BufferDescriptionOracle` and the signature is sufficient to pre-allocate -appropriate results, which allows chains of result-producing invocations to be -pipelined. - -If, however, a `buffer-type` is not fully specified, the compiler may emit a -special _result allocator_ function, which will be referenced in the `fbr` -attribute. Such a function would have a signature like this: - -```c++ -tuple __allocate_results(tuple dynamic_dims); -``` - -Such a function takes a tuple of all dynamic buffer dims in the function input -signature and returns a tuple of allocated buffers for each dynamic result. Note -that it may not be possible to fully allocate results in this fashion (i.e. if -the result layout is data dependent), in which case a null buffer is returned -for that slot (and the host library would need to await on the invocation to get -the fully populated result). - -A similar mechanism will need to be created at some future point for -under-specified results of other (non-buffer) types. - -###### Contiguity hinting - -Commonly in some kinds of dataflows, the compiler needs to be free to internally -toggle buffer continuity (i.e. C/row-major, Fortran/col-major, etc). In many -cases, such toggling does not naturally escape through the exported function -boundaries, in which case, there is no ABI impact. However, it is anticipated -that there is benefit to letting the toggle propagate through the exported ABI -boundary, in which case, the `buffer-type` will likely be extended with a -contiguity hint indicating the preference. When combined with the buffer -description oracle and in-pipeline conversion features described above, this -could yield a powerful mechanism for dynamically and efficiently managing such -transitions. - -Such an enhancement would almost certainly necessitate a major version bump in -the ABI and would be logical to implement once the advanced features above are -functional. - -#### Structured Index Path ABI - -Functions may support the SIP ABI if their input and result tuples logically map -onto "structures" (nested sequence/dicts). - -_Attributes:_ - -- `sipv` = 1 (current SIP ABI version) -- `sip` = encoded SIP signature (see below) - -This ABI maps a raw, linear sequence of inputs and results onto an input and -result "structure" -- which in this context refers to a nested assembly of -sequences (with integer keys) and dictionaries (with string keys). Such a -facility is useful for encoding input/result mappings in a way that is common in -dynamic languages (such as Python). - -In practice, this ABI supports the calling convention for TensorFlow, which -allows functions that accept and produce nestings via the -[`tf.nest`](https://www.tensorflow.org/api_docs/python/tf/nest) facility. In -implementing it, however, care has been taken to allow the calling convention to -generalize to other similar cases. - -##### Grammar - -The signature is implemented in terms of the SignatureBuilder, using tagged -Integer and Spans. - -```text -# Defines the structured value for the inputs ('I') and results ('R') -# of the function. -signature ::= 'I' length-prefixed(structured-value) - 'R' length-prefixed(structured-value) - -structured-value ::= raw-fn-index | sequence | dict -raw-fn-index ::= '_' integer -sequence ::= 'S' length-prefixed( (integer-key structured-value)* ) -integer-key ::= 'k' integer -dict ::= 'D' length-prefixed( (string-key structured-value)* ) -string-key ::= 'K' length-prefixed( any-byte-sequence ) - -# Low-level lexical primitives: -integer ::= -?[0-9]+ -length ::= [0-9]+ -# The `length` encodes the length in bytes of `production`, plus 1 for the '!'. -length-prefixed(production) ::= length '!' production -any-byte-sequence ::= -``` - -Structured values define a tree of recursive dicts/lists, with `raw-fn-index` at -the leaves. The interpretation is that a raw-fn-index that has been reached by -traversing N expansions of the structured-value production is assigned an "index -path" which is a list of the N keys that were traversed to reach it. For -example, for N=0, the index path is empty. For N=1, and if an integer-key with -numerical value 0 was traversed to reach the raw-fn-index, then the index path -is [0]. - -.... give a few examples more, writing out various nested dicts/lists in -Python-esque notation to clarify this concept .... - -See the `SipSignatureParser::ToStringVisitor` for a canonical example of how to -interpret the signature. - -##### Implementations - -- C++ - - - `SipSignatureMangler`: Produces a function signature given individual - input and result assignment of physical indices to nested index paths in - the structure tree. - - `SipSignatureParser`: Parses signatures and dispatches calls to a - visitor. diff --git a/integrations/tensorflow/iree_tf_compiler/TF/BUILD b/integrations/tensorflow/iree_tf_compiler/TF/BUILD index 004ecd20b279..95ae15420292 100644 --- a/integrations/tensorflow/iree_tf_compiler/TF/BUILD +++ b/integrations/tensorflow/iree_tf_compiler/TF/BUILD @@ -14,7 +14,6 @@ cc_library( name = "TF", srcs = [ "ConvertToMHLO.cpp", - "LowerExportedFunctions.cpp", "LowerGlobalTensors.cpp", "Passes.cpp", "PrettifyDebugInfo.cpp", @@ -38,7 +37,6 @@ cc_library( "//iree_tf_compiler/dialect/tf_tensorlist/conversion:convert_tf_tensorlist_to_tensorlist", "//iree_tf_compiler/dialect/tf_tensorlist/conversion:convert_tf_to_tf_tensorlist", "//iree_tf_compiler/dialect/tf_tensorlist/ir:tf_tensorlist_dialect", - "@iree//iree/compiler/Bindings/SIP/Utils", "@iree//iree/compiler/Dialect/Flow/IR", "@iree//iree/compiler/Dialect/Flow/Transforms", "@iree//iree/compiler/Dialect/HAL/IR", diff --git a/integrations/tensorflow/iree_tf_compiler/TF/LowerExportedFunctions.cpp b/integrations/tensorflow/iree_tf_compiler/TF/LowerExportedFunctions.cpp deleted file mode 100644 index 87adac5ee268..000000000000 --- a/integrations/tensorflow/iree_tf_compiler/TF/LowerExportedFunctions.cpp +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Bindings/SIP/Utils/SignatureBuilder.h" -#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" -#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" -#include "iree_tf_compiler/TF/Passes.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" - -namespace mlir { -namespace iree_integrations { -namespace TF { - -using mlir::iree_compiler::IREE::SIP::SipSignatureMangler; - -static LogicalResult setRawSignatureIndex(FuncOp funcOp, - SipSignatureMangler &mangler, - int rawIndex, - ArrayAttr indexPathAttr) { - llvm::SmallVector indexKeys; - for (auto &indexAttr : indexPathAttr) { - if (auto stringAttr = indexAttr.dyn_cast()) { - auto stringRef = stringAttr.getValue(); - indexKeys.emplace_back(StringRef(stringRef.data(), stringRef.size())); - } else if (auto intAttr = indexAttr.dyn_cast()) { - indexKeys.emplace_back(intAttr.getInt()); - } else { - return funcOp.emitError() - << "Each index path component must be a string or integer"; - } - } - - if (!mangler.SetRawSignatureIndex(rawIndex, indexKeys)) { - return funcOp.emitError() - << "Unable to generate mangled form for index path"; - } - - return success(); -} - -class LowerExportedFunctionsPass - : public PassWrapper> { - public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - if (failed(run())) { - signalPassFailure(); - } - } - - LogicalResult run() { - mlir::Builder builder(getOperation()); - const Identifier savedModelIndexPathIdent = - builder.getIdentifier("tf_saved_model.index_path"); - const Identifier sipIdent = builder.getIdentifier("sip"); - const Identifier abiIdent = builder.getIdentifier("abi"); - const Identifier abiVersionIdent = builder.getIdentifier("abiv"); - - // Handle saved model exported functions. - for (auto func : getOperation().getOps()) { - // Transfer exported names to IREE. - auto exported_names = mlir::tf_saved_model::GetExportedNames(func); - if (exported_names.empty()) continue; - - // TODO(laurenzo): After VM rework, we should just keep the - // function name as-is and create explicit export ops for each exported - // function. - if (exported_names.size() > 1) { - return func.emitError() << "Multiple exported names not supported yet"; - } - func.setName(exported_names.front()); - - // Function level reflection attributes. - SipSignatureMangler inputsMangler; - SipSignatureMangler resultsMangler; - SmallVector funcReflectAttrs; - funcReflectAttrs.push_back(builder.getNamedAttr( - abiIdent, builder.getStringAttr(sipIdent.strref()))); - funcReflectAttrs.push_back( - builder.getNamedAttr(abiVersionIdent, builder.getI32IntegerAttr(1))); - - // Tag it as an IREE exported function. - func->setAttr("iree.module.export", builder.getUnitAttr()); - - // Process per-argument attrs and generate reflection metadata. - for (int i = 0, e = func.getNumArguments(); i < e; i++) { - auto indexPathAttr = - func.getArgAttrOfType(i, savedModelIndexPathIdent); - if (!indexPathAttr) { - return func.emitError() - << "Missing argument attribute: " << savedModelIndexPathIdent; - } - func.removeArgAttr(i, savedModelIndexPathIdent); - - if (failed( - setRawSignatureIndex(func, inputsMangler, i, indexPathAttr))) { - return failure(); - } - } - - // Process per-result attrs and generate reflection metadata. - for (int i = 0, e = func.getNumResults(); i < e; i++) { - auto indexPathAttr = func.getResultAttrOfType( - i, savedModelIndexPathIdent); - if (!indexPathAttr) { - return func.emitError() - << "Missing result attribute: " << savedModelIndexPathIdent; - } - func.removeResultAttr(i, savedModelIndexPathIdent); - - if (failed( - setRawSignatureIndex(func, resultsMangler, i, indexPathAttr))) { - return failure(); - } - } - - // Add the function level reflection attribute. - auto functionSignature = SipSignatureMangler::ToFunctionSignature( - inputsMangler, resultsMangler); - if (!functionSignature) { - return func.emitError() << "Unable to generate sip function signature"; - } - funcReflectAttrs.push_back(builder.getNamedAttr( - sipIdent, builder.getStringAttr(functionSignature->encoded()))); - - if (!funcReflectAttrs.empty()) { - func->setAttr("iree.reflection", - builder.getDictionaryAttr(funcReflectAttrs)); - } - - // Remove its designation as a saved model export. - func->removeAttr("tf_saved_model.exported_names"); - } - - // We should have now removed anything requiring saved model semantics. - getOperation()->removeAttr("tf_saved_model.semantics"); - return success(); - } -}; - -std::unique_ptr> createLowerExportedFunctionsPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "iree-tf-saved-model-lower-exported-functions", - "Lower tf_saved_model exported functions to ones with IREE SIP metadata"); - -} // namespace TF -} // namespace iree_integrations -} // namespace mlir diff --git a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h index d1ba9190f6d4..f20d428a8609 100644 --- a/integrations/tensorflow/iree_tf_compiler/TF/Passes.h +++ b/integrations/tensorflow/iree_tf_compiler/TF/Passes.h @@ -40,11 +40,6 @@ std::unique_ptr createConvertToMHLOPass(); // a module that does not have `tf_saved_model.semantics`. std::unique_ptr> createLowerGlobalTensorsPass(); -// In a module tagged with `tf_saved_model.semantics`, lowers any tf_saved_model -// exported functions to IREE exported functions with appropriate reflection -// metadata. -std::unique_ptr> createLowerExportedFunctionsPass(); - // In a module tagged with `tf_saved_model.semantics`, creates IREE ABI // functions for any saved model exported functions. std::unique_ptr> createSavedModelToIREEABIPass(); @@ -77,7 +72,6 @@ inline void registerAllPasses() { createConvertToMHLOPass(); createLowerGlobalTensorsPass(); - createLowerExportedFunctionsPass(); createPrettifyDebugInfoPass(); createPropagateResourceCastsPass(); createSavedModelToIREEABIPass(); diff --git a/iree/compiler/Bindings/SIP/BUILD b/iree/compiler/Bindings/SIP/BUILD deleted file mode 100644 index f27d209ddf95..000000000000 --- a/iree/compiler/Bindings/SIP/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) diff --git a/iree/compiler/Bindings/SIP/CMakeLists.txt b/iree/compiler/Bindings/SIP/CMakeLists.txt deleted file mode 100644 index 91baf7ef6f6f..000000000000 --- a/iree/compiler/Bindings/SIP/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Bindings/SIP/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Bindings/SIP/Transforms/BUILD b/iree/compiler/Bindings/SIP/Transforms/BUILD deleted file mode 100644 index 6005b8865c89..000000000000 --- a/iree/compiler/Bindings/SIP/Transforms/BUILD +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Transforms", - srcs = [ - "MaterializeReflectionAttrs.cpp", - "Passes.cpp", - ], - hdrs = [ - "Passes.h", - ], - deps = [ - "//iree/compiler/Bindings/SIP/Utils", - "//iree/compiler/Dialect/Flow/IR", - "//iree/compiler/Dialect/Flow/Transforms", - "//iree/compiler/Dialect/IREE/IR", - "//iree/compiler/Dialect/Shape/IR", - "//iree/compiler/Dialect/Shape/Transforms", - "//iree/compiler/Dialect/Shape/Utils:TypeConversion", - "//iree/compiler/Utils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Shape", - "@llvm-project//mlir:ShapeTransforms", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/iree/compiler/Bindings/SIP/Transforms/CMakeLists.txt b/iree/compiler/Bindings/SIP/Transforms/CMakeLists.txt deleted file mode 100644 index c668a790fd9b..000000000000 --- a/iree/compiler/Bindings/SIP/Transforms/CMakeLists.txt +++ /dev/null @@ -1,43 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Bindings/SIP/Transforms/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_cc_library( - NAME - Transforms - HDRS - "Passes.h" - SRCS - "MaterializeReflectionAttrs.cpp" - "Passes.cpp" - DEPS - LLVMSupport - MLIRIR - MLIRPass - MLIRShape - MLIRShapeOpsTransforms - MLIRStandard - MLIRSupport - MLIRTensor - MLIRTransformUtils - MLIRTransforms - iree::compiler::Bindings::SIP::Utils - iree::compiler::Dialect::Flow::IR - iree::compiler::Dialect::Flow::Transforms - iree::compiler::Dialect::IREE::IR - iree::compiler::Dialect::Shape::IR - iree::compiler::Dialect::Shape::Transforms - iree::compiler::Dialect::Shape::Utils::TypeConversion - iree::compiler::Utils - PUBLIC -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Bindings/SIP/Transforms/MaterializeReflectionAttrs.cpp b/iree/compiler/Bindings/SIP/Transforms/MaterializeReflectionAttrs.cpp deleted file mode 100644 index e1797b5ee5b8..000000000000 --- a/iree/compiler/Bindings/SIP/Transforms/MaterializeReflectionAttrs.cpp +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include - -#include "iree/compiler/Bindings/SIP/Transforms/Passes.h" -#include "iree/compiler/Bindings/SIP/Utils/SignatureBuilder.h" -#include "llvm/ADT/Optional.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -using AbiConstants::ScalarType; - -static llvm::Optional mapScalarType(Type elementType) { - // Map ScalarType. - if (elementType.isSignlessInteger()) { - auto bits = elementType.getIntOrFloatBitWidth(); - // TODO(laurenzo): These types are still signless. Assume signed and - // preserve once represented. - switch (bits) { - // We represent bools as 8-bit integers right now. - case 1: - case 8: - return ScalarType::kSint8; - case 16: - return ScalarType::kSint16; - case 32: - return ScalarType::kSint32; - case 64: - return ScalarType::kSint64; - default: - return llvm::None; - } - } else if (auto floatType = elementType.dyn_cast()) { - if (floatType.isF32()) { - return ScalarType::kIeeeFloat32; - } else if (floatType.isF64()) { - return ScalarType::kIeeeFloat64; - } else if (floatType.isF16()) { - return ScalarType::kIeeeFloat16; - } else if (floatType.isBF16()) { - return ScalarType::kGoogleBfloat16; - } else { - return llvm::None; - } - } - return llvm::None; -} - -static LogicalResult mangleTensorType(TensorType t, - RawSignatureMangler &mangler) { - auto scalarType = mapScalarType(t.getElementType()); - if (!scalarType) return failure(); - - llvm::SmallVector dims; - for (auto typeDim : t.getShape()) { - if (typeDim < 0) { - dims.push_back(-1); - } else if (typeDim > std::numeric_limits::max()) { - return failure(); - } else { - dims.push_back(typeDim); - } - } - - // Tensors map to buffers in the ABI. - mangler.AddShapedNDBuffer(*scalarType, dims); - return success(); -} - -static LogicalResult mangleScalarType(Type t, RawSignatureMangler &mangler) { - auto mappedType = mapScalarType(t); - if (!mappedType) return failure(); - mangler.AddScalar(*mappedType); - return success(); -} - -static LogicalResult mangleType(Type type, RawSignatureMangler &mangler) { - if (auto tensorType = type.dyn_cast()) { - return mangleTensorType(tensorType, mangler); - } - return mangleScalarType(type, mangler); -} - -class MaterializeReflectionAttrsPass - : public PassWrapper { - void runOnFunction() override { - auto func = getFunction(); - auto funcType = func.getType(); - auto builder = Builder(&getContext()); - - // Only process exported functions that are not marked to omit an abi. - if (!func->getAttr("iree.module.export")) return; - if (func->getAttr("iree.abi.stub")) return; - if (func->getAttr("iree.abi.none")) return; - - // Arguments. - RawSignatureMangler inputsMangler; - for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { - if (failed(mangleType(funcType.getInput(i), inputsMangler))) { - func.emitWarning() - << "Argument #" << i << " of function " << func.getName() - << " is not a recognized public ABI type and the function" - << " may not be invokable by standard tools"; - inputsMangler.AddUnrecognized(); - } - } - - // Results. - RawSignatureMangler resultsMangler; - for (int i = 0, e = funcType.getNumResults(); i < e; ++i) { - if (failed(mangleType(funcType.getResult(i), resultsMangler))) { - func.emitWarning() - << "Result #" << i << " of function " << func.getName() - << " is not a recognized public ABI type and the function" - << " may not be invokable by standard tools"; - resultsMangler.AddUnrecognized(); - } - } - - // Update the function level attribute. - auto reflectionIdent = builder.getIdentifier("iree.reflection"); - auto fIdent = builder.getIdentifier("f"); - auto fVersionIdent = builder.getIdentifier("fv"); - SignatureBuilder functionSignature = - RawSignatureMangler::ToFunctionSignature(inputsMangler, resultsMangler); - NamedAttrList l(func->getAttrOfType(reflectionIdent)); - l.set(fIdent, builder.getStringAttr(functionSignature.encoded())); - l.set(fVersionIdent, builder.getStringAttr("1")); - func->setAttr(reflectionIdent, l.getDictionary(&getContext())); - } -}; - -std::unique_ptr> createMaterializeReflectionAttrsPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "iree-sip-materialize-reflection-attrs", - "Materializes argument/result level reflection metadata for exported " - "functions."); - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Bindings/SIP/Transforms/Passes.cpp b/iree/compiler/Bindings/SIP/Transforms/Passes.cpp deleted file mode 100644 index 1c26130a2b2d..000000000000 --- a/iree/compiler/Bindings/SIP/Transforms/Passes.cpp +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Bindings/SIP/Transforms/Passes.h" - -#include - -#include "iree/compiler/Dialect/Flow/Transforms/Passes.h" -#include "mlir/Pass/PassOptions.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Passes.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -void buildTransformPassPipeline(OpPassManager &passManager) { - // Materialize default arg/result reflection metadata. - // This pass must come before any 1:N type expansion that will not be retained - // in the public ABI (i.e. loose shape dims, etc). - passManager.addNestedPass( - IREE::SIP::createMaterializeReflectionAttrsPass()); - - // Cleanup the IR after manipulating it. - passManager.addNestedPass(createCanonicalizerPass()); - passManager.addNestedPass(createCSEPass()); - passManager.addPass(createSymbolDCEPass()); -} - -void registerTransformPassPipeline() { - PassPipelineRegistration<> transformPassPipeline( - "iree-sip-transform-pipeline", - "Runs the SIP-compatible binding support pipeline", - [](OpPassManager &passManager) { - buildTransformPassPipeline(passManager); - }); -} - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Bindings/SIP/Transforms/Passes.h b/iree/compiler/Bindings/SIP/Transforms/Passes.h deleted file mode 100644 index 9bf7061ec09c..000000000000 --- a/iree/compiler/Bindings/SIP/Transforms/Passes.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2021 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_BINDINGS_SIP_TRANSFORMS_PASSES_H_ -#define IREE_COMPILER_BINDINGS_SIP_TRANSFORMS_PASSES_H_ - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -//===----------------------------------------------------------------------===// -// Helpers -//===----------------------------------------------------------------------===// - -// Adds a set of passes to the given pass manager that setup a module for use -// with an IREE SIP-compatible runtime binding implementation (python, etc). -void buildTransformPassPipeline(OpPassManager &passManager); - -void registerTransformPassPipeline(); - -//===----------------------------------------------------------------------===// -// SIP-compatible bindings support -//===----------------------------------------------------------------------===// - -// Materializes reflection metadata on exported function arguments and results. -// This runs as close to the input processing as possible as it needs to -// annotate the ABI that the consumer is expecting to interop with. -std::unique_ptr> createMaterializeReflectionAttrsPass(); - -//===----------------------------------------------------------------------===// -// Register all Passes -//===----------------------------------------------------------------------===// - -inline void registerPasses() { - registerTransformPassPipeline(); - createMaterializeReflectionAttrsPass(); -} - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_BINDINGS_SIP_TRANSFORMS_PASSES_H_ diff --git a/iree/compiler/Bindings/SIP/Transforms/test/BUILD b/iree/compiler/Bindings/SIP/Transforms/test/BUILD deleted file mode 100644 index 311d4a79b2f5..000000000000 --- a/iree/compiler/Bindings/SIP/Transforms/test/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -load("//iree:lit_test.bzl", "iree_lit_test_suite") -load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -iree_lit_test_suite( - name = "lit", - srcs = enforce_glob( - [ - "materialize_reflection_attrs.mlir", - ], - include = ["*.mlir"], - ), - data = [ - "//iree/tools:IreeFileCheck", - "//iree/tools:iree-opt", - ], -) diff --git a/iree/compiler/Bindings/SIP/Transforms/test/CMakeLists.txt b/iree/compiler/Bindings/SIP/Transforms/test/CMakeLists.txt deleted file mode 100644 index 31f66fa3b165..000000000000 --- a/iree/compiler/Bindings/SIP/Transforms/test/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Bindings/SIP/Transforms/test/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_lit_test_suite( - NAME - lit - SRCS - "materialize_reflection_attrs.mlir" - DATA - iree::tools::IreeFileCheck - iree::tools::iree-opt -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Bindings/SIP/Transforms/test/materialize_reflection_attrs.mlir b/iree/compiler/Bindings/SIP/Transforms/test/materialize_reflection_attrs.mlir deleted file mode 100644 index 761142607051..000000000000 --- a/iree/compiler/Bindings/SIP/Transforms/test/materialize_reflection_attrs.mlir +++ /dev/null @@ -1,191 +0,0 @@ -// RUN: iree-opt -split-input-file -verify-diagnostics -iree-sip-materialize-reflection-attrs %s | IreeFileCheck %s - -// CHECK-LABEL: func @notExported -// CHECK-NOT: iree.reflection -func @notExported(%arg0 : tensor<4x4xi64>) -> tensor<4x4xi64> { - return %arg0 : tensor<4x4xi64> -} - -// ----- - -// CHECK-LABEL: func @emptyWithVersion -// CHECK-SAME: iree.reflection = {f = "I1!R1!", fv = "1"} -func @emptyWithVersion() -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @exportedTensor -// CHECK-SAME: iree.reflection = {f = "I19!B7!t7d4d4B7!t7d5d5R10!B7!t7d5d5", fv = "1"} -func @exportedTensor(%arg0 : tensor<4x4xi64>, %arg1 : tensor<5x5xi64>) -> tensor<5x5xi64> attributes { - iree.module.export -} { - return %arg1 : tensor<5x5xi64> -} - -// ----- - -// CHECK-LABEL: func @dynamicDim -// CHECK-SAME: iree.reflection = {f = "I11!B8!t7d-1d4R1!", fv = "1"} -func @dynamicDim(%arg0 : tensor) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @scalari32 -// CHECK-SAME: iree.reflection = {f = "I6!S3!t6R1!", fv = "1"} -func @scalari32(%arg0 : i32) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @tensorFloat32 -// CHECK-SAME: iree.reflection = {f = "I6!B3!d1R1!", fv = "1"} -func @tensorFloat32(%arg0 : tensor<1xf32>) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @tensorFloat64 -// CHECK-SAME: iree.reflection = {f = "I8!B5!t2d1R1!", fv = "1"} -func @tensorFloat64(%arg0 : tensor<1xf64>) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @tensorFloat16 -// CHECK-SAME: iree.reflection = {f = "I8!B5!t1d1R1!", fv = "1"} -func @tensorFloat16(%arg0 : tensor<1xf16>) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @tensorBfloat16 -// CHECK-SAME: iree.reflection = {f = "I8!B5!t3d1R1!", fv = "1"} -func @tensorBfloat16(%arg0 : tensor<1xbf16>) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @tensorSint8 -// CHECK-SAME: iree.reflection = {f = "I8!B5!t4d1R1!", fv = "1"} -func @tensorSint8(%arg0 : tensor<1xi8>) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @tensorSint16 -// CHECK-SAME: iree.reflection = {f = "I8!B5!t5d1R1!", fv = "1"} -func @tensorSint16(%arg0 : tensor<1xi16>) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @tensorSint32 -// CHECK-SAME: iree.reflection = {f = "I8!B5!t6d1R1!", fv = "1"} -func @tensorSint32(%arg0 : tensor<1xi32>) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @tensorSint64 -// CHECK-SAME: iree.reflection = {f = "I8!B5!t7d1R1!", fv = "1"} -func @tensorSint64(%arg0 : tensor<1xi64>) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @noReflectionOnAbiNone -// CHECK-NOT: iree.reflection -func @noReflectionOnAbiNone(%arg0 : tensor<4x4xi64>, %arg1 : tensor<5x5xi64>) -> tensor<5x5xi64> attributes { - iree.module.export, - iree.abi.none -} { - return %arg1 : tensor<5x5xi64> -} - -// ----- - -// CHECK-LABEL: @unsupportedTypeOnAbiNone -// Should not generate warning -func @unsupportedTypeOnAbiNone(%arg0 : i1) -> () attributes { - iree.module.export, - iree.abi.none -} { - return -} - -// ----- - -// CHECK-LABEL: @reflectionOnBool -// CHECK-SAME: iree.reflection = {f = "I6!S3!t4R1!", fv = "1"} -func @reflectionOnBool(%arg0 : i1) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// expected-warning @+1 {{Argument #0 of function unsupportedType is not a recognized public ABI type and the function may not be invokable by standard tools}} -func @unsupportedType(%arg0 : i3) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @unrecognizedArgument -// CHECK-SAME: iree.reflection = {f = "I4!U1!R1!", fv = "1"} -// expected-warning @+1 {{Argument #0 of function unrecognizedArgument is not a recognized public ABI type and the function may not be invokable by standard tools}} -func @unrecognizedArgument(%arg0 : i3) -> () attributes { - iree.module.export -} { - return -} - -// ----- - -// CHECK-LABEL: func @unrecognizedResult -// CHECK-SAME: iree.reflection = {f = "I1!R4!U1!", fv = "1"} -// expected-warning @+1 {{Result #0 of function unrecognizedResult is not a recognized public ABI type and the function may not be invokable by standard tools}} -func @unrecognizedResult() -> (i3) attributes { - iree.module.export -} { - %0 = constant 0 : i3 - return %0 : i3 -} diff --git a/iree/compiler/Bindings/SIP/Utils/BUILD b/iree/compiler/Bindings/SIP/Utils/BUILD deleted file mode 100644 index 55db1cbc7d1f..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2021 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Utils", - srcs = [ - "Signature.cpp", - "SignatureBuilder.cpp", - "SignatureParser.cpp", - ], - hdrs = [ - "Signature.h", - "SignatureBuilder.h", - "SignatureParser.h", - ], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Support", - ], -) - -cc_test( - name = "SignatureTest", - srcs = ["SignatureTest.cpp"], - deps = [ - ":Utils", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - ], -) diff --git a/iree/compiler/Bindings/SIP/Utils/CMakeLists.txt b/iree/compiler/Bindings/SIP/Utils/CMakeLists.txt deleted file mode 100644 index 47daf954f60f..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/CMakeLists.txt +++ /dev/null @@ -1,41 +0,0 @@ -################################################################################ -# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # -# iree/compiler/Bindings/SIP/Utils/BUILD # -# # -# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # -# CMake-only content. # -# # -# To disable autogeneration for this file entirely, delete this header. # -################################################################################ - -iree_add_all_subdirs() - -iree_cc_library( - NAME - Utils - HDRS - "Signature.h" - "SignatureBuilder.h" - "SignatureParser.h" - SRCS - "Signature.cpp" - "SignatureBuilder.cpp" - "SignatureParser.cpp" - DEPS - LLVMSupport - MLIRSupport - PUBLIC -) - -iree_cc_test( - NAME - SignatureTest - SRCS - "SignatureTest.cpp" - DEPS - ::Utils - iree::testing::gtest - iree::testing::gtest_main -) - -### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/iree/compiler/Bindings/SIP/Utils/Signature.cpp b/iree/compiler/Bindings/SIP/Utils/Signature.cpp deleted file mode 100644 index 51c295646a05..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/Signature.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Bindings/SIP/Utils/Signature.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -// ----------------------------------------------------------------------------- -// AbiConstants -// ----------------------------------------------------------------------------- - -const std::array AbiConstants::kScalarTypeSize = { - 4, // kIeeeFloat32 = 0, - 2, // kIeeeFloat16 = 1, - 8, // kIeeeFloat64 = 2, - 2, // kGoogleBfloat16 = 3, - 1, // kSint8 = 4, - 2, // kSint16 = 5, - 4, // kSint32 = 6, - 8, // kSint64 = 7, - 1, // kUint8 = 8, - 2, // kUint16 = 9, - 4, // kUint32 = 10, - 8, // kUint64 = 11, -}; - -const std::array AbiConstants::kScalarTypeNames = { - "float32", "float16", "float64", "bfloat16", "sint8", "sint16", - "sint32", "sint64", "uint8", "uint16", "uint32", "uint64", -}; - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Bindings/SIP/Utils/Signature.h b/iree/compiler/Bindings/SIP/Utils/Signature.h deleted file mode 100644 index cc4ce538997e..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/Signature.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_H_ -#define IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_H_ - -#include -#include - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -namespace AbiConstants { - -// Canonical integer mappings are maintained for core scalar type codes -// since they change infrequently and are used everywhere. -// Generally, favor adding a custom type vs extending this arbitrarily. -enum class ScalarType : unsigned { - kIeeeFloat32 = 0, - kIeeeFloat16 = 1, - kIeeeFloat64 = 2, - kGoogleBfloat16 = 3, - kSint8 = 4, - kSint16 = 5, - kSint32 = 6, - kSint64 = 7, - kUint8 = 8, - kUint16 = 9, - kUint32 = 10, - kUint64 = 11, - kMaxScalarType = 11, -}; - -// Array that maps ScalarType codes to the size in bytes. -extern const std::array(ScalarType::kMaxScalarType) + 1> - kScalarTypeSize; - -extern const std::array(ScalarType::kMaxScalarType) + 1> - kScalarTypeNames; - -} // namespace AbiConstants - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_H_ diff --git a/iree/compiler/Bindings/SIP/Utils/SignatureBuilder.cpp b/iree/compiler/Bindings/SIP/Utils/SignatureBuilder.cpp deleted file mode 100644 index e6f81247c4d7..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/SignatureBuilder.cpp +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Bindings/SIP/Utils/SignatureBuilder.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -// ----------------------------------------------------------------------------- -// SignatureBuilder -// ----------------------------------------------------------------------------- - -SignatureBuilder& SignatureBuilder::Integer(int value, char tag) { - assert(tag == '_' || (tag >= 'a' && tag <= 'z') && - "integer signature tag must be '_' or 'a'..'z'"); - encoded_.push_back(tag); - encoded_.append(std::to_string(value)); - return *this; -} - -SignatureBuilder& SignatureBuilder::Span(StringRef contents, char tag) { - assert((tag >= 'A' && tag <= 'Z') && "span signature tag must be 'A'..'Z'"); - encoded_.push_back(tag); - // If the contents starts with a digit or the escape char (!), then escape it. - encoded_.append(std::to_string(contents.size() + 1)); - encoded_.push_back('!'); - encoded_.append(contents.begin(), contents.end()); - return *this; -} - -// ----------------------------------------------------------------------------- -// RawSignatureMangler -// ----------------------------------------------------------------------------- - -SignatureBuilder RawSignatureMangler::ToFunctionSignature( - const SignatureBuilder& inputs, const SignatureBuilder& results) { - SignatureBuilder func_builder; - inputs.AppendTo(func_builder, 'I'); - results.AppendTo(func_builder, 'R'); - return func_builder; -} - -void RawSignatureMangler::AddUnrecognized() { builder_.Span(StringRef(), 'U'); } - -void RawSignatureMangler::AddAnyReference() { - // A more constrained ref object would have a non empty span. - builder_.Span(StringRef(), 'O'); -} - -void RawSignatureMangler::AddShapedNDBuffer( - AbiConstants::ScalarType element_type, ArrayRef shape) { - SignatureBuilder item_builder; - // Fields: - // 't': scalar type code - // 'd': shape dimension - if (static_cast(element_type) != 0) { - item_builder.Integer(static_cast(element_type), 't'); - } - for (int d : shape) { - item_builder.Integer(d, 'd'); - } - item_builder.AppendTo(builder_, 'B'); -} - -void RawSignatureMangler::AddScalar(AbiConstants::ScalarType type) { - SignatureBuilder item_builder; - // Fields: - // 't': scalar type code - if (static_cast(type) != 0) { - item_builder.Integer(static_cast(type), 't'); - } - item_builder.AppendTo(builder_, 'S'); -} - -// ----------------------------------------------------------------------------- -// SipSignatureMangler -// ----------------------------------------------------------------------------- - -SipSignatureMangler::SipSignatureMangler() = default; - -bool SipSignatureMangler::SetRawSignatureIndex(int raw_signature_index, - ArrayRef path) { - if (raw_signature_index < 0) { - return false; - } - - Value* level = &root_; - for (const auto& key : path) { - // Is the indexing mode compatible? - if (level->index_mode == IndexMode::kNone) { - // Not yet committed: just adopt this first access. - level->index_mode = key.index_mode(); - } else if (level->index_mode != key.index_mode()) { - // Indexing mode mismatch. - return false; - } - - auto found_it = level->children.find(key); - if (found_it == level->children.end()) { - // Create a new level. - auto child = std::make_unique(); - Value* unowned_child = child.get(); - level->children.insert(std::make_pair(key, std::move(child))); - level = unowned_child; - continue; - } - - // Found. - level = found_it->second.get(); - } - - // Should now be on the leaf/terminal. - if (level->index_mode != IndexMode::kNone || - level->raw_signature_index != -1) { - // It is not a leaf or has already been setup as a leaf. - return false; - } - - level->raw_signature_index = raw_signature_index; - return true; -} - -bool SipSignatureMangler::ToStructureSignature(SignatureBuilder* sb, - const Value* level) const { - char sub_span_tag; - switch (level->index_mode) { - case IndexMode::kNone: - // Leaf with un-assigned raw index. - if (level->raw_signature_index < 0) { - // An un-assigned leaf is only allowed for the root. - assert(level == &root_ && "Un-assigned non-root leaf not allowed"); - return level == &root_; - } else { - sb->Integer(level->raw_signature_index); - return true; - } - case IndexMode::kSequence: - sub_span_tag = 'S'; - break; - case IndexMode::kDict: - sub_span_tag = 'D'; - break; - default: - return false; - } - - SignatureBuilder child_sb; - for (const auto& kv : level->children) { - const Key& key = kv.first; - if (key.is_integer_key()) { - child_sb.Integer(key.ikey(), 'k'); - } else if (key.is_string_key()) { - child_sb.Span(key.skey(), 'K'); - } else { - return false; - } - if (!ToStructureSignature(&child_sb, kv.second.get())) return false; - } - - child_sb.AppendTo(*sb, sub_span_tag); - return true; -} - -llvm::Optional SipSignatureMangler::ToFunctionSignature( - const SipSignatureMangler& inputs_struct, - const SipSignatureMangler& results_struct) { - auto inputs_sb = inputs_struct.ToStructureSignature(); - auto results_sb = results_struct.ToStructureSignature(); - - if (!inputs_sb || !results_sb) return {}; - - SignatureBuilder func_sb; - inputs_sb->AppendTo(func_sb, 'I'); - results_sb->AppendTo(func_sb, 'R'); - return func_sb; -} - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Bindings/SIP/Utils/SignatureBuilder.h b/iree/compiler/Bindings/SIP/Utils/SignatureBuilder.h deleted file mode 100644 index 7290b85826dc..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/SignatureBuilder.h +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_BUILDER_H_ -#define IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_BUILDER_H_ - -#include -#include -#include -#include - -#include "iree/compiler/Bindings/SIP/Utils/Signature.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -// Builds up a signature string from components. -// The signature syntax is a sequence of Integer or Span fields: -// integer_tag ::= '_' | [a-z] -// integer ::= integer_tag ('-')?[0-9]+ -// span_tag ::= [A-Z] -// span ::= span_tag (LENGTH:[0-9]+) .{LENGTH} -// -// component ::= integer-component | span-component -// integer-component ::= integer-tag integer -// span-component ::= span-tag length '!' contents -// # (Where 'length' encoded the length in bytes of 'contents' plus 1 for -// # the '!'. -// -// Low-level lexical primitives: -// integer ::= -?[0-9]+ -// length ::= [0-9]+ -// integer-tag ::= '_' | [a-z] -// span-tag ::= [A-Z] -class SignatureBuilder { - public: - SignatureBuilder() = default; - ~SignatureBuilder() = default; - - std::string& encoded() { return encoded_; } - const std::string& encoded() const { return encoded_; } - - // Appends an integer component with the given tag (or generic integer - // tag '_'). The tag must be a lower-case ascii letter between 'a'..'z' - // inclusive. - SignatureBuilder& Integer(int value, char tag = '_'); - - // Appends a literal span with a tag. - // The tag must be an upper-case ascii letter between 'A'..'Z' inclusive. - SignatureBuilder& Span(StringRef contents, char tag); - - // Appends to another builder as a sub-span with the given tag. - const SignatureBuilder& AppendTo(SignatureBuilder& other, char tag) const { - other.Span(encoded_, tag); - return *this; - } - - private: - std::string encoded_; -}; - -// ----------------------------------------------------------------------------- -// Raw signatures -// ----------------------------------------------------------------------------- - -// Mangles raw function signatures. -// See docs/developers/design_docs/function_abi.md. -class RawSignatureMangler { - public: - static SignatureBuilder ToFunctionSignature(const SignatureBuilder& inputs, - const SignatureBuilder& results); - - // Combines mangled input and result signatures into a function signature. - static SignatureBuilder ToFunctionSignature( - const RawSignatureMangler& inputs, const RawSignatureMangler& results) { - return ToFunctionSignature(inputs.builder(), results.builder()); - } - - // Adds an unrecognized type. By default, this is an empty span, but in the - // future, it may contain some further description. - void AddUnrecognized(); - - // Adds an unconstrained reference-type object. - void AddAnyReference(); - - // Adds a shaped nd buffer operand with the given element type and shape. - // Unknown dims should be -1. - // This is the common case for external interfacing and requires a fully - // ranked shape. - void AddShapedNDBuffer(AbiConstants::ScalarType element_type, - ArrayRef shape); - - void AddScalar(AbiConstants::ScalarType type); - - const SignatureBuilder& builder() const { return builder_; } - - private: - SignatureBuilder builder_; -}; - -// ----------------------------------------------------------------------------- -// Sip signatures -// ----------------------------------------------------------------------------- - -// Mangles function signatures according to the Sip (Structured Index Path) V1 -// scheme. -// -// Mangler for the 'sip' ABI. See docs/developers/design_docs/function_abi.md -// in the documentation. -class SipSignatureMangler { - public: - enum class IndexMode { - kNone, - kSequence, - kDict, - }; - - class Key { - public: - Key(int ikey) : skey_(), ikey_(ikey) { assert(ikey_ >= 0); } - Key(StringRef skey) : skey_(skey), ikey_(-1) {} - Key(const char* skey) : skey_(skey), ikey_(-1) {} - - bool is_integer_key() const { return ikey_ >= 0; } - bool is_string_key() const { return ikey_ < 0; } - - IndexMode index_mode() const { - return is_integer_key() ? IndexMode::kSequence : IndexMode::kDict; - } - - int ikey() const { return ikey_; } - StringRef skey() const { return skey_; } - - bool operator==(const Key& other) const { - return ikey_ == other.ikey_ && skey_ == other.skey_; - } - bool operator<(const Key& other) const { - return (ikey_ != other.ikey_) ? (ikey_ < other.ikey_) - : (skey_ < other.skey_); - } - - private: - StringRef skey_; - int ikey_; - }; - SipSignatureMangler(); - - // Sets the raw signature index at a structure leaf as identified by path. - // Returns whether the path and index are valid. - bool SetRawSignatureIndex(int raw_signature_index, ArrayRef path); - - // Emits a signature for the resulting structure, which will typically - // be embedded in a full function signature as either inputs or results. - llvm::Optional ToStructureSignature() const { - SignatureBuilder sb; - if (!ToStructureSignature(&sb, &root_)) { - return llvm::None; - } - return sb; - } - - // Generates a full function signature from structured inputs and results. - static llvm::Optional ToFunctionSignature( - const SipSignatureMangler& inputs_struct, - const SipSignatureMangler& results_struct); - - private: - struct Value { - // If this is a leaf, then this will be >= 0 and maps to the flat input/ - // result index in the raw signature. - int raw_signature_index = -1; - - // Whether the value is being indexed as a sequence or a dict. - IndexMode index_mode = IndexMode::kNone; - - // If not a leaf, then this is the children. - std::map> children; - }; - - bool ToStructureSignature(SignatureBuilder* sb, const Value* level) const; - Value root_; -}; - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_BUILDER_H_ diff --git a/iree/compiler/Bindings/SIP/Utils/SignatureParser.cpp b/iree/compiler/Bindings/SIP/Utils/SignatureParser.cpp deleted file mode 100644 index be77db0eb00a..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/SignatureParser.cpp +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Bindings/SIP/Utils/SignatureParser.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -// ----------------------------------------------------------------------------- -// SignatureParser -// ----------------------------------------------------------------------------- - -SignatureParser::Type SignatureParser::Next() { - next_type_ = Type::kError; - next_tag_ = 0; - next_ival_ = 0; - next_sval_ = StringRef(); - if (cursor_ == encoded_.end()) { - next_type_ = Type::kEnd; - return next_type_; - } - - next_tag_ = *cursor_; - StringRef::const_iterator ival_begin = cursor_ + 1; - StringRef::const_iterator ival_end = ival_begin; - while (ival_end != encoded_.end() && - ((*ival_end >= '0' && *ival_end <= '9') || - (*ival_end == '-' && ival_end == ival_begin))) { - ++ival_end; - } - - // No numeric value. - if (ival_end == ival_begin) { - return next_type_; - } - - // Parse ival. - if (StringRef(&(*ival_begin), ival_end - ival_begin) - .consumeInteger(10, next_ival_)) { - // Should not be possible. - return next_type_; - } - - // For integer components ('_', 'a'..'z'), that is all. - if (next_tag_ == '_' || (next_tag_ >= 'a' && next_tag_ <= 'z')) { - next_type_ = Type::kInteger; - cursor_ = ival_end; - return next_type_; - } - - // For string components ('A'..'Z'), extract the string. - if (next_tag_ >= 'A' && next_tag_ <= 'Z') { - if (next_ival_ < 0) return next_type_; // Negative size error. - StringRef::const_iterator sval_begin = ival_end; - StringRef::const_iterator sval_end = sval_begin + next_ival_; - if (sval_end > encoded_.end()) return next_type_; // Underrun. - - // Remove escape char if escaped. - if (next_ival_ == 0 || *sval_begin != '!') { - next_type_ = Type::kError; - return next_type_; - } - next_ival_ -= 1; - ++sval_begin; - - next_sval_ = StringRef(&(*sval_begin), sval_end - sval_begin); - cursor_ = sval_end; - next_type_ = Type::kSpan; - return next_type_; - } - - // Otherwise, error. - return next_type_; -} - -bool SignatureParser::SeekTag(char tag) { - while (next_tag_ != tag && next_type_ != Type::kEnd) { - Next(); - } - return next_type_ != Type::kEnd; -} - -// ----------------------------------------------------------------------------- -// RawSignatureParser -// ----------------------------------------------------------------------------- - -void RawSignatureParser::Description::ToString(std::string& s) const { - switch (type) { - case Type::kBuffer: { - const char* scalar_type_name = "!BADTYPE!"; - unsigned scalar_type_u = static_cast(buffer.scalar_type); - if (scalar_type_u >= 0 && - scalar_type_u <= AbiConstants::kScalarTypeNames.size()) { - scalar_type_name = AbiConstants::kScalarTypeNames[static_cast( - scalar_type_u)]; - } - s.append("Buffer<"); - s.append(scalar_type_name); - s.append("["); - for (size_t i = 0; i < dims.size(); ++i) { - if (i > 0) s.push_back('x'); - if (dims[i] >= 0) { - s.append(std::to_string(dims[i])); - } else { - s.push_back('?'); - } - } - s.append("]>"); - break; - } - case Type::kRefObject: { - s.append("RefObject"); - break; - } - case Type::kScalar: { - const char* type_name = "!BADTYPE!"; - unsigned type_u = static_cast(scalar.type); - if (type_u >= 0 && type_u <= AbiConstants::kScalarTypeNames.size()) { - type_name = - AbiConstants::kScalarTypeNames[static_cast(type_u)]; - } - s.append(type_name); - break; - } - default: - s.append("!UNKNOWN!"); - } -} - -llvm::Optional RawSignatureParser::FunctionSignatureToString( - StringRef signature) { - std::string s; - - bool print_sep = false; - auto visitor = [&print_sep, &s](const Description& d) { - if (print_sep) { - s.append(", "); - } - d.ToString(s); - print_sep = true; - }; - s.push_back('('); - VisitInputs(signature, visitor); - s.append(") -> ("); - print_sep = false; - VisitResults(signature, visitor); - s.push_back(')'); - - if (!GetError()) { - return s; - } else { - return llvm::None; - } -} - -// ----------------------------------------------------------------------------- -// SipSignatureParser -// ----------------------------------------------------------------------------- - -void SipSignatureParser::ToStringVisitor::IntegerKey(SipSignatureParser& p, - int k) { - s_.append(indent_); - s_.append(std::to_string(k)); -} - -void SipSignatureParser::ToStringVisitor::StringKey(SipSignatureParser& p, - StringRef k) { - s_.append(indent_); - s_.append(k.data(), k.size()); -} - -void SipSignatureParser::ToStringVisitor::OpenStruct(SipSignatureParser& p, - StructType struct_type) { - indent_.append(" "); - switch (struct_type) { - case StructType::kDict: - close_char_.push_back('}'); - s_.append(":{"); - break; - case StructType::kSequence: - close_char_.push_back(']'); - s_.append(":["); - break; - default: - close_char_.push_back('?'); - s_.append(":?"); - } - s_.append("\n"); -} - -void SipSignatureParser::ToStringVisitor::CloseStruct(SipSignatureParser& p) { - if (indent_.size() >= 2) { - indent_.resize(indent_.size() - 2); - } - s_.append(indent_); - s_.push_back(close_char_.back()); - close_char_.pop_back(); - s_.append(",\n"); -} - -void SipSignatureParser::ToStringVisitor::MapToRawSignatureIndex( - SipSignatureParser& p, int index) { - s_.append("=raw("); - s_.append(std::to_string(index)); - s_.append("),\n"); -} - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Bindings/SIP/Utils/SignatureParser.h b/iree/compiler/Bindings/SIP/Utils/SignatureParser.h deleted file mode 100644 index f40cc85c7588..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/SignatureParser.h +++ /dev/null @@ -1,394 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_PARSER_H_ -#define IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_PARSER_H_ - -#include -#include -#include -#include - -#include "iree/compiler/Bindings/SIP/Utils/Signature.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace SIP { - -// Parses a signature produced by SignatureBuilder. -// The parser works field-by-field and it is up to the caller to handle nesting -// by handling nested SignatureParsers (typically by calling nested()). -class SignatureParser { - public: - enum class Type { - kEnd, - kInteger, - kSpan, - kError, - }; - - explicit SignatureParser(StringRef encoded) - : encoded_(encoded), cursor_(encoded_.begin()) { - Next(); - } - - // Gets the next component from the signature. - Type Next(); - - // Seek to the next field with the given tag (potentially this one). - // Returns true if found. If false, the parser will be at kEnd. - bool SeekTag(char tag); - - bool end_or_error() const { - return next_type_ == Type::kEnd || next_type_ == Type::kError; - } - Type type() const { return next_type_; } - char tag() const { return next_tag_; } - int ival() const { return next_ival_; } - StringRef sval() const { return next_sval_; } - SignatureParser nested() const { return SignatureParser(next_sval_); } - - private: - StringRef encoded_; - - // Cursor is always positioned at the start of the next component. - StringRef::const_iterator cursor_; - - Type next_type_; - int next_ival_; - StringRef next_sval_; - char next_tag_; -}; - -// ----------------------------------------------------------------------------- -// Raw signatures -// ----------------------------------------------------------------------------- - -// Parses function signatures generated by RawSignatureMangler. -class RawSignatureParser { - public: - using DimVector = SmallVector; - - enum class Type { - kBuffer = 0, - kRefObject = 1, - kScalar = 2, - }; - - // Description of an input or result. - struct Description { - // Type category of the argument. - Type type; - - // For shaped types, this is the corresponding dims. - DimVector dims; - - union { - // Further details for Type == kBuffer. - struct { - AbiConstants::ScalarType scalar_type; - } buffer; - // Further details for Type == kScalar. - struct { - AbiConstants::ScalarType type; - } scalar; - }; - - // Human readable description. - void ToString(std::string& s) const; - }; - - using Visitor = std::function; - - void VisitInputs(StringRef signature, Visitor visitor) { - SignatureParser sp(signature); - if (!sp.SeekTag('I')) { - SetError("Inputs span not found"); - return; - } - auto nested = sp.nested(); - return Visit(visitor, nested); - } - - void VisitResults(StringRef signature, Visitor visitor) { - SignatureParser sp(signature); - if (!sp.SeekTag('R')) { - SetError("Results span not found"); - return; - } - auto nested = sp.nested(); - return Visit(visitor, nested); - } - - // Produces a human readable function signature from the encoded form. - // Does not return a value on error. - llvm::Optional FunctionSignatureToString(StringRef signature); - - // If the parser is in an error state, accesses the error. - const llvm::Optional& GetError() { return error_; } - void SetError(std::string error) { - if (!error_) error_ = std::move(error); - } - - private: - void Visit(Visitor& v, SignatureParser& item_parser) { - Description d; - while (!item_parser.end_or_error() && !error_) { - // Reset shared fields. - d.dims.clear(); - - switch (item_parser.tag()) { - case 'B': - if (!FillBuffer(d, SignatureParser(item_parser.nested()))) { - return; - } - break; - case 'O': - if (!FillRefObject(d, SignatureParser(item_parser.nested()))) { - return; - } - break; - case 'S': - if (!FillScalar(d, SignatureParser(item_parser.nested()))) { - return; - } - break; - default: - SetError("Unrecognized raw tag"); - return; - } - - v(d); - item_parser.Next(); - } - } - - bool FillScalar(Description& d, SignatureParser p) { - d.type = Type::kScalar; - d.buffer.scalar_type = AbiConstants::ScalarType::kIeeeFloat32; // Default - while (!p.end_or_error()) { - switch (p.tag()) { - case 't': - if (p.ival() < 0 || - p.ival() > - static_cast(AbiConstants::ScalarType::kMaxScalarType)) { - SetError("Illegal ScalarType code"); - return false; - } - d.buffer.scalar_type = - static_cast(p.ival()); - break; - default: - SetError("Unrecognized scalar field tag"); - return false; - } - p.Next(); - } - return true; - } - - bool FillBuffer(Description& d, SignatureParser p) { - d.type = Type::kBuffer; - d.buffer.scalar_type = AbiConstants::ScalarType::kIeeeFloat32; // Default - while (!p.end_or_error()) { - switch (p.tag()) { - case 't': - if (p.ival() < 0 || - p.ival() > - static_cast(AbiConstants::ScalarType::kMaxScalarType)) { - SetError("Illegal ScalarType code"); - return false; - } - d.buffer.scalar_type = - static_cast(p.ival()); - break; - case 'd': - d.dims.push_back(p.ival()); - break; - default: - SetError("Unrecognized buffer field tag"); - return false; - } - p.Next(); - } - return true; - } - - bool FillRefObject(Description& d, SignatureParser p) { - d.type = Type::kRefObject; - while (!p.end_or_error()) { - switch (p.tag()) { - default: - SetError("Unrecognized ref object field tag"); - return false; - } - p.Next(); - } - return true; - } - - llvm::Optional error_; -}; - -// ----------------------------------------------------------------------------- -// Sip signatures -// ----------------------------------------------------------------------------- - -// Parser for signatures generated by SipSignatureMangler. -// This uses a Visitor interface to walk either input or result structs. -// -// Mangler for the 'sip' ABI. See docs/developers/design_docs/function_abi.md -// in the documentation. -class SipSignatureParser { - public: - enum class StructType { - kSequence, - kDict, - }; - - template - struct VisitorAdapter { - VisitorAdapter(SipSignatureParser& p, Visitor& v) : p(p), v(v) {} - - void IntegerKey(int k) { v.IntegerKey(p, k); } - void StringKey(StringRef k) { v.StringKey(p, k); } - - void OpenStruct(StructType struct_type) { v.OpenStruct(p, struct_type); } - void CloseStruct() { v.CloseStruct(p); } - - void MapToRawSignatureIndex(int index) { - v.MapToRawSignatureIndex(p, index); - } - - SipSignatureParser& p; - Visitor& v; - }; - - class ToStringVisitor { - public: - void IntegerKey(SipSignatureParser& p, int k); - void StringKey(SipSignatureParser& p, StringRef k); - void OpenStruct(SipSignatureParser& p, StructType struct_type); - void CloseStruct(SipSignatureParser& p); - void MapToRawSignatureIndex(SipSignatureParser& p, int index); - - std::string& s() { return s_; } - - private: - std::string s_; - std::string indent_; - std::string close_char_; - }; - - template - void VisitInputs(Visitor& v, StringRef signature) { - SignatureParser sp(signature); - if (!sp.SeekTag('I')) { - return SetError("Inputs struct not found"); - } - VisitorAdapter va(*this, v); - auto nested = sp.nested(); - return Visit(va, nested, false); - } - - template - void VisitResults(Visitor& v, StringRef signature) { - SignatureParser sp(signature); - if (!sp.SeekTag('R')) { - return SetError("Results struct not found"); - } - VisitorAdapter va(*this, v); - auto nested = sp.nested(); - return Visit(va, nested, false); - } - - // If the parser is in an error state, accesses the error. - const llvm::Optional& GetError() { return error_; } - void SetError(std::string error) { - if (!error_) error_ = std::move(error); - } - - private: - template - void Visit(VisitorAdapter& v, SignatureParser& struct_parser, - bool allow_key); - - llvm::Optional error_; -}; - -template -void SipSignatureParser::Visit(VisitorAdapter& v, - SignatureParser& struct_parser, - bool global_allow_key) { - bool allow_key; - bool allow_value; - - auto reset_state = [&]() { - allow_key = global_allow_key; - allow_value = !allow_key; - }; - reset_state(); - - while (!struct_parser.end_or_error() && !error_) { - switch (struct_parser.tag()) { - case 'k': - if (!allow_key) { - return SetError("Struct key not allowed here"); - } - allow_key = false; - allow_value = true; - v.IntegerKey(struct_parser.ival()); - break; - case 'K': - if (!allow_key) { - return SetError("Struct key not allowed here"); - } - allow_key = false; - allow_value = true; - v.StringKey(struct_parser.sval()); - break; - case '_': - if (!allow_value) { - return SetError("Value not allowed here"); - } - v.MapToRawSignatureIndex(struct_parser.ival()); - reset_state(); - break; - case 'S': - case 'D': { - if (!allow_value) { - return SetError("Value not allowed here"); - } - v.OpenStruct(struct_parser.tag() == 'S' ? StructType::kSequence - : StructType::kDict); - SignatureParser child_struct_parser(struct_parser.sval()); - Visit(v, child_struct_parser, true); - v.CloseStruct(); - reset_state(); - break; - } - default: - return SetError("Unrecognized tag"); - } - struct_parser.Next(); - } - - if (struct_parser.type() == SignatureParser::Type::kError) { - return SetError("Syntax error in signature"); - } -} - -} // namespace SIP -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_BINDINGS_SIP_UTILS_SIGNATURE_PARSER_H_ diff --git a/iree/compiler/Bindings/SIP/Utils/SignatureTest.cpp b/iree/compiler/Bindings/SIP/Utils/SignatureTest.cpp deleted file mode 100644 index 79c27b90b61b..000000000000 --- a/iree/compiler/Bindings/SIP/Utils/SignatureTest.cpp +++ /dev/null @@ -1,399 +0,0 @@ -// Copyright 2019 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Bindings/SIP/Utils/SignatureBuilder.h" -#include "iree/compiler/Bindings/SIP/Utils/SignatureParser.h" -#include "iree/testing/gtest.h" - -namespace { - -using namespace mlir::iree_compiler::IREE::SIP; - -class SipSignatureTest : public ::testing::Test { - protected: - std::string PrintInputSignature(llvm::Optional signature) { - EXPECT_TRUE(signature); - - SipSignatureParser parser; - SipSignatureParser::ToStringVisitor printer; - parser.VisitInputs(printer, signature->encoded()); - EXPECT_FALSE(parser.GetError()) << "Parse error: " << *parser.GetError(); - return std::move(printer.s()); - } - - std::string PrintResultsSignature( - llvm::Optional signature) { - EXPECT_TRUE(signature); - - SipSignatureParser parser; - SipSignatureParser::ToStringVisitor printer; - parser.VisitResults(printer, signature->encoded()); - EXPECT_FALSE(parser.GetError()) << "Parse error: " << *parser.GetError(); - return std::move(printer.s()); - } -}; - -TEST(SignatureBuilderTest, TestInteger) { - SignatureBuilder sb1; - sb1.Integer(5).Integer(1, 'a').Integer(10, 'z').Integer(-5991, 'x'); - EXPECT_EQ("_5a1z10x-5991", sb1.encoded()); - - SignatureParser sp1(sb1.encoded()); - - // Expect 5. - ASSERT_EQ(SignatureParser::Type::kInteger, sp1.type()); - EXPECT_EQ('_', sp1.tag()); - EXPECT_EQ(5, sp1.ival()); - EXPECT_TRUE(sp1.sval().empty()); - - // Expect 1. - ASSERT_EQ(SignatureParser::Type::kInteger, sp1.Next()); - EXPECT_EQ('a', sp1.tag()); - EXPECT_EQ(1, sp1.ival()); - EXPECT_TRUE(sp1.sval().empty()); - - // Expect 10. - ASSERT_EQ(SignatureParser::Type::kInteger, sp1.Next()); - EXPECT_EQ('z', sp1.tag()); - EXPECT_EQ(10, sp1.ival()); - EXPECT_TRUE(sp1.sval().empty()); - - // Expect -5991. - ASSERT_EQ(SignatureParser::Type::kInteger, sp1.Next()); - EXPECT_EQ('x', sp1.tag()); - EXPECT_EQ(-5991, sp1.ival()); - EXPECT_TRUE(sp1.sval().empty()); - - // Expect end. - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); -} - -TEST(SignatureBuilderTest, TestSpan) { - SignatureBuilder sb1; - sb1.Span("foobar", 'A').Span("FOOBAR_23_FOOBAR", 'Z'); - EXPECT_EQ("A7!foobarZ17!FOOBAR_23_FOOBAR", sb1.encoded()); - - SignatureParser sp1(sb1.encoded()); - - // Expect "foobar". - ASSERT_EQ(SignatureParser::Type::kSpan, sp1.type()); - EXPECT_EQ('A', sp1.tag()); - EXPECT_EQ("foobar", sp1.sval()); - EXPECT_EQ(6, sp1.ival()); // Length. - - // Expect "FOOBAR_23_FOOBAR" - ASSERT_EQ(SignatureParser::Type::kSpan, sp1.Next()); - EXPECT_EQ('Z', sp1.tag()); - EXPECT_EQ("FOOBAR_23_FOOBAR", sp1.sval()); - EXPECT_EQ(16, sp1.ival()); // Length. - - // Expect end. - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); -} - -TEST(SignatureBuilderTest, TestEscapedNumericSpan) { - SignatureBuilder sb1; - sb1.Span("12345", 'A').Span("-23", 'Z'); - EXPECT_EQ("A6!12345Z4!-23", sb1.encoded()); - - SignatureParser sp1(sb1.encoded()); - - // Expect "foobar". - ASSERT_EQ(SignatureParser::Type::kSpan, sp1.type()); - EXPECT_EQ('A', sp1.tag()); - EXPECT_EQ("12345", sp1.sval()); - EXPECT_EQ(5, sp1.ival()); // Length. - - // Expect "FOOBAR_23_FOOBAR" - ASSERT_EQ(SignatureParser::Type::kSpan, sp1.Next()); - EXPECT_EQ('Z', sp1.tag()); - EXPECT_EQ("-23", sp1.sval()); - EXPECT_EQ(3, sp1.ival()); // Length. - - // Expect end. - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); -} - -TEST(SignatureBuilderTest, TestEscapedEscapeChar) { - SignatureBuilder sb1; - sb1.Span("!2345", 'A').Span("-23", 'Z'); - EXPECT_EQ("A6!!2345Z4!-23", sb1.encoded()); - - SignatureParser sp1(sb1.encoded()); - - // Expect "foobar". - ASSERT_EQ(SignatureParser::Type::kSpan, sp1.type()); - EXPECT_EQ('A', sp1.tag()); - EXPECT_EQ("!2345", sp1.sval()); - EXPECT_EQ(5, sp1.ival()); // Length. - - // Expect "FOOBAR_23_FOOBAR" - ASSERT_EQ(SignatureParser::Type::kSpan, sp1.Next()); - EXPECT_EQ('Z', sp1.tag()); - EXPECT_EQ("-23", sp1.sval()); - EXPECT_EQ(3, sp1.ival()); // Length. - - // Expect end. - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); -} - -TEST(SignatureBuilderTest, TestNested) { - SignatureBuilder sb1; - sb1.Integer(5); - SignatureBuilder().Integer(6).AppendTo(sb1, 'X'); - EXPECT_EQ("_5X3!_6", sb1.encoded()); - - SignatureParser sp1(sb1.encoded()); - ASSERT_EQ(SignatureParser::Type::kInteger, sp1.type()); - EXPECT_EQ('_', sp1.tag()); - EXPECT_EQ(5, sp1.ival()); - ASSERT_EQ(SignatureParser::Type::kSpan, sp1.Next()); - EXPECT_EQ('X', sp1.tag()); - auto sp2 = sp1.nested(); - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); - ASSERT_EQ(SignatureParser::Type::kInteger, sp2.type()); - EXPECT_EQ(6, sp2.ival()); - EXPECT_EQ('_', sp2.tag()); - ASSERT_EQ(SignatureParser::Type::kEnd, sp2.Next()); -} - -TEST(SignatureParserTest, Empty) { - SignatureParser sp1(""); - EXPECT_EQ(SignatureParser::Type::kEnd, sp1.type()); - ASSERT_EQ(SignatureParser::Type::kEnd, sp1.Next()); -} - -TEST(SignatureParserTest, IllegalTag) { - SignatureParser sp1("\0011 "); - EXPECT_EQ(SignatureParser::Type::kError, sp1.type()); - ASSERT_EQ(SignatureParser::Type::kError, sp1.Next()); -} - -TEST(SignatureParserTest, ShortLength) { - SignatureParser sp1("Z4abc"); - EXPECT_EQ(SignatureParser::Type::kError, sp1.type()); - ASSERT_EQ(SignatureParser::Type::kError, sp1.Next()); -} - -TEST(SignatureParserTest, NonNumeric) { - SignatureParser sp1("_+12"); - EXPECT_EQ(SignatureParser::Type::kError, sp1.type()); - ASSERT_EQ(SignatureParser::Type::kError, sp1.Next()); -} - -TEST(SignatureParserTest, NegativeLength) { - SignatureParser sp1("Z-3abc"); - EXPECT_EQ(SignatureParser::Type::kError, sp1.type()); - ASSERT_EQ(SignatureParser::Type::kError, sp1.Next()); -} - -TEST(SignatureParserTest, ZeroLengthSpan) { - SignatureParser sp1("Z1!"); - EXPECT_EQ(SignatureParser::Type::kSpan, sp1.type()); - EXPECT_EQ(0, sp1.ival()); - EXPECT_EQ("", sp1.sval()); - EXPECT_EQ(SignatureParser::Type::kEnd, sp1.Next()); -} - -// ----------------------------------------------------------------------------- -// Raw signatures -// ----------------------------------------------------------------------------- - -TEST(RawSignatureManglerTest, DefaultBuffer) { - RawSignatureMangler sm; - sm.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat32, {}); - EXPECT_EQ("B1!", sm.builder().encoded()); -} - -TEST(RawSignatureManglerTest, FullBuffer) { - RawSignatureMangler sm; - std::vector dims = {-1, 128, 64}; - sm.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat64, dims); - EXPECT_EQ("B13!t2d-1d128d64", sm.builder().encoded()); -} - -TEST(RawSignatureManglerTest, DefaultScalar) { - RawSignatureMangler sm; - sm.AddScalar(AbiConstants::ScalarType::kIeeeFloat32); - EXPECT_EQ("S1!", sm.builder().encoded()); -} - -TEST(RawSignatureManglerTest, FullScalar) { - RawSignatureMangler sm; - sm.AddScalar(AbiConstants::ScalarType::kSint32); - EXPECT_EQ("S3!t6", sm.builder().encoded()); -} - -TEST(RawSignatureManglerTest, AnyRef) { - RawSignatureMangler sm; - sm.AddAnyReference(); - EXPECT_EQ("O1!", sm.builder().encoded()); -} - -TEST(RawSignatureParserTest, EmptySignature) { - RawSignatureMangler inputs; - RawSignatureMangler results; - - auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results); - RawSignatureParser p; - auto s = p.FunctionSignatureToString(sig.encoded()); - ASSERT_TRUE(s) << *p.GetError(); - EXPECT_EQ("() -> ()", *s); -} - -TEST(RawSignatureParserTest, StaticNdArrayBuffer) { - RawSignatureMangler inputs; - std::vector dims = {10, 128, 64}; - inputs.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat32, dims); - RawSignatureMangler results; - std::vector dims2 = {32, 8, 64}; - results.AddShapedNDBuffer(AbiConstants::ScalarType::kSint32, dims2); - - auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results); - EXPECT_EQ("I15!B11!d10d128d64R15!B11!t6d32d8d64", sig.encoded()); - - RawSignatureParser p; - auto s = p.FunctionSignatureToString(sig.encoded()); - ASSERT_TRUE(s) << *p.GetError(); - EXPECT_EQ("(Buffer) -> (Buffer)", *s); -} - -TEST(RawSignatureParserTest, DynamicNdArrayBuffer) { - RawSignatureMangler inputs; - std::vector dims = {-1, 128, 64}; - inputs.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat32, dims); - RawSignatureMangler results; - std::vector dims2 = {-1, 8, 64}; - results.AddShapedNDBuffer(AbiConstants::ScalarType::kSint32, dims2); - - auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results); - EXPECT_EQ("I15!B11!d-1d128d64R15!B11!t6d-1d8d64", sig.encoded()); - - RawSignatureParser p; - auto s = p.FunctionSignatureToString(sig.encoded()); - ASSERT_TRUE(s) << *p.GetError(); - EXPECT_EQ("(Buffer) -> (Buffer)", *s); -} - -TEST(RawSignatureParserTest, Scalar) { - RawSignatureMangler inputs; - inputs.AddScalar(AbiConstants::ScalarType::kSint32); - RawSignatureMangler results; - results.AddScalar(AbiConstants::ScalarType::kIeeeFloat64); - - auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results); - EXPECT_EQ("I6!S3!t6R6!S3!t2", sig.encoded()); - - RawSignatureParser p; - auto s = p.FunctionSignatureToString(sig.encoded()); - ASSERT_TRUE(s) << *p.GetError(); - EXPECT_EQ("(sint32) -> (float64)", *s); -} - -TEST(RawSignatureParserTest, AllTypes) { - RawSignatureMangler inputs; - inputs.AddAnyReference(); - std::vector dims = {-1, 128, 64}; - inputs.AddShapedNDBuffer(AbiConstants::ScalarType::kIeeeFloat32, dims); - inputs.AddScalar(AbiConstants::ScalarType::kSint32); - RawSignatureMangler results; - std::vector dims2 = {32, -1, 64}; - results.AddShapedNDBuffer(AbiConstants::ScalarType::kUint64, dims2); - - auto sig = RawSignatureMangler::ToFunctionSignature(inputs, results); - EXPECT_EQ("I23!O1!B11!d-1d128d64S3!t6R17!B13!t11d32d-1d64", sig.encoded()); - - RawSignatureParser p; - auto s = p.FunctionSignatureToString(sig.encoded()); - ASSERT_TRUE(s) << *p.GetError(); - EXPECT_EQ( - "(RefObject, Buffer, sint32) -> " - "(Buffer)", - *s); -} - -// ----------------------------------------------------------------------------- -// Sip signatures -// ----------------------------------------------------------------------------- - -TEST_F(SipSignatureTest, NoInputsResults) { - const char kExpectedInputs[] = R"()"; - const char kExpectedResults[] = R"()"; - - SipSignatureMangler inputs; - SipSignatureMangler results; - - auto signature = SipSignatureMangler::ToFunctionSignature(inputs, results); - EXPECT_EQ("I1!R1!", signature->encoded()); - - auto inputs_string = PrintInputSignature(signature); - EXPECT_EQ(kExpectedInputs, inputs_string) << inputs_string; - - auto results_string = PrintResultsSignature(signature); - EXPECT_EQ(kExpectedResults, results_string) << results_string; -} - -TEST_F(SipSignatureTest, SequentialInputSingleResult) { - const char kExpectedInputs[] = R"(:[ - 0=raw(0), - 1=raw(1), -], -)"; - const char kExpectedResults[] = R"(=raw(0), -)"; - - SipSignatureMangler inputs; - inputs.SetRawSignatureIndex(0, {{0}}); - inputs.SetRawSignatureIndex(1, {{1}}); - - SipSignatureMangler results; - results.SetRawSignatureIndex(0, {}); - - auto signature = SipSignatureMangler::ToFunctionSignature(inputs, results); - auto inputs_string = PrintInputSignature(signature); - EXPECT_EQ(kExpectedInputs, inputs_string) << inputs_string; - - auto results_string = PrintResultsSignature(signature); - EXPECT_EQ(kExpectedResults, results_string) << results_string; -} - -TEST_F(SipSignatureTest, NestedInputMultiResult) { - const char kExpectedInputs[] = R"(:[ - 0:{ - bar=raw(1), - foo=raw(0), - }, - 1=raw(2), -], -)"; - const char kExpectedResults[] = R"(:[ - 0=raw(0), - 1=raw(1), -], -)"; - - SipSignatureMangler inputs; - inputs.SetRawSignatureIndex(0, {{0, "foo"}}); - inputs.SetRawSignatureIndex(1, {{0, "bar"}}); - inputs.SetRawSignatureIndex(2, {{1}}); - - SipSignatureMangler results; - results.SetRawSignatureIndex(0, {{0}}); - results.SetRawSignatureIndex(1, {{1}}); - - auto signature = SipSignatureMangler::ToFunctionSignature(inputs, results); - auto inputs_string = PrintInputSignature(signature); - EXPECT_EQ(kExpectedInputs, inputs_string) << inputs_string; - - auto results_string = PrintResultsSignature(signature); - EXPECT_EQ(kExpectedResults, results_string) << results_string; -} - -} // namespace diff --git a/iree/compiler/Dialect/HAL/Transforms/BUILD b/iree/compiler/Dialect/HAL/Transforms/BUILD index 1549890908aa..027944d358c9 100644 --- a/iree/compiler/Dialect/HAL/Transforms/BUILD +++ b/iree/compiler/Dialect/HAL/Transforms/BUILD @@ -27,7 +27,6 @@ cc_library( "PackConstantPoolStorage.cpp", "Passes.cpp", "PropagateConstantWorkgroupInfo.cpp", - "PublicAbiGeneration.cpp", "ResolveEntryPointOrdinals.cpp", "SerializeExecutables.cpp", "TranslateExecutables.cpp", @@ -36,7 +35,6 @@ cc_library( "Passes.h", ], deps = [ - "//iree/compiler/Bindings/SIP/Utils", "//iree/compiler/Dialect/Flow/IR", "//iree/compiler/Dialect/HAL/Conversion", "//iree/compiler/Dialect/HAL/Conversion/FlowToHAL", diff --git a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt index 1534bd59cc63..44359acc4144 100644 --- a/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Transforms/CMakeLists.txt @@ -30,7 +30,6 @@ iree_cc_library( "PackConstantPoolStorage.cpp" "Passes.cpp" "PropagateConstantWorkgroupInfo.cpp" - "PublicAbiGeneration.cpp" "ResolveEntryPointOrdinals.cpp" "SerializeExecutables.cpp" "TranslateExecutables.cpp" @@ -42,7 +41,6 @@ iree_cc_library( MLIRStandard MLIRSupport MLIRTransforms - iree::compiler::Bindings::SIP::Utils iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::HAL::Conversion iree::compiler::Dialect::HAL::Conversion::FlowToHAL diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp index 4a77b87da63d..06ceb3e3965a 100644 --- a/iree/compiler/Dialect/HAL/Transforms/Passes.cpp +++ b/iree/compiler/Dialect/HAL/Transforms/Passes.cpp @@ -88,12 +88,6 @@ void buildHALTransformPassPipeline(OpPassManager &passManager, // sizes are as much as possible available as constants. passManager.addNestedPass(createPackAllocationsPass(targetOptions)); - // For each exported function, processes the reflection metadata and - // generates public ABI wrappers for various calling conventions. - // Phase ordering note: This operates on functions whose signatures have - // been expanded to primitives. - passManager.addPass(createPublicABIGenerationPass()); - // After all executables are translated and before resolving entry point // ordinals, we allow the backends to link executables together. For example, // the LLVM AOT backend may combine all executable targets for the same diff --git a/iree/compiler/Dialect/HAL/Transforms/Passes.h b/iree/compiler/Dialect/HAL/Transforms/Passes.h index a1bbbe9027fc..25cdf9f42f0c 100644 --- a/iree/compiler/Dialect/HAL/Transforms/Passes.h +++ b/iree/compiler/Dialect/HAL/Transforms/Passes.h @@ -89,11 +89,6 @@ std::unique_ptr> createResolveEntryPointOrdinalsPass(); std::unique_ptr> createSerializeExecutablesPass(TargetOptions targetOptions); -// For functions that contain reflection metadata in an -// iree.generateabi.reflection attribute, generate public ABI functions for -// typical clients to use. -std::unique_ptr> createPublicABIGenerationPass(); - //===----------------------------------------------------------------------===// // Resource initialization, caching, and optimization //===----------------------------------------------------------------------===// @@ -146,7 +141,6 @@ inline void registerHALPasses() { createLinkExecutablesPass(targetOptions); createResolveEntryPointOrdinalsPass(); createSerializeExecutablesPass(targetOptions); - createPublicABIGenerationPass(); createIdentifyConstantPoolsPass(targetOptions); createPackConstantPoolStoragePass(); createMaterializeConstantPoolBuffersPass(); diff --git a/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp b/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp deleted file mode 100644 index 430e2f39fc32..000000000000 --- a/iree/compiler/Dialect/HAL/Transforms/PublicAbiGeneration.cpp +++ /dev/null @@ -1,435 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Bindings/SIP/Utils/SignatureParser.h" -#include "iree/compiler/Dialect/HAL/IR/HALOps.h" -#include "iree/compiler/Dialect/HAL/IR/HALTypes.h" -#include "iree/compiler/Dialect/HAL/Transforms/Passes.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Types.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { -namespace HAL { - -namespace { - -using mlir::iree_compiler::IREE::SIP::RawSignatureParser; -using mlir::iree_compiler::IREE::SIP::AbiConstants::ScalarType; - -Type mapScalarType(MLIRContext *context, ScalarType scalarType) { - switch (scalarType) { - case ScalarType::kIeeeFloat32: - return FloatType::getF32(context); - case ScalarType::kIeeeFloat64: - return FloatType::getF64(context); - case ScalarType::kIeeeFloat16: - return FloatType::getF16(context); - case ScalarType::kGoogleBfloat16: - return FloatType::getBF16(context); - case ScalarType::kSint32: - case ScalarType::kUint32: - return IntegerType::get(context, 32); - case ScalarType::kSint64: - case ScalarType::kUint64: - return IntegerType::get(context, 64); - case ScalarType::kSint16: - case ScalarType::kUint16: - return IntegerType::get(context, 16); - case ScalarType::kSint8: - case ScalarType::kUint8: - return IntegerType::get(context, 8); - default: - return nullptr; - } -} - -LogicalResult mapRawAbiTypes( - Location loc, SmallVectorImpl &descs, - SmallVectorImpl &types) { - auto *context = loc.getContext(); - auto bufferViewType = HAL::BufferViewType::get(loc.getContext()); - for (auto &d : descs) { - switch (d.type) { - case RawSignatureParser::Type::kBuffer: - // ABI buffers map to shape-erased ref of buffer_views. - types.push_back(bufferViewType); - break; - case RawSignatureParser::Type::kRefObject: { - // TODO(laurenzo): Map supported ref objects. - std::string dstr; - d.ToString(dstr); - return emitError(loc) << "unsupported ABI type: " << dstr; - } - case RawSignatureParser::Type::kScalar: { - auto t = mapScalarType(context, d.scalar.type); - if (!t) { - std::string dstr; - d.ToString(dstr); - return emitError(loc) << "unsupported ABI type: " << dstr; - } - types.push_back(t); - break; - } - } - } - - return success(); -} - -LogicalResult generateAsynchronousBody( - FuncOp rawCalleeFuncOp, FuncOp funcOp, OpBuilder moduleBuilder, - SmallVectorImpl &inputTypes, - SmallVectorImpl &inputDescs, - SmallVectorImpl &resultTypes, - SmallVectorImpl &resultDescs) { - auto *context = funcOp.getContext(); - auto loc = funcOp.getLoc(); - Block *entryBlock = funcOp.addEntryBlock(); - OpBuilder builder = OpBuilder::atBlockEnd(entryBlock); - - // TODO(#1285): Pass semaphores into raw function so modules can run async - // Wait until the wait semaphore reaches the wait value. - auto waitSemaphore = entryBlock->getArgument(0); - auto waitValue = entryBlock->getArgument(1); - auto waitOp = builder.create( - loc, builder.getIntegerType(32), waitSemaphore, waitValue); - builder.create(loc, waitOp.getResult(), - "semaphore wait failed"); - - // Build call operands. - SmallVector callOperands; - for (const auto &input : llvm::enumerate(inputDescs)) { - // Skip first two arguments (wait semaphore, wait value). - auto blockArg = entryBlock->getArgument(input.index() + 2); - switch (input.value().type) { - case RawSignatureParser::Type::kBuffer: { - // Pass the backing buffer. - // TODO(laurenzo): Validate shape. - callOperands.push_back(builder.create( - loc, IREE::HAL::BufferType::get(context), blockArg)); - - // Now, each dynamic dim is passed individually. - for (auto dim : llvm::enumerate(input.value().dims)) { - if (dim.value() >= 0) { - // Static. - continue; - } - // Dynamic: Get each dim individually. - // There is an optimization potential here if more than a couple of - // dynamic dims to use the bulk dim query op, but here just get one - // at a time as needed. - auto dimValue = builder.create( - loc, builder.getIndexType(), blockArg, - builder.getIndexAttr(dim.index())); - callOperands.push_back(dimValue); - } - break; - } - case RawSignatureParser::Type::kScalar: { - // Assume that scalars are pass-through. - callOperands.push_back(blockArg); - break; - } - case RawSignatureParser::Type::kRefObject: { - // Assume that ref objects are pass-through. - callOperands.push_back(blockArg); - break; - } - } - } - - // Build call. - auto callOp = builder.create(loc, rawCalleeFuncOp, callOperands); - - // And convert each result. For any buffer results, this involves a - // contraction from (buffer, index...) -> (buffer_view). - auto callResults = callOp.getResults(); - auto callResultsIt = callResults.begin(); - SmallVector funcResults; - for (const auto &output : llvm::enumerate(resultDescs)) { - if (callResultsIt == callResults.end()) { - return emitError(loc) - << "mismatched reflection metadata and function signature " - << "(overall arity)"; - } - Value nextCallResult = *(callResultsIt++); - switch (output.value().type) { - case RawSignatureParser::Type::kBuffer: { - // Unpack dims (dynamic dims come from call result, static become - // consts). - SmallVector dimValues; - for (auto dim : llvm::enumerate(output.value().dims)) { - if (dim.value() >= 0) { - // Static. - dimValues.push_back( - builder.create(loc, dim.value())); - } else { - // Dynamic. - if (callResultsIt == callResults.end()) { - return emitError(loc) - << "mismatched reflection metadata and function signature " - << "(dynamic dim)"; - } - dimValues.push_back(*callResultsIt); - ++callResultsIt; - } - } - - // Determine element type. - Type mappedScalarType = - mapScalarType(context, output.value().scalar.type); - auto elementType = getElementTypeValue(mappedScalarType); - if (!elementType) { - return emitError(loc) - << "unsupported hal element type: " << mappedScalarType; - } - - // Build buffer_view. - funcResults.push_back(builder.create( - loc, nextCallResult, *elementType, dimValues)); - break; - } - case RawSignatureParser::Type::kScalar: { - // Assume that scalars are pass-through. - funcResults.push_back(nextCallResult); - break; - } - case RawSignatureParser::Type::kRefObject: { - // Assume that ref objects are pass-through. - funcResults.push_back(nextCallResult); - break; - } - } - } - - // TODO(#1285): Pass semaphores into raw function so modules can run async - // Signal the signal semaphore to its signal value. - auto signalSemaphore = - entryBlock->getArgument(entryBlock->getNumArguments() - 2); - auto signalValue = entryBlock->getArgument(entryBlock->getNumArguments() - 1); - builder.create(loc, signalSemaphore, signalValue); - - // Add the return. - builder.create(loc, funcResults); - return success(); -} - -LogicalResult generateSynchronousBody(FuncOp funcOp, FuncOp asyncFuncOp, - OpBuilder moduleBuilder) { - auto loc = funcOp.getLoc(); - Block *entryBlock = funcOp.addEntryBlock(); - OpBuilder builder = OpBuilder::atBlockEnd(entryBlock); - - Value zero = builder.createOrFold(loc, builder.getIndexAttr(0)); - Value one = builder.createOrFold(loc, builder.getIndexAttr(1)); - - auto device = builder.create(loc); - auto semaphore = builder.create( - loc, IREE::HAL::SemaphoreType::get(builder.getContext()), - device.getResult(), zero); - - // Construct async arguments: - // wait_semaphore, wait_value, args, signal_semaphore, signal_value - SmallVector callAsyncArguments; - callAsyncArguments.push_back(semaphore); - callAsyncArguments.push_back(zero); - for (const auto &arg : entryBlock->getArguments()) { - callAsyncArguments.push_back(arg); - } - callAsyncArguments.push_back(semaphore); - callAsyncArguments.push_back(one); - auto callAsyncOp = - builder.create(loc, asyncFuncOp, callAsyncArguments); - - // Wait until the semaphore reaches the signal value. - auto waitOp = builder.create( - loc, builder.getIntegerType(32), semaphore, one); - builder.create(loc, waitOp.getResult(), - "semaphore wait failed"); - - // Return results of the async op. - builder.create(loc, callAsyncOp.getResults()); - - return success(); -} - -LogicalResult generateRawAbiFunctions(OpBuilder &moduleBuilder, - FuncOp rawCalleeFuncOp, - StringRef exportName, - DictionaryAttr reflection, - StringRef signatureSr) { - auto context = rawCalleeFuncOp.getContext(); - auto loc = rawCalleeFuncOp.getLoc(); - - StringRef signature(signatureSr.data(), signatureSr.size()); - SmallVector inputDescs; - SmallVector resultDescs; - - // Parse the reflection metadata. - RawSignatureParser p; - p.VisitInputs(signature, [&](const RawSignatureParser::Description &d) { - inputDescs.push_back(d); - }); - p.VisitResults(signature, [&](const RawSignatureParser::Description &d) { - resultDescs.push_back(d); - }); - if (p.GetError()) { - return rawCalleeFuncOp.emitError() - << "illegal abi signature ('" << signatureSr - << "'): " << *p.GetError(); - } - - // Map to function signature types. - SmallVector inputTypes; - SmallVector resultTypes; - if (failed(mapRawAbiTypes(loc, inputDescs, inputTypes))) { - return failure(); - } - assert(inputTypes.size() == inputDescs.size()); - if (failed(mapRawAbiTypes(loc, resultDescs, resultTypes))) { - return failure(); - } - assert(resultTypes.size() == resultDescs.size()); - - // Create the new asynchronous function export. - SmallVector asyncInputTypes; - // Prefix with wait semaphore and its value. - // TODO(scotttodd): SemaphoreValue wrapper for single {semaphore, value} - // TODO(scotttodd): SemaphoreList wrapper for list of SemaphoreValues - asyncInputTypes.push_back(HAL::SemaphoreType::get(context)); - asyncInputTypes.push_back(moduleBuilder.getIndexType()); - for (const auto &inputType : inputTypes) { - asyncInputTypes.push_back(inputType); - } - // Postfix with signal semaphore and its value. - asyncInputTypes.push_back(HAL::SemaphoreType::get(context)); - asyncInputTypes.push_back(moduleBuilder.getIndexType()); - - // TODO(scotttodd): populate async export attributes - // * iree.reflection (considering new args?) - // * iree.abi.stub - SmallVector asyncExportAttrs; - asyncExportAttrs.push_back(moduleBuilder.getNamedAttr( - "iree.module.export", - StringAttr::get(context, (exportName + "$async").str()))); - - auto asyncType = FunctionType::get(context, asyncInputTypes, resultTypes); - auto asyncName = (rawCalleeFuncOp.getName() + "$async").str(); - auto asyncFuncOp = - moduleBuilder.create(loc, asyncName, asyncType, asyncExportAttrs); - - if (failed(generateAsynchronousBody(rawCalleeFuncOp, asyncFuncOp, - moduleBuilder, inputTypes, inputDescs, - resultTypes, resultDescs))) { - return failure(); - } - - // Create the new synchronous function export. - SmallVector syncExportAttrs; - syncExportAttrs.push_back(moduleBuilder.getNamedAttr( - "iree.module.export", moduleBuilder.getStringAttr(exportName))); - syncExportAttrs.push_back( - moduleBuilder.getNamedAttr("iree.reflection", reflection)); - syncExportAttrs.push_back( - moduleBuilder.getNamedAttr("iree.abi.stub", UnitAttr::get(context))); - - auto syncType = FunctionType::get(context, inputTypes, resultTypes); - auto syncName = (rawCalleeFuncOp.getName() + "$sync").str(); - auto syncFuncOp = - moduleBuilder.create(loc, syncName, syncType, syncExportAttrs); - - if (failed(generateSynchronousBody(syncFuncOp, asyncFuncOp, moduleBuilder))) { - return failure(); - } - - return success(); -} - -LogicalResult generateAbiFunctions(FuncOp funcOp, StringRef exportName, - DictionaryAttr reflection) { - OpBuilder builder(funcOp.getContext()); - builder.setInsertionPointAfter(funcOp); - - auto rawSignatureSpec = reflection.get("f").dyn_cast_or_null(); - if (rawSignatureSpec) { - if (failed(generateRawAbiFunctions(builder, funcOp, exportName, reflection, - rawSignatureSpec.getValue()))) { - return failure(); - } - } - - return success(); -} - -Optional getFuncOpExportName(FuncOp op) { - auto exportAttr = op->getAttr("iree.module.export"); - if (!exportAttr) return llvm::None; - - if (exportAttr.isa()) { - // Just the function name. - return op.getName(); - } else if (auto nameAttr = exportAttr.dyn_cast()) { - return nameAttr.getValue(); - } - - return llvm::None; -} - -class PublicABIGenerationPass - : public PassWrapper> { - public: - void runOnOperation() override { - auto *context = &getContext(); - for (auto &op : getOperation().getBody()->getOperations()) { - if (auto funcOp = dyn_cast(op)) { - // Skip functions we generate. - if (funcOp->getAttr("iree.abi.stub")) continue; - - // Any function marked for export we make private and expose via - // generated ABI wrappers with the original name. - Optional exportName = getFuncOpExportName(funcOp); - if (!exportName) continue; - auto reflection = funcOp->getAttr("iree.reflection") - .dyn_cast_or_null(); - if (!reflection) continue; - - // Rename and remove reflection (it will go on the ABI entry point). - funcOp->removeAttr("iree.module.export"); - funcOp->removeAttr("iree.reflection"); - funcOp->setAttr("noinline", UnitAttr::get(context)); - - if (reflection) { - if (failed(generateAbiFunctions(funcOp, *exportName, reflection))) { - signalPassFailure(); - return; - } - } - } - } - } -}; - -} // namespace - -std::unique_ptr> createPublicABIGenerationPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "iree-hal-public-abi-generation", "Creates public ABI entry points"); - -} // namespace HAL -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir diff --git a/iree/compiler/Dialect/HAL/Transforms/test/BUILD b/iree/compiler/Dialect/HAL/Transforms/test/BUILD index 6f90f51229b7..9635cfa33d9b 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/BUILD +++ b/iree/compiler/Dialect/HAL/Transforms/test/BUILD @@ -28,7 +28,6 @@ iree_lit_test_suite( "pack_allocations.mlir", "pack_constant_pool_storage.mlir", "propagate_constant_workgroup_info.mlir", - "public_abi_generation.mlir", "resolve_entry_point_ordinals.mlir", ], include = ["*.mlir"], diff --git a/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt index c48201b7ac4c..d68729bd00c3 100644 --- a/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt +++ b/iree/compiler/Dialect/HAL/Transforms/test/CMakeLists.txt @@ -25,7 +25,6 @@ iree_lit_test_suite( "pack_allocations.mlir" "pack_constant_pool_storage.mlir" "propagate_constant_workgroup_info.mlir" - "public_abi_generation.mlir" "resolve_entry_point_ordinals.mlir" DATA iree::tools::IreeFileCheck diff --git a/iree/compiler/Dialect/HAL/Transforms/test/public_abi_generation.mlir b/iree/compiler/Dialect/HAL/Transforms/test/public_abi_generation.mlir deleted file mode 100644 index 802d6072a6bd..000000000000 --- a/iree/compiler/Dialect/HAL/Transforms/test/public_abi_generation.mlir +++ /dev/null @@ -1,93 +0,0 @@ -// RUN: iree-opt -split-input-file -iree-hal-public-abi-generation %s | IreeFileCheck %s - -// CHECK-LABEL: @noReflectionExport -// CHECK-SAME: attributes {iree.module.export} -func @noReflectionExport(%arg0 : tensor<4xf32>) -> tensor<4xf32> - attributes {iree.module.export} { - return %arg0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @staticTwoArg -// Note: reflection matches signature: -// (%arg0 : tensor<4x4xi64>, %arg1 : tensor<5x6xi64>) -> tensor<5x6xi64> -// A new function with $async suffix based on buffer_view with wait and signal -// semaphore arguments should be generated. -// CHECK: func @staticTwoArg$async(%[[ARG0:.+]]: !hal.semaphore, %[[ARG1:.+]]: index, %[[ARG2:.+]]: !hal.buffer_view, %[[ARG3:.+]]: !hal.buffer_view, %[[ARG4:.+]]: !hal.semaphore, %[[ARG5:.+]]: index) -// CHECK-SAME: attributes -// CHECK-SAME: iree.module.export = "staticTwoArg$async" -func @staticTwoArg(%arg0: !hal.buffer, %arg1: !hal.buffer) -> !hal.buffer - attributes {iree.module.export, - iree.reflection = {f = "I19!B7!t7d4d4B7!t7d5d6R10!B7!t7d5d6", fv = "1"}} -{ - // CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await<%[[ARG0]] : !hal.semaphore> until(%[[ARG1]]) : i32 - // CHECK-DAG: hal.check_success %[[WAITRESULT]] - // CHECK-DAG: %[[BUFFER0:.+]] = hal.buffer_view.buffer %[[ARG2]] : !hal.buffer - // CHECK-DAG: %[[BUFFER1:.+]] = hal.buffer_view.buffer %[[ARG3]] : !hal.buffer - // CHECK-DAG: %[[R0:.+]] = call @staticTwoArg(%[[BUFFER0]], %[[BUFFER1]]) - // CHECK-DAG: %[[C5:.+]] = constant 5 : index - // CHECK-DAG: %[[C6:.+]] = constant 6 : index - // CHECK-DAG: %[[VIEW:.+]] = hal.buffer_view.create %[[R0]], element_type = %c16777280_i32, shape = [%[[C5]], %[[C6]]] : !hal.buffer -> !hal.buffer_view - // CHECK-DAG: hal.semaphore.signal<%[[ARG4]] : !hal.semaphore> value(%[[ARG5]]) - // CHECK: return %[[VIEW]] - return %arg1 : !hal.buffer -} -// A new function with $sync suffix based on buffer_view should be generated. -// It should wrap the $async function. -// CHECK: func @staticTwoArg$sync(%[[ARG0:.+]]: !hal.buffer_view, %[[ARG1:.+]]: !hal.buffer_view) -// CHECK-SAME: attributes -// CHECK-SAME: iree.abi.stub -// CHECK-SAME: iree.module.export = "staticTwoArg" -// CHECK-SAME: iree.reflection = {f = "I19!B7!t7d4d4B7!t7d5d6R10!B7!t7d5d6", fv = "1"} -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device -// CHECK-DAG: %[[SEMAPHORE:.+]] = hal.semaphore.create device(%[[DEVICE]] : !hal.device) initial(%[[C0]]) : !hal.semaphore -// CHECK-DAG: %[[RESULT:.+]] = call @staticTwoArg$async(%[[SEMAPHORE]], %[[C0]], %[[ARG0]], %[[ARG1]], %[[SEMAPHORE]], %[[C1]]) : (!hal.semaphore, index, !hal.buffer_view, !hal.buffer_view, !hal.semaphore, index) -> !hal.buffer_view -// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await<%[[SEMAPHORE]] : !hal.semaphore> until(%[[C1]]) : i32 -// CHECK-DAG: hal.check_success %[[WAITRESULT]] -// CHECK: return %[[RESULT]] : !hal.buffer_view - -// ----- - -// CHECK-LABEL: @dynamicTwoDims -// Note: reflection matches signature: -// (%arg0 : tensor) -> tensor -// A new function with $async suffix based on buffer_view with wait and signal -// semaphore arguments should be generated. -// CHECK: func @dynamicTwoDims$async(%[[ARG0:.+]]: !hal.semaphore, %[[ARG1:.+]]: index, %[[ARG2:.+]]: !hal.buffer_view, %[[ARG3:.+]]: !hal.semaphore, %[[ARG4:.+]]: index) -// CHECK-SAME: attributes -// CHECK-SAME: iree.module.export = "dynamicTwoDims$async" -// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await<%[[ARG0]] : !hal.semaphore> until(%[[ARG1]]) : i32 -// CHECK-DAG: hal.check_success %[[WAITRESULT]] -// CHECK-DAG: %[[BUFFER:.+]] = hal.buffer_view.buffer %[[ARG2]] : !hal.buffer -// CHECK-DAG: %[[DIM0:.+]] = hal.buffer_view.dim %[[ARG2]], 0 : index -// CHECK-DAG: %[[DIM1:.+]] = hal.buffer_view.dim %[[ARG2]], 1 : index -// CHECK-DAG: %[[RESULT:.+]]:3 = call @dynamicTwoDims(%[[BUFFER]], %[[DIM0]], %[[DIM1]]) -// CHECK-DAG: %[[RESULT_VIEW:.+]] = hal.buffer_view.create %[[RESULT]]#0, element_type = %c50331680_i32, shape = [%[[RESULT]]#1, %[[RESULT]]#2] : !hal.buffer -> !hal.buffer_view -// CHECK-DAG: hal.semaphore.signal<%[[ARG3]] : !hal.semaphore> value(%[[ARG4]]) -// CHECK: return %[[RESULT_VIEW]] -// A new function with $sync suffix based on buffer_view should be generated. -// It should wrap the $async function. -// CHECK: func @dynamicTwoDims$sync(%[[ARG0:.+]]: !hal.buffer_view) -// CHECK-SAME: attributes -// CHECK-SAME: iree.abi.stub -// CHECK-SAME: iree.module.export = "dynamicTwoDims" -// CHECK-SAME: iree.reflection = {f = "I10!B7!d-1d-1R10!B7!d-1d-1", fv = "1"} -// CHECK-DAG: %[[C0:.+]] = constant 0 : index -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[DEVICE:.+]] = hal.ex.shared_device : !hal.device -// CHECK-DAG: %[[SEMAPHORE:.+]] = hal.semaphore.create device(%[[DEVICE]] : !hal.device) initial(%[[C0]]) : !hal.semaphore -// CHECK-DAG: %[[RESULT:.+]] = call @dynamicTwoDims$async(%[[SEMAPHORE]], %[[C0]], %[[ARG0]], %[[SEMAPHORE]], %[[C1]]) : (!hal.semaphore, index, !hal.buffer_view, !hal.semaphore, index) -> !hal.buffer_view -// CHECK-DAG: %[[WAITRESULT:.+]] = hal.semaphore.await<%[[SEMAPHORE]] : !hal.semaphore> until(%[[C1]]) : i32 -// CHECK-DAG: hal.check_success %[[WAITRESULT]] -// CHECK: return %[[RESULT]] : !hal.buffer_view -func @dynamicTwoDims(%arg0 : !hal.buffer, %arg1 : index, %arg2 : index) -> (!hal.buffer, index, index) - attributes {iree.module.export, - iree.reflection = {f = "I10!B7!d-1d-1R10!B7!d-1d-1", fv = "1"}} -{ - %0 = constant 5 : index - %1 = constant 6 : index - return %arg0, %0, %1 : !hal.buffer, index, index -} diff --git a/iree/compiler/Translation/BUILD b/iree/compiler/Translation/BUILD index 5be18c04585a..7653572fecf6 100644 --- a/iree/compiler/Translation/BUILD +++ b/iree/compiler/Translation/BUILD @@ -18,7 +18,6 @@ cc_library( hdrs = ["IREEVM.h"], deps = [ "//iree/compiler/Bindings/Native/Transforms", - "//iree/compiler/Bindings/SIP/Transforms", "//iree/compiler/Bindings/TFLite/Transforms", "//iree/compiler/Dialect/Flow/IR", "//iree/compiler/Dialect/Flow/Transforms", diff --git a/iree/compiler/Translation/CMakeLists.txt b/iree/compiler/Translation/CMakeLists.txt index 6e7de19b1955..6764fabf23a8 100644 --- a/iree/compiler/Translation/CMakeLists.txt +++ b/iree/compiler/Translation/CMakeLists.txt @@ -26,7 +26,6 @@ iree_cc_library( MLIRSupport MLIRTranslation iree::compiler::Bindings::Native::Transforms - iree::compiler::Bindings::SIP::Transforms iree::compiler::Bindings::TFLite::Transforms iree::compiler::Dialect::Flow::IR iree::compiler::Dialect::Flow::Transforms diff --git a/iree/compiler/Translation/IREEVM.cpp b/iree/compiler/Translation/IREEVM.cpp index e25b1a150fc3..d18212fb8710 100644 --- a/iree/compiler/Translation/IREEVM.cpp +++ b/iree/compiler/Translation/IREEVM.cpp @@ -7,7 +7,6 @@ #include "iree/compiler/Translation/IREEVM.h" #include "iree/compiler/Bindings/Native/Transforms/Passes.h" -#include "iree/compiler/Bindings/SIP/Transforms/Passes.h" #include "iree/compiler/Bindings/TFLite/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/HAL/Transforms/Passes.h" @@ -35,9 +34,6 @@ namespace iree_compiler { struct BindingOptions { // Whether to include runtime support functions for the IREE native ABI. bool native = true; - // Whether to include runtime support functions and metadata required for - // SIP-compatible bindings (like bindings/python/iree). - bool sip = false; // Whether to include runtime support functions required for the IREE TFLite // API compatibility bindings. bool tflite = false; @@ -53,11 +49,6 @@ static BindingOptions getBindingOptionsFromFlags() { "Include runtime support for native IREE ABI-compatible bindings"), llvm::cl::init(true), llvm::cl::cat(bindingOptionsCategory)}; - static llvm::cl::opt *bindingsSIPFlag = new llvm::cl::opt{ - "iree-sip-bindings-support", - llvm::cl::desc("Include runtime support for SIP-compatible bindings"), - llvm::cl::init(false), llvm::cl::cat(bindingOptionsCategory)}; - static llvm::cl::opt *bindingsTFLiteFlag = new llvm::cl::opt{ "iree-tflite-bindings-support", llvm::cl::desc( @@ -66,7 +57,6 @@ static BindingOptions getBindingOptionsFromFlags() { BindingOptions bindingOptions; bindingOptions.native = *bindingsNativeFlag; - bindingOptions.sip = *bindingsSIPFlag; bindingOptions.tflite = *bindingsTFLiteFlag; return bindingOptions; } @@ -172,9 +162,6 @@ static void buildIREEVMTransformPassPipeline( if (bindingOptions.native) { IREE::ABI::buildTransformPassPipeline(passManager); } - if (bindingOptions.sip) { - IREE::SIP::buildTransformPassPipeline(passManager); - } if (bindingOptions.tflite) { IREE::TFLite::buildTransformPassPipeline(passManager); }