From 5aabc30a903d0352a72f28b5658d43aeb732bf50 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Fri, 23 Oct 2020 11:35:14 -0700 Subject: [PATCH] PyTorch NNAPI integration prototype Summary: This is in prototype status, but pretty functional. There are two major parts. - Model converter. This is a pure Python component that consumes a model in TorchScript format, converts the operations into NNAPI semantics, and serializes the model in a custom format. It then wraps the result in a new TorchScript model that can invoke NNAPI under the hood. - Runtime. This is a TorchBind object that deserializes the model and sends the result to NNAPI. This is fairly simple since the serialized format is basically just a list of NNAPI calls to make, so most of the code is spent on bounds checking. A few notes on the design. - Currently, all tensor sizes need to be fixed, and those fixed sizes are burned directly into the serialized model. This will probably need to change. NNAPI supports variable-sized tensors, but the important hardware backends do not. However, we're seeing use cases crop up where the input size is not known until around the time that the model is loaded (for example, it might depend on the camera aspect ratio). I think the proper fix here is to remove the code in the converter that eagerly calculates the sizes of the intermediate tensors and replace it with a code generator that will generate some TorchScript code that will perform those calculations at model load time. This way, we will be able to support models that have variable-sized inputs while still only showing fixed-sized operands to NNAPI. - The important hardware backends want operands to be in NHWC order, but PyTorch natively represents all tensors and NCHW. The strategy for this is to keep NCHW during most of the conversion process, but track and additional value per operand representing the "dimension order". The dimension order gets propagated through convolutions and pointwise ops. When we're ready to serialize the model, we reorder the dimensions for "channels last" operands to NHWC. Test Plan: Some local testing with FB prod models. I'll need to add some examples and automated tests. ghstack-source-id: e1fa978af170d4d00c5270c52b9d4cb63843e7d2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46780 --- aten/src/ATen/CMakeLists.txt | 8 + aten/src/ATen/nnapi/CMakeLists.txt | 24 + aten/src/ATen/nnapi/NeuralNetworks.h | 84 ++ aten/src/ATen/nnapi/codegen.py | 155 +++ aten/src/ATen/nnapi/nnapi_bind.cpp | 197 ++++ aten/src/ATen/nnapi/nnapi_model_loader.cpp | 264 +++++ aten/src/ATen/nnapi/nnapi_model_loader.h | 29 + aten/src/ATen/nnapi/nnapi_wrapper.cpp | 325 ++++++ aten/src/ATen/nnapi/nnapi_wrapper.h | 62 ++ binaries/CMakeLists.txt | 4 + torch/backends/_nnapi/__init__.py | 0 torch/backends/_nnapi/prepare.py | 167 +++ torch/backends/_nnapi/serializer.py | 1131 ++++++++++++++++++++ 13 files changed, 2450 insertions(+) create mode 100644 aten/src/ATen/nnapi/CMakeLists.txt create mode 100644 aten/src/ATen/nnapi/NeuralNetworks.h create mode 100755 aten/src/ATen/nnapi/codegen.py create mode 100644 aten/src/ATen/nnapi/nnapi_bind.cpp create mode 100644 aten/src/ATen/nnapi/nnapi_model_loader.cpp create mode 100644 aten/src/ATen/nnapi/nnapi_model_loader.h create mode 100644 aten/src/ATen/nnapi/nnapi_wrapper.cpp create mode 100644 aten/src/ATen/nnapi/nnapi_wrapper.h create mode 100644 torch/backends/_nnapi/__init__.py create mode 100644 torch/backends/_nnapi/prepare.py create mode 100644 torch/backends/_nnapi/serializer.py diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index bf9029ff6c6b..5a261a5dd37a 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -446,6 +446,14 @@ list(APPEND ATen_MOBILE_BENCHMARK_SRCS list(APPEND ATen_MOBILE_BENCHMARK_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/benchmarks/stateful_conv1d.cpp) +if(LINUX OR ANDROID) + # NNAPI is primarily for Android, but also build on Linux + # to allow easy experimentation when a host build of libneuralnetworks + # is available. We don't have any build-time dependencies on NNAPI, + # so this should be safe. + add_subdirectory(nnapi) +endif() + # Pass source, includes, and libs to parent set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE) set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE) diff --git a/aten/src/ATen/nnapi/CMakeLists.txt b/aten/src/ATen/nnapi/CMakeLists.txt new file mode 100644 index 000000000000..d4505ca5a7c2 --- /dev/null +++ b/aten/src/ATen/nnapi/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 3.5 FATAL_ERROR) +project(pytorch_nnapi) + +# Define this to build the NNAPI binding out of tree. +if(PYTORCH_NNAPI_STANDALONE) + set(CMAKE_CXX_STANDARD 14) + find_package(Torch REQUIRED) +endif() + +set(NNAPI_SRCS + nnapi_bind.cpp + nnapi_wrapper.cpp + nnapi_model_loader.cpp + ) + +# Static build on Android so we can just bundle with the benchmarker +# or with PyTorch, but use shared on host so we don't load by accident. +if(ANDROID) + add_library(pytorch_nnapi STATIC ${NNAPI_SRCS}) +else() + add_library(pytorch_nnapi SHARED ${NNAPI_SRCS}) +endif() + +target_link_libraries(pytorch_nnapi torch) diff --git a/aten/src/ATen/nnapi/NeuralNetworks.h b/aten/src/ATen/nnapi/NeuralNetworks.h new file mode 100644 index 000000000000..bfc3ea4ac49d --- /dev/null +++ b/aten/src/ATen/nnapi/NeuralNetworks.h @@ -0,0 +1,84 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + +Most of NeuralNetworks.h has been stripped for simplicity. +We don't need any of the function declarations since +we call them all through dlopen/dlsym. +Operation codes are pulled directly from serialized models. + +*/ + +#ifndef MINIMAL_NEURAL_NETWORKS_H +#define MINIMAL_NEURAL_NETWORKS_H + +#include + +typedef enum { + ANEURALNETWORKS_NO_ERROR = 0, + ANEURALNETWORKS_OUT_OF_MEMORY = 1, + ANEURALNETWORKS_INCOMPLETE = 2, + ANEURALNETWORKS_UNEXPECTED_NULL = 3, + ANEURALNETWORKS_BAD_DATA = 4, + ANEURALNETWORKS_OP_FAILED = 5, + ANEURALNETWORKS_BAD_STATE = 6, + ANEURALNETWORKS_UNMAPPABLE = 7, + ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE = 8, + ANEURALNETWORKS_UNAVAILABLE_DEVICE = 9, +} ResultCode; + +typedef enum { + ANEURALNETWORKS_FLOAT32 = 0, + ANEURALNETWORKS_INT32 = 1, + ANEURALNETWORKS_UINT32 = 2, + ANEURALNETWORKS_TENSOR_FLOAT32 = 3, + ANEURALNETWORKS_TENSOR_INT32 = 4, + ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5, + ANEURALNETWORKS_BOOL = 6, + ANEURALNETWORKS_TENSOR_QUANT16_SYMM = 7, + ANEURALNETWORKS_TENSOR_FLOAT16 = 8, + ANEURALNETWORKS_TENSOR_BOOL8 = 9, + ANEURALNETWORKS_FLOAT16 = 10, + ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL = 11, + ANEURALNETWORKS_TENSOR_QUANT16_ASYMM = 12, + ANEURALNETWORKS_TENSOR_QUANT8_SYMM = 13, +} OperandCode; + +typedef enum { + ANEURALNETWORKS_PREFER_LOW_POWER = 0, + ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1, + ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2, +} PreferenceCode; + +typedef struct ANeuralNetworksMemory ANeuralNetworksMemory; +typedef struct ANeuralNetworksModel ANeuralNetworksModel; +typedef struct ANeuralNetworksDevice ANeuralNetworksDevice; +typedef struct ANeuralNetworksCompilation ANeuralNetworksCompilation; +typedef struct ANeuralNetworksExecution ANeuralNetworksExecution; +typedef struct ANeuralNetworksEvent ANeuralNetworksEvent; + +typedef int32_t ANeuralNetworksOperationType; + +typedef struct ANeuralNetworksOperandType { + int32_t type; + uint32_t dimensionCount; + const uint32_t* dimensions; + float scale; + int32_t zeroPoint; +} ANeuralNetworksOperandType; + +#endif // MINIMAL_NEURAL_NETWORKS_H diff --git a/aten/src/ATen/nnapi/codegen.py b/aten/src/ATen/nnapi/codegen.py new file mode 100755 index 000000000000..12ae40d4645e --- /dev/null +++ b/aten/src/ATen/nnapi/codegen.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +""" +Code generator for NNAPI wrapper. We can't link directly against +libneuralnetworks.so because we want PyTorch to work on Android +devices that don't have it available. Instead, we generate a wrapper +that opens libneuralnetworks.so with dlopen and finds the functions +we need with dlsym. We also generate a "check" wrapper that checks +return values and throws C++ exceptions on errors. +""" +import sys +import re +import pathlib +import textwrap + + +PREFIX = """\ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is generated by nnapi/codegen.py +""" + + +NNAPI_FUNCTIONS = [ + ("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), + ("int", "ANeuralNetworks_getDevice", "uint32_t devIndex, ANeuralNetworksDevice** device"), + ("int", "ANeuralNetworksDevice_getName", "const ANeuralNetworksDevice* device, const char** name"), + ("int", "ANeuralNetworksDevice_getVersion", "const ANeuralNetworksDevice* device, const char** version"), + ("int", "ANeuralNetworksDevice_getFeatureLevel", "const ANeuralNetworksDevice* device, int64_t* featureLevel"), + ("int", "ANeuralNetworksModel_getSupportedOperationsForDevices", " const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps"), + ("int", "ANeuralNetworksCompilation_createForDevices", "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation"), + ("int", "ANeuralNetworksExecution_compute", "ANeuralNetworksExecution* execution"), + ("int", "ANeuralNetworksMemory_createFromFd", "size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory"), + ("void", "ANeuralNetworksMemory_free", "ANeuralNetworksMemory* memory"), + ("int", "ANeuralNetworksModel_create", "ANeuralNetworksModel** model"), + ("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), + ("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), + ("int", "ANeuralNetworksModel_addOperand", "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type"), + ("int", "ANeuralNetworksModel_setOperandValue", "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length"), + ("int", "ANeuralNetworksModel_setOperandValueFromMemory", "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), + ("int", "ANeuralNetworksModel_addOperation", "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), + ("int", "ANeuralNetworksModel_identifyInputsAndOutputs", "ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs"), + ("int", "ANeuralNetworksModel_relaxComputationFloat32toFloat16", "ANeuralNetworksModel* model, bool allow"), + ("int", "ANeuralNetworksCompilation_create", "ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation"), + ("void", "ANeuralNetworksCompilation_free", "ANeuralNetworksCompilation* compilation"), + ("int", "ANeuralNetworksCompilation_setPreference", "ANeuralNetworksCompilation* compilation, int32_t preference"), + ("int", "ANeuralNetworksCompilation_finish", "ANeuralNetworksCompilation* compilation"), + ("int", "ANeuralNetworksExecution_create", "ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution"), + ("void", "ANeuralNetworksExecution_free", "ANeuralNetworksExecution* execution"), + ("int", "ANeuralNetworksExecution_setInput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length"), + ("int", "ANeuralNetworksExecution_setInputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), + ("int", "ANeuralNetworksExecution_setOutput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length"), + ("int", "ANeuralNetworksExecution_setOutputFromMemory", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length"), + ("int", "ANeuralNetworksExecution_startCompute", "ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event"), + ("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), + ("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), + ("int", "ANeuralNetworksExecution_getOutputOperandRank", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank"), + ("int", "ANeuralNetworksExecution_getOutputOperandDimensions", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions"), +] + + +def main(argv): + struct_members = [] + load_functions = [] + define_checks = [] + + for ret, name, args in NNAPI_FUNCTIONS: + short_name = name.replace("ANeuralNetworks", "", 1) + + struct_members.append(f" {ret}(*{short_name})({args});") + + load_functions.append(f' *(void**)&nnapi_.{short_name} = dlsym(handle, "{name}");') + load_functions.append(f' check_nnapi_.{short_name} = check_{short_name};') + + call_args = "".join(re.findall("\w+(?:,|$)", args)) + if ret == "void": + define_checks.append(textwrap.dedent(f"""\ + {ret} check_{short_name}({args}) {{ + CAFFE_ENFORCE(nnapi_.{short_name}); + nnapi_.{short_name}({call_args}); + }}""")) + if ret == "int": + define_checks.append(textwrap.dedent(f"""\ + {ret} check_{short_name}({args}) {{ + CAFFE_ENFORCE(nnapi_.{short_name}); + int ret = nnapi_.{short_name}({call_args}); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; + }}""")) + + out_dir = pathlib.Path(__file__).parent + + (out_dir / "nnapi_wrapper.h").write_text( + PREFIX + + textwrap.dedent("""\ + #ifndef NNAPI_WRAPPER_H_ + #define NNAPI_WRAPPER_H_ + #include + #include + #include "NeuralNetworks.h" + struct nnapi_wrapper { + __STRUCT_MEMBERS__ + }; + #ifdef __cplusplus + void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi); + #endif + #endif + """) + .replace("__STRUCT_MEMBERS__", "\n".join(struct_members)) + ) + + (out_dir / "nnapi_wrapper.cpp").write_text( + PREFIX + + textwrap.dedent("""\ + #include + #include "nnapi_wrapper.h" + #include "c10/util/Logging.h" + static int loaded = 0; + static struct nnapi_wrapper nnapi_; + static struct nnapi_wrapper check_nnapi_; + __DEFINE_CHECK_FUNCTIONS__ + void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi) { + if (!loaded) { + // Clear error flag. + dlerror(); + void* handle = dlopen("libneuralnetworks.so", RTLD_LAZY | RTLD_LOCAL); + CAFFE_ENFORCE(handle, "Failed to load libneuralnetworks.so ", dlerror()); + __LOAD_FUNCTIONS__ + loaded = 1; + } + *nnapi = &nnapi_; + *check_nnapi = &check_nnapi_; + } + """) + .replace("__DEFINE_CHECK_FUNCTIONS__", "\n".join(define_checks)) + .replace("__LOAD_FUNCTIONS__", "\n".join(load_functions)) + ) + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/aten/src/ATen/nnapi/nnapi_bind.cpp b/aten/src/ATen/nnapi/nnapi_bind.cpp new file mode 100644 index 000000000000..0b5eb09ff154 --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_bind.cpp @@ -0,0 +1,197 @@ +#include + +#include +#include + +#include "nnapi_wrapper.h" +#include "nnapi_model_loader.h" + + +namespace torch { +namespace nnapi { +namespace { + +nnapi_wrapper* nnapi; +nnapi_wrapper* check_nnapi; + +void load_platform_library() { + static int run_once = [](){ + nnapi_wrapper_load(&nnapi, &check_nnapi); + CAFFE_ENFORCE(nnapi); + CAFFE_ENFORCE(nnapi->Model_free); + CAFFE_ENFORCE(nnapi->Compilation_free); + CAFFE_ENFORCE(nnapi->Execution_free); + return 0; + }(); + (void)run_once; +} + +#define MAKE_SMART_PTR(type) \ + struct type ## Freer { \ + void operator()(ANeuralNetworks ## type * obj) { \ + if (!nnapi) { /* obj must be null. */ return; } \ + nnapi-> type ## _free(obj); \ + } \ + }; \ + typedef std::unique_ptr type ## Ptr; + +MAKE_SMART_PTR(Model) +MAKE_SMART_PTR(Compilation) +MAKE_SMART_PTR(Execution) + +#undef MAKE_SMART_PTR + +struct NnapiCompilation : torch::jit::CustomClassHolder { + NnapiCompilation() { + // Could possibly call load_platform_library here, but error reporting + // can be complicated if the constructor is called during model loading. + // Instead, delay all work until the explicit init call. + } + + ~NnapiCompilation() { + } + + void init( + torch::Tensor serialized_model_tensor, + std::vector parameter_buffers) { + TORCH_CHECK(!model_, "Attempted to re-initialize NnapiCompilation."); + + load_platform_library(); + + std::vector buffers; + std::vector buffer_sizes; + for (auto& t : parameter_buffers) { + TORCH_CHECK(t.is_contiguous()); + buffers.push_back(t.data_ptr()); + buffer_sizes.push_back(t.nbytes()); + } + + TORCH_CHECK(serialized_model_tensor.is_contiguous()); + c10::ArrayRef ser_model = { + serialized_model_tensor.data_ptr(), + serialized_model_tensor.nbytes() + }; + TORCH_CHECK(ser_model.size() > 0); + + ANeuralNetworksModel* model; + check_nnapi->Model_create(&model); + CAFFE_ENFORCE(model); + model_.reset(model); + + int load_result = ::caffe2::nnapi::load_nnapi_model( + nnapi, + model_.get(), + ser_model.data(), + ser_model.size(), + buffers.size(), + buffers.data(), + buffer_sizes.data(), + 0, + nullptr, + nullptr, + &num_inputs_, + &num_outputs_, + nullptr); + CAFFE_ENFORCE(load_result == 0); + + check_nnapi->Model_finish(model_.get()); + + ANeuralNetworksCompilation* compilation; + check_nnapi->Compilation_create(model_.get(), &compilation); + // TODO: Make this configurable. + check_nnapi->Compilation_setPreference(compilation, ANEURALNETWORKS_PREFER_SUSTAINED_SPEED); + check_nnapi->Compilation_finish(compilation); + compilation_.reset(compilation); + } + + void run( + std::vector inputs, + std::vector outputs) { + ANeuralNetworksExecution* execution; + check_nnapi->Execution_create(compilation_.get(), &execution); + ExecutionPtr execution_unique_ptr(execution); + + TORCH_CHECK((int32_t)inputs.size() == num_inputs_); + TORCH_CHECK((int32_t)outputs.size() == num_outputs_); + + for (size_t i = 0; i < inputs.size(); i++) { + auto& t = inputs[i]; + // TODO: Check contiguous and dtype. + ANeuralNetworksOperandType op_type; + std::vector dim; + get_operand_type(t, &op_type, &dim); + check_nnapi->Execution_setInput( + execution, + i, + &op_type, + t.data_ptr(), + t.nbytes()); + } + + for (size_t i = 0; i < outputs.size(); i++) { + auto& t = outputs[i]; + // TODO: Check contiguous and dtype. + check_nnapi->Execution_setOutput( + execution, + i, + nullptr, + t.data_ptr(), + t.nbytes()); + } + + check_nnapi->Execution_compute(execution); + + // TODO: Maybe skip this for fixed-size outputs? + for (size_t i = 0; i < outputs.size(); i++) { + auto& t = outputs[i]; + uint32_t rank; + check_nnapi->Execution_getOutputOperandRank(execution, i, &rank); + std::vector dims(rank); + check_nnapi->Execution_getOutputOperandDimensions(execution, i, dims.data()); + std::vector long_dims(dims.begin(), dims.end()); + // TODO: Maybe check that only the batch dimension is changed? + t.resize_(long_dims); + } + } + + static void get_operand_type(const Tensor& t, ANeuralNetworksOperandType* operand, std::vector* dims) { + operand->dimensionCount = t.dim(); + TORCH_CHECK(operand->dimensionCount == t.dim()); // Check for overflow. + dims->resize(t.dim()); + operand->dimensions = dims->data(); + for (size_t i = 0; i < dims->size(); i++) { + (*dims)[i] = t.sizes()[i]; + TORCH_CHECK((*dims)[i] == t.sizes()[i]); // Check for overflow. + } + if (t.scalar_type() == torch::kFloat32) { + operand->type = ANEURALNETWORKS_TENSOR_FLOAT32; + operand->scale = 0; + operand->zeroPoint = 0; + return; + } + // TODO: Support more dtypes. + CAFFE_THROW("Bad dtype"); + } + + ModelPtr model_; + CompilationPtr compilation_; + int32_t num_inputs_; + int32_t num_outputs_; +}; + +static auto register_NnapiCompilation = [](){ + try { + return torch::jit::class_("_nnapi", "Compilation") + .def(torch::jit::init<>()) + .def("init", &NnapiCompilation::init) + .def("run", &NnapiCompilation::run) + ; + } catch (std::exception& exn) { + LOG(ERROR) << "Failed to register class nnapi.Compilation: " << exn.what(); + throw; + } +}(); + +} // namespace +} // namespace nnapi +} // namespace torch diff --git a/aten/src/ATen/nnapi/nnapi_model_loader.cpp b/aten/src/ATen/nnapi/nnapi_model_loader.cpp new file mode 100644 index 000000000000..ed51851c7138 --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_model_loader.cpp @@ -0,0 +1,264 @@ +#include + +#include "NeuralNetworks.h" +#include "nnapi_wrapper.h" +#include "nnapi_model_loader.h" + + +#ifndef NNAPI_LOADER_STANDALONE + +# include + +#else + +#define CAFFE_ENFORCE(cond, ...) do { if (!cond) { return -1; } } while (0) + +#endif + + +#define NNAPI_CHECK(res) CAFFE_ENFORCE(res == ANEURALNETWORKS_NO_ERROR) + + +namespace caffe2 { +namespace nnapi { + +namespace { + +/* +Serialized format for NNAPI models. It is basically just a list arguments +for calls to be made to NNAPI. +*/ + +typedef enum _SourceType { + SOURCE_IMMEDIATE = 0, + SOURCE_NUMBERED_BUFFER = 2, + SOURCE_NUMBERED_MEMORY = 3, +} SourceType; + +typedef struct _SerializedOperand { + int32_t type; + uint32_t dimension_count; + float scale; + int32_t zero_point; +} SerializedOperand; + +typedef struct _SerializedValue { + int32_t index; + int32_t source_type; + uint32_t source_length; +} SerializedValue; + +typedef struct _SerializedOperation { + int32_t operation_type; + uint32_t input_count; + uint32_t output_count; +} SerializedOperation; + +typedef struct _SerializedModel { + int32_t version; + int32_t operand_count; + int32_t value_count; + int32_t operation_count; + int32_t input_count; + int32_t output_count; + // SerializedOperand operands[operand_count]; + // SerializedValue values[value_count]; + // SerializedOperation operations[operation_count]; + // uint32_t operand_dimensions[sum(dimension_count)] + // uint32_t value_data[sum(source_length+pad)/4] + // uint32_t operation_args[sum(input_count + output_count)] + // uint32_t model_inputs[input_count] + // uint32_t model_outputs[output_count] +} SerializedModel; + + +/** + * Get the physically stored size of a value. All values are padded out + * to a multiple of 4 bytes to ensure the next value is 4-byte aligned. + */ +static uint32_t value_physical_size(uint32_t len) { + uint32_t phys = len; + if (len % 4 == 0) { + return len; + } + return len + 4 - (phys % 4); +} + +} // namespace + + +int load_nnapi_model( + struct nnapi_wrapper* nnapi, + ANeuralNetworksModel* model, + const void* serialized_model, + int64_t model_length, + size_t num_buffers, + const void** buffer_ptrs, + int32_t* buffer_sizes, + size_t num_memories, + ANeuralNetworksMemory** memories, + int32_t* memory_sizes, + int32_t* out_input_count, + int32_t* out_output_count, + size_t* out_bytes_consumed) { + int64_t required_size = 0; + const uint8_t* next_pointer = (const uint8_t*)serialized_model; + const uint8_t* end_of_buf = (const uint8_t*)serialized_model + model_length; + + required_size += sizeof(SerializedModel); + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + const SerializedModel* ser_model = (SerializedModel*)next_pointer; + next_pointer = (uint8_t*)serialized_model + required_size; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + CAFFE_ENFORCE(ser_model->version == 1); + // Keep these small to avoid integer overflow. + CAFFE_ENFORCE(ser_model->operand_count < (1 << 24)); + CAFFE_ENFORCE(ser_model->value_count < (1 << 24)); + CAFFE_ENFORCE(ser_model->operation_count < (1 << 24)); + CAFFE_ENFORCE(ser_model->input_count < (1 << 24)); + CAFFE_ENFORCE(ser_model->output_count < (1 << 24)); + + required_size += sizeof(SerializedOperand) * ser_model->operand_count; + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + const SerializedOperand* operands = (const SerializedOperand*)next_pointer; + next_pointer = (uint8_t*)serialized_model + required_size; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + required_size += sizeof(SerializedValue) * ser_model->value_count; + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + const SerializedValue* values = (const SerializedValue*)next_pointer; + next_pointer = (uint8_t*)serialized_model + required_size; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + required_size += sizeof(SerializedOperation) * ser_model->operation_count; + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + const SerializedOperation* operations = (const SerializedOperation*)next_pointer; + next_pointer = (uint8_t*)serialized_model + required_size; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + for (int i = 0; i < ser_model->operand_count; i++) { + required_size += 4 * operands[i].dimension_count; + } + + for (int i = 0; i < ser_model->value_count; i++) { + required_size += value_physical_size(values[i].source_length); + } + + for (int i = 0; i < ser_model->operation_count; i++) { + required_size += 4 * (operations[i].input_count + operations[i].output_count); + } + + required_size += 4 * (ser_model->input_count + ser_model->output_count); + + CAFFE_ENFORCE(model_length >= required_size, "Model is too small. Size = ", model_length); + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + for (int i = 0; i < ser_model->operand_count; i++) { + ANeuralNetworksOperandType operand; + operand.type = operands[i].type; + operand.scale = operands[i].scale; + operand.zeroPoint = operands[i].zero_point; + operand.dimensionCount = operands[i].dimension_count; + operand.dimensions = operands[i].dimension_count ? (const uint32_t*)next_pointer : NULL; + + next_pointer += 4 * operands[i].dimension_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + int result = nnapi->Model_addOperand(model, &operand); + NNAPI_CHECK(result); + } + + for (int i = 0; i < ser_model->value_count; i++) { + uint32_t len = values[i].source_length; + const uint8_t* stored_pointer = next_pointer; + const void* value_pointer = NULL; + size_t value_length; + + switch ((SourceType)values[i].source_type) { + case SOURCE_IMMEDIATE: + { + value_pointer = stored_pointer; + value_length = len; + } + break; + case SOURCE_NUMBERED_BUFFER: + { + CAFFE_ENFORCE(len == 12); + uint32_t buffer_number = *(uint32_t*)stored_pointer; + uint32_t buffer_offset = *(uint32_t*)(stored_pointer + 4); + uint32_t operand_length = *(uint32_t*)(stored_pointer + 8); + CAFFE_ENFORCE(buffer_number < num_buffers); + CAFFE_ENFORCE(buffer_offset + operand_length >= buffer_offset); // No integer overflow + CAFFE_ENFORCE(buffer_offset + operand_length <= (uint32_t)buffer_sizes[buffer_number]); // No buffer overflow + value_pointer = (uint8_t*)buffer_ptrs[buffer_number] + buffer_offset; + value_length = operand_length; + } + break; + case SOURCE_NUMBERED_MEMORY: + CAFFE_ENFORCE(false, "Memory inputs not implemented yet."); + break; + default: + CAFFE_ENFORCE(false, "Unknown source type: ", values[i].source_type); + } + + CAFFE_ENFORCE(value_pointer != NULL); + + next_pointer += value_physical_size(len); + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + int result = nnapi->Model_setOperandValue( + model, + values[i].index, + value_pointer, + value_length); + NNAPI_CHECK(result); + } + + for (int i = 0; i < ser_model->operation_count; i++) { + const uint32_t* inputs = (const uint32_t*)next_pointer; + next_pointer += 4 * operations[i].input_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + const uint32_t* outputs = (const uint32_t*)next_pointer; + next_pointer += 4 * operations[i].output_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + int result = nnapi->Model_addOperation( + model, + operations[i].operation_type, + operations[i].input_count, + inputs, + operations[i].output_count, + outputs); + NNAPI_CHECK(result); + } + + const uint32_t* model_inputs = (const uint32_t*)next_pointer; + next_pointer += 4 * ser_model->input_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + const uint32_t* model_outputs = (const uint32_t*)next_pointer; + next_pointer += 4 * ser_model->output_count; + CAFFE_ENFORCE(next_pointer <= end_of_buf); + + int result = nnapi->Model_identifyInputsAndOutputs( + model, + ser_model->input_count, + model_inputs, + ser_model->output_count, + model_outputs); + NNAPI_CHECK(result); + + *out_input_count = ser_model->input_count; + *out_output_count = ser_model->output_count; + + // TODO: Maybe eliminate required_size and just rely on next_pointer for bounds checking. + CAFFE_ENFORCE(next_pointer <= end_of_buf); + CAFFE_ENFORCE(next_pointer == (const uint8_t*)serialized_model + required_size); + if (out_bytes_consumed != NULL) { + *out_bytes_consumed = next_pointer - (const uint8_t*)serialized_model; + } + + return 0; +} + +}} // namespace caffe2::nnapi diff --git a/aten/src/ATen/nnapi/nnapi_model_loader.h b/aten/src/ATen/nnapi/nnapi_model_loader.h new file mode 100644 index 000000000000..870879b3aaa9 --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_model_loader.h @@ -0,0 +1,29 @@ +#ifndef NNAPI_MODEL_LOADER_H_ +#define NNAPI_MODEL_LOADER_H_ + +#include + +#include "NeuralNetworks.h" +#include "nnapi_wrapper.h" + +namespace caffe2 { +namespace nnapi { + +int load_nnapi_model( + struct nnapi_wrapper* nnapi, + ANeuralNetworksModel* model, + const void* serialized_model, + int64_t model_length, + size_t num_buffers, + const void** buffer_ptrs, + int32_t* buffer_sizes, + size_t num_memories, + ANeuralNetworksMemory** memories, + int32_t* memory_sizes, + int32_t* out_input_count, + int32_t* out_output_count, + size_t* out_bytes_consumed); + +}} // namespace caffe2::nnapi + +#endif // NNAPI_MODEL_LOADER_H_ diff --git a/aten/src/ATen/nnapi/nnapi_wrapper.cpp b/aten/src/ATen/nnapi/nnapi_wrapper.cpp new file mode 100644 index 000000000000..740902597f8c --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_wrapper.cpp @@ -0,0 +1,325 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is generated by nnapi/codegen.py +#include +#include "nnapi_wrapper.h" +#include "c10/util/Logging.h" +static int loaded = 0; +static struct nnapi_wrapper nnapi_; +static struct nnapi_wrapper check_nnapi_; +int check__getDeviceCount(uint32_t* numDevices) { + CAFFE_ENFORCE(nnapi_._getDeviceCount); + int ret = nnapi_._getDeviceCount(numDevices); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check__getDevice(uint32_t devIndex, ANeuralNetworksDevice** device) { + CAFFE_ENFORCE(nnapi_._getDevice); + int ret = nnapi_._getDevice(devIndex,device); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Device_getName(const ANeuralNetworksDevice* device, const char** name) { + CAFFE_ENFORCE(nnapi_.Device_getName); + int ret = nnapi_.Device_getName(device,name); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Device_getVersion(const ANeuralNetworksDevice* device, const char** version) { + CAFFE_ENFORCE(nnapi_.Device_getVersion); + int ret = nnapi_.Device_getVersion(device,version); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Device_getFeatureLevel(const ANeuralNetworksDevice* device, int64_t* featureLevel) { + CAFFE_ENFORCE(nnapi_.Device_getFeatureLevel); + int ret = nnapi_.Device_getFeatureLevel(device,featureLevel); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_getSupportedOperationsForDevices( const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps) { + CAFFE_ENFORCE(nnapi_.Model_getSupportedOperationsForDevices); + int ret = nnapi_.Model_getSupportedOperationsForDevices(model,devices,numDevices,supportedOps); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Compilation_createForDevices(ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation) { + CAFFE_ENFORCE(nnapi_.Compilation_createForDevices); + int ret = nnapi_.Compilation_createForDevices(model,devices,numDevices,compilation); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_compute(ANeuralNetworksExecution* execution) { + CAFFE_ENFORCE(nnapi_.Execution_compute); + int ret = nnapi_.Execution_compute(execution); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Memory_createFromFd(size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory) { + CAFFE_ENFORCE(nnapi_.Memory_createFromFd); + int ret = nnapi_.Memory_createFromFd(size,protect,fd,offset,memory); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Memory_free(ANeuralNetworksMemory* memory) { + CAFFE_ENFORCE(nnapi_.Memory_free); + nnapi_.Memory_free(memory); +} +int check_Model_create(ANeuralNetworksModel** model) { + CAFFE_ENFORCE(nnapi_.Model_create); + int ret = nnapi_.Model_create(model); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Model_free(ANeuralNetworksModel* model) { + CAFFE_ENFORCE(nnapi_.Model_free); + nnapi_.Model_free(model); +} +int check_Model_finish(ANeuralNetworksModel* model) { + CAFFE_ENFORCE(nnapi_.Model_finish); + int ret = nnapi_.Model_finish(model); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_addOperand(ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type) { + CAFFE_ENFORCE(nnapi_.Model_addOperand); + int ret = nnapi_.Model_addOperand(model,type); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_setOperandValue(ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length) { + CAFFE_ENFORCE(nnapi_.Model_setOperandValue); + int ret = nnapi_.Model_setOperandValue(model,index,buffer,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_setOperandValueFromMemory(ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length) { + CAFFE_ENFORCE(nnapi_.Model_setOperandValueFromMemory); + int ret = nnapi_.Model_setOperandValueFromMemory(model,index,memory,offset,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_addOperation(ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs) { + CAFFE_ENFORCE(nnapi_.Model_addOperation); + int ret = nnapi_.Model_addOperation(model,type,inputCount,inputs,outputCount,outputs); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_identifyInputsAndOutputs(ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs) { + CAFFE_ENFORCE(nnapi_.Model_identifyInputsAndOutputs); + int ret = nnapi_.Model_identifyInputsAndOutputs(model,inputCount,inputs,outputCount,outputs); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Model_relaxComputationFloat32toFloat16(ANeuralNetworksModel* model, bool allow) { + CAFFE_ENFORCE(nnapi_.Model_relaxComputationFloat32toFloat16); + int ret = nnapi_.Model_relaxComputationFloat32toFloat16(model,allow); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Compilation_create(ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation) { + CAFFE_ENFORCE(nnapi_.Compilation_create); + int ret = nnapi_.Compilation_create(model,compilation); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Compilation_free(ANeuralNetworksCompilation* compilation) { + CAFFE_ENFORCE(nnapi_.Compilation_free); + nnapi_.Compilation_free(compilation); +} +int check_Compilation_setPreference(ANeuralNetworksCompilation* compilation, int32_t preference) { + CAFFE_ENFORCE(nnapi_.Compilation_setPreference); + int ret = nnapi_.Compilation_setPreference(compilation,preference); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Compilation_finish(ANeuralNetworksCompilation* compilation) { + CAFFE_ENFORCE(nnapi_.Compilation_finish); + int ret = nnapi_.Compilation_finish(compilation); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_create(ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution) { + CAFFE_ENFORCE(nnapi_.Execution_create); + int ret = nnapi_.Execution_create(compilation,execution); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Execution_free(ANeuralNetworksExecution* execution) { + CAFFE_ENFORCE(nnapi_.Execution_free); + nnapi_.Execution_free(execution); +} +int check_Execution_setInput(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length) { + CAFFE_ENFORCE(nnapi_.Execution_setInput); + int ret = nnapi_.Execution_setInput(execution,index,type,buffer,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_setInputFromMemory(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length) { + CAFFE_ENFORCE(nnapi_.Execution_setInputFromMemory); + int ret = nnapi_.Execution_setInputFromMemory(execution,index,type,memory,offset,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_setOutput(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length) { + CAFFE_ENFORCE(nnapi_.Execution_setOutput); + int ret = nnapi_.Execution_setOutput(execution,index,type,buffer,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_setOutputFromMemory(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length) { + CAFFE_ENFORCE(nnapi_.Execution_setOutputFromMemory); + int ret = nnapi_.Execution_setOutputFromMemory(execution,index,type,memory,offset,length); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_startCompute(ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event) { + CAFFE_ENFORCE(nnapi_.Execution_startCompute); + int ret = nnapi_.Execution_startCompute(execution,event); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Event_wait(ANeuralNetworksEvent* event) { + CAFFE_ENFORCE(nnapi_.Event_wait); + int ret = nnapi_.Event_wait(event); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void check_Event_free(ANeuralNetworksEvent* event) { + CAFFE_ENFORCE(nnapi_.Event_free); + nnapi_.Event_free(event); +} +int check_Execution_getOutputOperandRank(ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank) { + CAFFE_ENFORCE(nnapi_.Execution_getOutputOperandRank); + int ret = nnapi_.Execution_getOutputOperandRank(execution,index,rank); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +int check_Execution_getOutputOperandDimensions(ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions) { + CAFFE_ENFORCE(nnapi_.Execution_getOutputOperandDimensions); + int ret = nnapi_.Execution_getOutputOperandDimensions(execution,index,dimensions); + // TODO: Maybe add better logging here. + CAFFE_ENFORCE(ret == ANEURALNETWORKS_NO_ERROR); + return ret; +} +void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi) { + if (!loaded) { + // Clear error flag. + dlerror(); + void* handle = dlopen("libneuralnetworks.so", RTLD_LAZY | RTLD_LOCAL); + CAFFE_ENFORCE(handle, "Failed to load libneuralnetworks.so ", dlerror()); + *(void**)&nnapi_._getDeviceCount = dlsym(handle, "ANeuralNetworks_getDeviceCount"); + check_nnapi_._getDeviceCount = check__getDeviceCount; + *(void**)&nnapi_._getDevice = dlsym(handle, "ANeuralNetworks_getDevice"); + check_nnapi_._getDevice = check__getDevice; + *(void**)&nnapi_.Device_getName = dlsym(handle, "ANeuralNetworksDevice_getName"); + check_nnapi_.Device_getName = check_Device_getName; + *(void**)&nnapi_.Device_getVersion = dlsym(handle, "ANeuralNetworksDevice_getVersion"); + check_nnapi_.Device_getVersion = check_Device_getVersion; + *(void**)&nnapi_.Device_getFeatureLevel = dlsym(handle, "ANeuralNetworksDevice_getFeatureLevel"); + check_nnapi_.Device_getFeatureLevel = check_Device_getFeatureLevel; + *(void**)&nnapi_.Model_getSupportedOperationsForDevices = dlsym(handle, "ANeuralNetworksModel_getSupportedOperationsForDevices"); + check_nnapi_.Model_getSupportedOperationsForDevices = check_Model_getSupportedOperationsForDevices; + *(void**)&nnapi_.Compilation_createForDevices = dlsym(handle, "ANeuralNetworksCompilation_createForDevices"); + check_nnapi_.Compilation_createForDevices = check_Compilation_createForDevices; + *(void**)&nnapi_.Execution_compute = dlsym(handle, "ANeuralNetworksExecution_compute"); + check_nnapi_.Execution_compute = check_Execution_compute; + *(void**)&nnapi_.Memory_createFromFd = dlsym(handle, "ANeuralNetworksMemory_createFromFd"); + check_nnapi_.Memory_createFromFd = check_Memory_createFromFd; + *(void**)&nnapi_.Memory_free = dlsym(handle, "ANeuralNetworksMemory_free"); + check_nnapi_.Memory_free = check_Memory_free; + *(void**)&nnapi_.Model_create = dlsym(handle, "ANeuralNetworksModel_create"); + check_nnapi_.Model_create = check_Model_create; + *(void**)&nnapi_.Model_free = dlsym(handle, "ANeuralNetworksModel_free"); + check_nnapi_.Model_free = check_Model_free; + *(void**)&nnapi_.Model_finish = dlsym(handle, "ANeuralNetworksModel_finish"); + check_nnapi_.Model_finish = check_Model_finish; + *(void**)&nnapi_.Model_addOperand = dlsym(handle, "ANeuralNetworksModel_addOperand"); + check_nnapi_.Model_addOperand = check_Model_addOperand; + *(void**)&nnapi_.Model_setOperandValue = dlsym(handle, "ANeuralNetworksModel_setOperandValue"); + check_nnapi_.Model_setOperandValue = check_Model_setOperandValue; + *(void**)&nnapi_.Model_setOperandValueFromMemory = dlsym(handle, "ANeuralNetworksModel_setOperandValueFromMemory"); + check_nnapi_.Model_setOperandValueFromMemory = check_Model_setOperandValueFromMemory; + *(void**)&nnapi_.Model_addOperation = dlsym(handle, "ANeuralNetworksModel_addOperation"); + check_nnapi_.Model_addOperation = check_Model_addOperation; + *(void**)&nnapi_.Model_identifyInputsAndOutputs = dlsym(handle, "ANeuralNetworksModel_identifyInputsAndOutputs"); + check_nnapi_.Model_identifyInputsAndOutputs = check_Model_identifyInputsAndOutputs; + *(void**)&nnapi_.Model_relaxComputationFloat32toFloat16 = dlsym(handle, "ANeuralNetworksModel_relaxComputationFloat32toFloat16"); + check_nnapi_.Model_relaxComputationFloat32toFloat16 = check_Model_relaxComputationFloat32toFloat16; + *(void**)&nnapi_.Compilation_create = dlsym(handle, "ANeuralNetworksCompilation_create"); + check_nnapi_.Compilation_create = check_Compilation_create; + *(void**)&nnapi_.Compilation_free = dlsym(handle, "ANeuralNetworksCompilation_free"); + check_nnapi_.Compilation_free = check_Compilation_free; + *(void**)&nnapi_.Compilation_setPreference = dlsym(handle, "ANeuralNetworksCompilation_setPreference"); + check_nnapi_.Compilation_setPreference = check_Compilation_setPreference; + *(void**)&nnapi_.Compilation_finish = dlsym(handle, "ANeuralNetworksCompilation_finish"); + check_nnapi_.Compilation_finish = check_Compilation_finish; + *(void**)&nnapi_.Execution_create = dlsym(handle, "ANeuralNetworksExecution_create"); + check_nnapi_.Execution_create = check_Execution_create; + *(void**)&nnapi_.Execution_free = dlsym(handle, "ANeuralNetworksExecution_free"); + check_nnapi_.Execution_free = check_Execution_free; + *(void**)&nnapi_.Execution_setInput = dlsym(handle, "ANeuralNetworksExecution_setInput"); + check_nnapi_.Execution_setInput = check_Execution_setInput; + *(void**)&nnapi_.Execution_setInputFromMemory = dlsym(handle, "ANeuralNetworksExecution_setInputFromMemory"); + check_nnapi_.Execution_setInputFromMemory = check_Execution_setInputFromMemory; + *(void**)&nnapi_.Execution_setOutput = dlsym(handle, "ANeuralNetworksExecution_setOutput"); + check_nnapi_.Execution_setOutput = check_Execution_setOutput; + *(void**)&nnapi_.Execution_setOutputFromMemory = dlsym(handle, "ANeuralNetworksExecution_setOutputFromMemory"); + check_nnapi_.Execution_setOutputFromMemory = check_Execution_setOutputFromMemory; + *(void**)&nnapi_.Execution_startCompute = dlsym(handle, "ANeuralNetworksExecution_startCompute"); + check_nnapi_.Execution_startCompute = check_Execution_startCompute; + *(void**)&nnapi_.Event_wait = dlsym(handle, "ANeuralNetworksEvent_wait"); + check_nnapi_.Event_wait = check_Event_wait; + *(void**)&nnapi_.Event_free = dlsym(handle, "ANeuralNetworksEvent_free"); + check_nnapi_.Event_free = check_Event_free; + *(void**)&nnapi_.Execution_getOutputOperandRank = dlsym(handle, "ANeuralNetworksExecution_getOutputOperandRank"); + check_nnapi_.Execution_getOutputOperandRank = check_Execution_getOutputOperandRank; + *(void**)&nnapi_.Execution_getOutputOperandDimensions = dlsym(handle, "ANeuralNetworksExecution_getOutputOperandDimensions"); + check_nnapi_.Execution_getOutputOperandDimensions = check_Execution_getOutputOperandDimensions; + loaded = 1; + } + *nnapi = &nnapi_; + *check_nnapi = &check_nnapi_; +} diff --git a/aten/src/ATen/nnapi/nnapi_wrapper.h b/aten/src/ATen/nnapi/nnapi_wrapper.h new file mode 100644 index 000000000000..c9dfdaa640c1 --- /dev/null +++ b/aten/src/ATen/nnapi/nnapi_wrapper.h @@ -0,0 +1,62 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is generated by nnapi/codegen.py +#ifndef NNAPI_WRAPPER_H_ +#define NNAPI_WRAPPER_H_ +#include +#include +#include "NeuralNetworks.h" +struct nnapi_wrapper { + int(*_getDeviceCount)(uint32_t* numDevices); + int(*_getDevice)(uint32_t devIndex, ANeuralNetworksDevice** device); + int(*Device_getName)(const ANeuralNetworksDevice* device, const char** name); + int(*Device_getVersion)(const ANeuralNetworksDevice* device, const char** version); + int(*Device_getFeatureLevel)(const ANeuralNetworksDevice* device, int64_t* featureLevel); + int(*Model_getSupportedOperationsForDevices)( const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps); + int(*Compilation_createForDevices)(ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation); + int(*Execution_compute)(ANeuralNetworksExecution* execution); + int(*Memory_createFromFd)(size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory); + void(*Memory_free)(ANeuralNetworksMemory* memory); + int(*Model_create)(ANeuralNetworksModel** model); + void(*Model_free)(ANeuralNetworksModel* model); + int(*Model_finish)(ANeuralNetworksModel* model); + int(*Model_addOperand)(ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type); + int(*Model_setOperandValue)(ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length); + int(*Model_setOperandValueFromMemory)(ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length); + int(*Model_addOperation)(ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs); + int(*Model_identifyInputsAndOutputs)(ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs); + int(*Model_relaxComputationFloat32toFloat16)(ANeuralNetworksModel* model, bool allow); + int(*Compilation_create)(ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation); + void(*Compilation_free)(ANeuralNetworksCompilation* compilation); + int(*Compilation_setPreference)(ANeuralNetworksCompilation* compilation, int32_t preference); + int(*Compilation_finish)(ANeuralNetworksCompilation* compilation); + int(*Execution_create)(ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution); + void(*Execution_free)(ANeuralNetworksExecution* execution); + int(*Execution_setInput)(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length); + int(*Execution_setInputFromMemory)(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length); + int(*Execution_setOutput)(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length); + int(*Execution_setOutputFromMemory)(ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length); + int(*Execution_startCompute)(ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event); + int(*Event_wait)(ANeuralNetworksEvent* event); + void(*Event_free)(ANeuralNetworksEvent* event); + int(*Execution_getOutputOperandRank)(ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank); + int(*Execution_getOutputOperandDimensions)(ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions); +}; +#ifdef __cplusplus +void nnapi_wrapper_load(struct nnapi_wrapper** nnapi, struct nnapi_wrapper** check_nnapi); +#endif +#endif diff --git a/binaries/CMakeLists.txt b/binaries/CMakeLists.txt index 075bc05b5ecf..34e4c9cfe62d 100644 --- a/binaries/CMakeLists.txt +++ b/binaries/CMakeLists.txt @@ -4,6 +4,10 @@ if(INTERN_BUILD_MOBILE) caffe2_binary_target("speed_benchmark.cc") else() caffe2_binary_target("speed_benchmark_torch.cc") + + if(ANDROID) + target_link_libraries(speed_benchmark_torch -Wl,--whole-archive pytorch_nnapi -Wl,--no-whole-archive) + endif() endif() return() endif() diff --git a/torch/backends/_nnapi/__init__.py b/torch/backends/_nnapi/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/backends/_nnapi/prepare.py b/torch/backends/_nnapi/prepare.py new file mode 100644 index 000000000000..a324e9a90054 --- /dev/null +++ b/torch/backends/_nnapi/prepare.py @@ -0,0 +1,167 @@ +from typing import Optional, List, Tuple, Any + +import torch +from torch.backends._nnapi.serializer import serialize_model + +class NnapiModule(torch.nn.Module): + """Torch Module that wraps an NNAPI Compilation. + + This module handles preparing the weights, initializing the + NNAPI TorchBind object, and adjusting the memory formats + of all inputs and outputs. + """ + + comp: Optional[torch.classes._nnapi.Compilation] + + def __init__( + self, + ser_model: torch.Tensor, + weights: List[torch.Tensor], + inp_mem_fmts: List[int], + out_mem_fmts: List[int], + out_templates: List[torch.Tensor]): + super().__init__() + self.ser_model = ser_model + self.weights = weights + self.inp_mem_fmts = inp_mem_fmts + self.out_mem_fmts = out_mem_fmts + self.out_templates = out_templates + self.comp = None + + @torch.jit.export + def init(self): + assert self.comp is None + self.weights = [ w.contiguous() for w in self.weights ] + comp = torch.classes._nnapi.Compilation() + comp.init(self.ser_model, self.weights) + self.comp = comp + + def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]: + comp = self.comp + assert comp is not None + outs = [ torch.empty_like(out) for out in self.out_templates ] + + assert len(args) == len(self.inp_mem_fmts) + fixed_args = [] + for idx in range(len(args)): + fmt = self.inp_mem_fmts[idx] + if fmt == 0: + fixed_args.append(args[idx].contiguous()) + elif fmt == 1: + fixed_args.append(args[idx].permute(0,2,3,1).contiguous()) + else: + raise Exception("Invalid mem_fmt") + comp.run(fixed_args, outs) + assert len(outs) == len(self.out_mem_fmts) + for idx in range(len(self.out_templates)): + fmt = self.out_mem_fmts[idx] + if fmt == 0: + pass + elif fmt == 1: + outs[idx] = outs[idx].permute(0,3,1,2) + else: + raise Exception("Invalid mem_fmt") + return outs + + +class NnapiInitWrapper(torch.nn.Module): + """Wrapper module to ensure NNAPI init is called.""" + def __init__(self, nnapi_module): + super().__init__() + self.nnapi_module = nnapi_module + + def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]: + return self.nnapi_module(args) + + @torch.jit.export + def __getstate__(self): + return self.nnapi_module + + @torch.jit.export + def __setstate__(self, nnapi_module): + self.training = False + self.nnapi_module = nnapi_module + self.nnapi_module.init() + + +class ListWrapper(torch.nn.Module): + """NNAPI list-ifying wrapper. + + NNAPI always expects a list of inputs. This module provides a + single-tensor input interface for models that want it. + """ + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, t: torch.Tensor) -> List[torch.Tensor]: + return self.mod([t]) + +class DelistWrapper(torch.nn.Module): + """NNAPI de-list-ifying wrapper. + + NNAPI always provides a list of outputs. This module provides a + single-tensor output interface for models that want it. + """ + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, ts: List[torch.Tensor]) -> torch.Tensor: + outs = self.mod(ts) + assert len(outs) == 1 + return outs[0] + +class ListDelistWrapper(torch.nn.Module): + """NNAPI list-ifying and de-list-ifying wrapper. + + NNAPI always expects a list of inputs and provides a list of outputs. + This module provides a single-tensor input/output interface + for models that want it. + """ + def __init__(self, mod): + super().__init__() + self.mod = mod + + def forward(self, t: torch.Tensor) -> torch.Tensor: + outs = self.mod([t]) + assert len(outs) == 1 + return outs[0] + + +def convert_model_to_nnapi(model, inputs): + model = torch.jit.freeze(model) + + outputs = model(inputs) + + if isinstance(inputs, torch.Tensor): + nnapi_inputs = [inputs] + list_inputs = True + else: + list_inputs = False + + if isinstance(outputs, torch.Tensor): + outputs = [outputs] + delist_outputs = True + else: + delist_outptus = False + + ser_model, used_weights, inp_mem_fmts, out_mem_fmts = serialize_model(model, nnapi_inputs) + ser_model_tensor = torch.tensor(list(ser_model), dtype=torch.uint8) + + out_templates = [torch.zeros(1, dtype=out.dtype).expand(out.shape) for out in outputs] + nnapi_model = NnapiInitWrapper(NnapiModule( + ser_model_tensor, + used_weights, + inp_mem_fmts, + out_mem_fmts, + out_templates)) + + if list_inputs and delist_outputs: + nnapi_model = ListDelistWrapper(nnapi_model) + elif list_inputs: + nnapi_model = ListWrapper(nnapi_model) + elif delist_outputs: + nnapi_model = DelistWrapper(nnapi_model) + + return nnapi_model diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py new file mode 100644 index 000000000000..dfa60f751154 --- /dev/null +++ b/torch/backends/_nnapi/serializer.py @@ -0,0 +1,1131 @@ +import enum +import collections +import struct +import operator +import functools +import logging +from typing import ( + Tuple, + NamedTuple, +) + +import torch + + + +#TODO: Add type annotations +#TODO: Check tensor types for ops + + +LOG = logging.getLogger("nnapi_serialize") + + +class NNAPI_OperandCode(object): + FLOAT32 = 0 + INT32 = 1 + UINT32 = 2 + TENSOR_FLOAT32 = 3 + TENSOR_INT32 = 4 + TENSOR_QUANT8_ASYMM = 5 + BOOL = 6 + TENSOR_QUANT16_SYMM = 7 + TENSOR_FLOAT16 = 8 + TENSOR_BOOL8 = 9 + FLOAT16 = 10 + TENSOR_QUANT8_SYMM_PER_CHANNEL = 11 + TENSOR_QUANT16_ASYMM = 12 + + +class NNAPI_OperationCode(object): + ADD = 0 + AVERAGE_POOL_2D = 1 + CONCATENATION = 2 + CONV_2D = 3 + DEPTHWISE_CONV_2D = 4 + DEPTH_TO_SPACE = 5 + DEQUANTIZE = 6 + EMBEDDING_LOOKUP = 7 + FLOOR = 8 + FULLY_CONNECTED = 9 + HASHTABLE_LOOKUP = 10 + L2_NORMALIZATION = 11 + L2_POOL_2D = 12 + LOCAL_RESPONSE_NORMALIZATION = 13 + LOGISTIC = 14 + LSH_PROJECTION = 15 + LSTM = 16 + MAX_POOL_2D = 17 + MUL = 18 + RELU = 19 + RELU1 = 20 + RELU6 = 21 + RESHAPE = 22 + RESIZE_BILINEAR = 23 + RNN = 24 + SOFTMAX = 25 + SPACE_TO_DEPTH = 26 + SVDF = 27 + TANH = 28 + BATCH_TO_SPACE_ND = 29 + DIV = 30 + MEAN = 31 + PAD = 32 + SPACE_TO_BATCH_ND = 33 + SQUEEZE = 34 + STRIDED_SLICE = 35 + SUB = 36 + TRANSPOSE = 37 + ABS = 38 + ARGMAX = 39 + ARGMIN = 40 + AXIS_ALIGNED_BBOX_TRANSFORM = 41 + BIDIRECTIONAL_SEQUENCE_LSTM = 42 + BIDIRECTIONAL_SEQUENCE_RNN = 43 + BOX_WITH_NMS_LIMIT = 44 + CAST = 45 + CHANNEL_SHUFFLE = 46 + DETECTION_POSTPROCESSING = 47 + EQUAL = 48 + EXP = 49 + EXPAND_DIMS = 50 + GATHER = 51 + GENERATE_PROPOSALS = 52 + GREATER = 53 + GREATER_EQUAL = 54 + GROUPED_CONV_2D = 55 + HEATMAP_MAX_KEYPOINT = 56 + INSTANCE_NORMALIZATION = 57 + LESS = 58 + LESS_EQUAL = 59 + LOG = 60 + LOGICAL_AND = 61 + LOGICAL_NOT = 62 + LOGICAL_OR = 63 + LOG_SOFTMAX = 64 + MAXIMUM = 65 + MINIMUM = 66 + NEG = 67 + NOT_EQUAL = 68 + PAD_V2 = 69 + POW = 70 + PRELU = 71 + QUANTIZE = 72 + QUANTIZED_16BIT_LSTM = 73 + RANDOM_MULTINOMIAL = 74 + REDUCE_ALL = 75 + REDUCE_ANY = 76 + REDUCE_MAX = 77 + REDUCE_MIN = 78 + REDUCE_PROD = 79 + REDUCE_SUM = 80 + ROI_ALIGN = 81 + ROI_POOLING = 82 + RSQRT = 83 + SELECT = 84 + SIN = 85 + SLICE = 86 + SPLIT = 87 + SQRT = 88 + TILE = 89 + TOPK_V2 = 90 + TRANSPOSE_CONV_2D = 91 + UNIDIRECTIONAL_SEQUENCE_LSTM = 92 + UNIDIRECTIONAL_SEQUENCE_RNN = 93 + RESIZE_NEAREST_NEIGHBOR = 94 + + +class NNAPI_FuseCode(object): + FUSED_NONE = 0 + FUSED_RELU = 1 + FUSED_RELU1 = 2 + FUSED_RELU6 = 3 + + +class OperandValueSourceType(object): + IMMEDIATE = 0 + NUMBERED_BUFFER = 2 + NUMBERED_MEMORY = 3 + + + +# Scalar types that appear explicitly in models. +# These must be kept in sync with +# AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS. +# TODO: Expose these directly to Python to avoid maintaining this list. +class TorchScalarTypes(enum.Enum): + QUINT8 = 13 + + +def approx_equal(lhs, rhs, tolerance = 1e-6): + return abs(lhs - rhs) <= tolerance * min(lhs, rhs) + + +def tensor_size(op_type, dims): + ITEM_SIZES = { + NNAPI_OperandCode.TENSOR_FLOAT32: 4, + NNAPI_OperandCode.TENSOR_INT32: 4, + NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: 1, + NNAPI_OperandCode.TENSOR_QUANT16_SYMM: 2, + } + size = ITEM_SIZES[op_type] + for d in dims: + size *= d + return size + + +class ConvPoolArgs2d(NamedTuple): + """Configuration arguments for a convolution.""" + kernel_h: int + kernel_w: int + stride_h: int + stride_w: int + pad_t: int + pad_b: int + pad_l: int + pad_r: int + dilation_h: int + dilation_w: int + group: int + + +class DimOrder(enum.Enum): + PRESUMED_CONTIGUOUS = 0 + CHANNELS_LAST = 1 + SCALAR_OR_VECTOR = 2 + UNKNOWN_CONSTANT = 999 + + +class Operand(NamedTuple): + """Represenation of an NNAPI operand.""" + + # NNAPI operand type. One of NNAPI_OperandCode. + # TODO: Make this an enum. + op_type: int + + # This is always the PyTorch shape, which is NCHW for feature maps. + # The actual NNAPI operand might have a transposed shape. + shape: Tuple[int, ...] + + # Specifies how the shape of the operand that we define in NNAPI + # relates to the shape we track above. + # - PRESUMED_CONTIGUOUS: physical NNAPI operand will exactly match + # the shape of the PyTorch tensor. + # - CHANNELS_LAST: The PyTorch tensor is expected to be NCHW, and + # the NNAPI operand will be represented explicitly as NHWC. + dim_order: DimOrder + + # Quantization params + scale: float + zero_point: int + + def use_nchw(self): + if self.dim_order is DimOrder.PRESUMED_CONTIGUOUS: + return True + if self.dim_order is DimOrder.CHANNELS_LAST: + return False + raise Exception("Unknown dim order") + + + +def broadcast_shapes(shape1, shape2): + assert len(shape1) > 0 + assert len(shape2) > 0 + s1 = list(shape1) + s2 = list(shape2) + if len(s1) > len(s2): + #s2 = [1] * (len(s1) - len(s2)) + s2 + raise Exception("Non-equal-rank broadcast is too dangerous because XXX.") + if len(s2) > len(s1): + #s3 = [1] * (len(s2) - len(s1)) + s1 + raise Exception("Non-equal-rank broadcast is too dangerous because XXX.") + ret = [] + for d1, d2 in zip(s1, s2): + if d1 == 1: + ret.append(d2) + elif d2 == 1: + ret.append(d1) + elif d1 == d2: + ret.append(d1) + else: + raise Exception("Cannot broadcast shapes: {} and {}".format(shape1, shape2)) + return tuple(ret) + +def get_conv_pool_shape(image_shape, args, out_ch, transpose): + batch, in_c, in_h, in_w = image_shape + + # TODO: Handle dilation + if args.dilation_h != 1 or args.dilation_w != 1: + raise Exception("Dilation not supported yet.") + + if transpose: + out_h = (in_h - 1) * args.stride_h + args.kernel_h - args.pad_t - args.pad_b + out_w = (in_w - 1) * args.stride_w + args.kernel_w - args.pad_l - args.pad_l + else: + out_h = (in_h - args.kernel_h + args.pad_t + args.pad_b) // args.stride_h + 1 + out_w = (in_w - args.kernel_w + args.pad_l + args.pad_r) // args.stride_w + 1 + + # Handle variable-sized tensors. + if in_h == 0: + out_h = 0 + if in_w == 0: + out_w = 0 + + out_shape = (batch, out_ch, out_h, out_w) + return out_shape + + +def fix_shape(shape, dim_order): + # Return the actual shape that an operand should have in NNAPI, + # given a PyTorch shape and dimension order. This is where we + # convert from PyTorch's "always NCHW" shape to explicit NHWC. + if dim_order is DimOrder.PRESUMED_CONTIGUOUS: + return shape + if dim_order is DimOrder.CHANNELS_LAST: + return tuple([shape[0]] + list(shape[2:]) + [shape[1]]) + if dim_order is DimOrder.SCALAR_OR_VECTOR: + assert len(shape) == 0 or len(shape) == 1 + return shape + if dim_order is DimOrder.UNKNOWN_CONSTANT: + # XXX think this through + return shape + raise Exception(f"Bad dim_order: {dim_order!r}.") + + +class _NnapiSerializer(object): + def __init__(self, config): + self.operands = [] + self.values = [] + self.operations = [] + self.value_data = [] + self.operation_args = [] + self.inputs = [] + self.outputs = [] + + self.modules = {} + self.constants = {} + self.jitval_operand_map = {} + self.cached_immediates = {} + self.used_weights = [] + self.weight_offset = 0 + + if config is None: + config = {} + + self.solid_weights = config.get("solid_weights", False) + + + def add_tensor_operand(self, jitval, oper): + assert isinstance(oper, Operand) + if jitval in self.jitval_operand_map: + raise Exception("Duplicate tensor: %r" % jitval) + + operand_id = len(self.operands) + self.operands.append(oper) + self.jitval_operand_map[jitval] = operand_id + return operand_id + + + @staticmethod + def torch_tensor_to_operand(tensor, dim_order): + dtype = str(tensor.dtype).replace("torch.", "") + scale = 0.0 + zero_point = 0 + if dtype == "float32": + op_type = NNAPI_OperandCode.TENSOR_FLOAT32 + elif dtype == "quint8": + op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + scale = tensor.q_scale() + zero_point = tensor.q_zero_point() + elif dtype == "qint32": + op_type = NNAPI_OperandCode.TENSOR_INT32 + scale = tensor.q_scale() + zero_point = tensor.q_zero_point() + assert zero_point == 0 + else: + raise Exception(f"Can't handle input with dtype '{tensor.dtype}'") + return Operand( + shape=tuple(tensor.shape), + op_type=op_type, + dim_order=dim_order, + scale=scale, + zero_point=zero_point, + ) + + def add_tensor_operand_for_input(self, jitval, tensor): + dim_order = ( + DimOrder.CHANNELS_LAST if getattr(tensor, "nnapi_nhwc", False) + else DimOrder.PRESUMED_CONTIGUOUS) + toper = self.torch_tensor_to_operand(tensor, dim_order) + operand_id = self.add_tensor_operand(jitval, toper) + self.inputs.append(operand_id) + return operand_id + + + def add_tensor_operand_for_weight(self, tensor): + toper = self.torch_tensor_to_operand(tensor, DimOrder.UNKNOWN_CONSTANT) + operand_id = len(self.operands) + self.operands.append(toper) + tsize = tensor_size(toper.op_type, toper.shape) + psize = ((tsize - 1) | 0x3) + 1 + self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) + if self.solid_weights: + buf_num = 0 + offset = self.weight_offset + self.weight_offset += psize + else: + buf_num = len(self.used_weights) + offset = 0 + self.value_data.append(struct.pack( + "iii", + buf_num, + offset, + tsize)) + self.used_weights.append(tensor) + return operand_id + + + def add_immediate_operand(self, code, value, dims): + assert isinstance(dims, tuple) + cache_key = (code, value) + if cache_key not in self.cached_immediates: + operand_id = len(self.operands) + self.operands.append(Operand(code, dims, DimOrder.SCALAR_OR_VECTOR, 0.0, 0)) + self.values.append((operand_id, OperandValueSourceType.IMMEDIATE)) + self.value_data.append(value) + self.cached_immediates[cache_key] = operand_id + return self.cached_immediates[cache_key] + + + def add_immediate_int_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.INT32, + struct.pack("i", value), + ()) + + + def add_immediate_float_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.FLOAT32, + struct.pack("f", value), + ()) + + + def add_immediate_bool_scalar(self, value): + return self.add_immediate_operand( + NNAPI_OperandCode.BOOL, + b"\x01" if value else b"\x00", + ()) + + + def get_tensor_operand_by_jitval(self, jitval): + operand_id = self.jitval_operand_map[jitval] + return (operand_id, self.operands[operand_id]) + + + def get_tensor_operand_or_constant(self, jitval): + operand_id = self.jitval_operand_map.get(jitval) + if operand_id is None: + _, value = self.get_constant_value(jitval, "TensorType") + operand_id = self.add_tensor_operand_for_weight(value) + return (operand_id, self.operands[operand_id]) + + + def get_tensor_operand_for_weight(self, jitval): + _, value = self.get_constant_value(jitval, "TensorType") + operand_id = self.add_tensor_operand_for_weight(value) + return (operand_id, self.operands[operand_id]) + + + def add_operation(self, opcode, inputs, outputs): + self.operations.append((opcode, len(inputs), len(outputs))) + self.operation_args.extend(inputs + outputs) + + def add_constant_value(self, jitval, ctype, value): + assert jitval not in self.constants + self.constants[jitval] = (ctype, value) + + def get_constant_value(self, jitval, typekind=None): + record = self.constants.get(jitval) + if record is None: + raise Exception(f"Could not find constant value for '{jitval!r}'.") + ctype, _ = record + if typekind is not None and ctype.kind() != typekind: + raise Exception(f"Expected constant value of type {typekind}, but got {ctype.kind()} for value '{jitval!r}'") + return record + + def get_size_arg(self, jitval): + ctype, value = self.get_constant_value(jitval) + if ctype.kind() == "ListType": + assert ctype.getElementType().kind() == "IntType" + return value + raise Exception(f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'") + + def expand_sizes(self, size): + return [ s.item() for s in size ] + + def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config): + pc = [ i.item() for i in packed_config ] + assert pc[0] == 2 + strides = [pc[1], pc[2]] + paddings = [pc[3], pc[4]] + dilations = [pc[5], pc[6]] + output_padding = [pc[7], pc[8]] + group_num = pc[9] + transpose = pc[10] + + assert len(pc) == 11 + assert output_padding == [0, 0] + assert transpose == 0 + + return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num) + + def get_conv_pool_args_2d_from_jit(self, kernel_size, stride, padding, dilation, group=None): + strides = self.get_size_arg(stride) + paddings = self.get_size_arg(padding) + dilations = self.get_size_arg(dilation) + if group is not None: + _, group_num = self.get_constant_value(group, "IntType") + else: + group_num = None + return self.get_conv_pool_args_2d_common(kernel_size, strides, paddings, dilations, group_num) + + def get_conv_pool_args_2d_common(self, kernel_size, strides, paddings, dilations, group_num): + kernels = list(kernel_size) + + assert len(kernels) == 2 + assert len(strides) == 2 + assert len(paddings) == 2 + assert len(dilations) == 2 + + # NNAPI uses 4 values for padding. + ph, pw = paddings + real_paddings = [ph, ph, pw, pw] + + return ConvPoolArgs2d(*(kernels + strides + real_paddings + dilations + [group_num])) + + + + def serialize_model(self, model, inputs): + self.add_immediate_bool_scalar(False) + self.add_immediate_bool_scalar(True) + + inp_dim_orders = [] + out_dim_orders = [] + + self_jitval = next(model.graph.inputs()) + self.add_constant_value(self_jitval, self_jitval.type(), model) + + for input_value, input_tensor in zip(list(model.graph.inputs())[1:], inputs): + op_id = self.add_tensor_operand_for_input(input_value, input_tensor) + inp_dim_orders.append(self.operands[op_id].dim_order.value) + + for idx, node in enumerate(model.graph.nodes()): + LOG.debug("Processing node #%d: %r", idx, node) + self.add_node(node) + + retn = model.graph.return_node() + assert retn.inputsSize() == 1 + assert retn.outputsSize() == 0 + # TODO: Make outputs a local variable? + # TODO: Handle tuple-of-tensor return + for idx in range(1): + op_id = self.jitval_operand_map[retn.inputsAt(0)] + self.outputs.append(op_id) + out_dim_orders.append(self.operands[op_id].dim_order.value) + + model = [] + + version = 1 + header = struct.pack( + "iiiiii", + version, + len(self.operands), + len(self.values), + len(self.operations), + len(self.inputs), + len(self.outputs), + ) + model.append(header) + + serialized_values, serialized_value_data = self.serialize_values() + + model.extend(struct.pack("iifi", t, len(d), s, z) for (t,d,_m,s,z) in self.operands) + model.extend(serialized_values) + model.extend(struct.pack("iii", *x) for x in self.operations) + model.extend(self.serialize_ints(fix_shape(dims, mf)) for (_, dims, mf, _, _) in self.operands) + model.extend(serialized_value_data) + model.append(self.serialize_ints(self.operation_args)) + model.append(self.serialize_ints(self.inputs)) + model.append(self.serialize_ints(self.outputs)) + + #return (b"".join(model), self.used_weight_tensor_names) + return (b"".join(model), self.used_weights, inp_dim_orders, out_dim_orders) + + + def serialize_values(self): + serialized_values = [] + serialized_value_data = [] + assert len(self.values) == len(self.value_data) + for ((op_index, source_type), data) in zip(self.values, self.value_data): + source_length = len(data) + + # Pad with 0 bytes out to a multiple of 4 for alignment. + physical_length = ((source_length - 1) | 0x3) + 1 + padded_data = data + (b"\0" * (physical_length - source_length)) + + serialized_values.append(struct.pack("iii", op_index, source_type, source_length)) + serialized_value_data.append(padded_data) + + return serialized_values, serialized_value_data + + + @staticmethod + def serialize_ints(ints): + return struct.pack("i" * len(ints), *ints) + + + ADDER_MAP = { + "prim::GetAttr": lambda self, node: + self.add_getattr(node), + "prim::Constant": lambda self, node: + self.add_constant_node(node), + "prim::ListConstruct": lambda self, node: + self.add_list_construct(node), + "aten::quantize_per_tensor": lambda self, node: + self.add_quantize(node), + "aten::dequantize": lambda self, node: + self.add_dequantize(node), + "aten::add": lambda self, node: + self.add_add_sub_op(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE), + "aten::sub": lambda self, node: + self.add_add_sub_op(node, NNAPI_OperationCode.SUB, NNAPI_FuseCode.FUSED_NONE), + "aten::mul": lambda self, node: + self.add_pointwise_simple_binary_broadcast_op(node, NNAPI_OperationCode.MUL), + "aten::relu": lambda self, node: + self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.RELU), + "aten::sigmoid": lambda self, node: + self.add_pointwise_simple_unary_op(node, NNAPI_OperationCode.LOGISTIC), + "aten::max_pool2d": lambda self, node: + self.add_pool2d_node(node, NNAPI_OperationCode.MAX_POOL_2D), + "aten::upsample_nearest2d": lambda self, node: + self.add_upsample_nearest2d(node), + "aten::prelu": lambda self, node: + self.add_prelu_op(node), + "aten::_convolution": lambda self, node: + self.add_conv(node), + "quantized::conv2d": lambda self, node: + self.add_qconv2d(node, NNAPI_FuseCode.FUSED_NONE), + "quantized::conv2d_relu": lambda self, node: + self.add_qconv2d(node, NNAPI_FuseCode.FUSED_RELU), + "quantized::add": lambda self, node: + self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_NONE), + "quantized::add_relu": lambda self, node: + self.add_qadd(node, NNAPI_OperationCode.ADD, NNAPI_FuseCode.FUSED_RELU), + } + + + def add_node(self, node): + adder = self.ADDER_MAP.get(node.kind()) + if not adder: + print(node) + raise Exception("Unsupported node type: %r" % node.kind()) + adder(self, node) + + + def add_getattr(self, node): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + obj_ctype, obj = self.get_constant_value(node.inputsAt(0)) + assert str(obj_ctype).startswith("__torch__.") + name = node.s("name") + value = getattr(obj, name) + output = node.outputsAt(0) + ctype = output.type() + self.add_constant_value(output, ctype, value) + + + def add_constant_node(self, node): + assert node.inputsSize() == 0 + assert node.outputsSize() == 1 + output = node.outputsAt(0) + ctype = output.type() + value = output.toIValue() + #if ctype.kind() == "NoneType": + # value = None + #else: + # assert node.hasAttribute("value") + # valueKind = node.kindOf("value") + # print(valueKind) + # # JIT: Is this dirty? Is there a better way? + # value = getattr(node, valueKind)("value") + self.add_constant_value(output, ctype, value) + + + def add_list_construct(self, node): + assert node.outputsSize() == 1 + output = node.outputsAt(0) + ctype = output.type() + values = [] + for inp in node.inputs(): + _, val = self.get_constant_value(inp) + values.append(val) + self.add_constant_value(output, ctype, values) + + + def add_quantize(self, node): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + if in_oper.dim_order != DimOrder.CHANNELS_LAST: + raise Exception( + "Most hardware backends prefer NHWC quantized tensors. " + "Try setting `t.nnapi_nhwc = True` on your tensor inputs. ") + _, scale = self.get_constant_value(node.inputsAt(1), "FloatType") + _, zero_point = self.get_constant_value(node.inputsAt(2), "IntType") + _, scalar_type = self.get_constant_value(node.inputsAt(3), "IntType") + if scalar_type != TorchScalarTypes.QUINT8.value: + raise Exception( + "PyTorch NNAPI export only supports quantized tensors " + "with the quint8 dtype.") + op_type = NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + + out_oper = in_oper._replace( + op_type=op_type, + scale=scale, + zero_point=zero_point, + ) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.QUANTIZE, inputs, outputs) + + + def add_dequantize(self, node): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + out_oper = in_oper._replace( + op_type=NNAPI_OperandCode.TENSOR_FLOAT32, + scale=0.0, + zero_point=0, + ) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(NNAPI_OperationCode.DEQUANTIZE, inputs, outputs) + + + def add_pointwise_simple_unary_op(self, node, opcode): + assert node.inputsSize() == 1 + assert node.outputsSize() == 1 + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + + inputs = [None] * 1 + inputs[0] = in_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) + + self.add_operation(opcode, inputs, outputs) + + + def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): + """Helper for pointwise binary broadcast ops with superfluous extra args""" + assert node.outputsSize() == 1 + + assert node.inputsAt(0).type().kind() == "TensorType" + assert node.inputsAt(1).type().kind() == "TensorType" + + # TODO: Should support constant as either operand. + in0_id, in0_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + in1_id, in1_oper = self.get_tensor_operand_by_jitval(node.inputsAt(1)) + + assert in0_oper.op_type == in1_oper.op_type + assert in0_oper.dim_order == in1_oper.dim_order + # NOTE: PyTorch and NNAPI have the same broadcast semantics. + out_shape = broadcast_shapes(in0_oper.shape, in1_oper.shape) + out_oper = in0_oper._replace(shape=out_shape) + if qparams is not None: + scale, zp = qparams + out_oper = out_oper._replace(scale=scale, zero_point=zp) + + inputs = [None] * 3 + inputs[0] = in0_id + inputs[1] = in1_id + inputs[2] = self.add_immediate_int_scalar(fuse_code) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper) + + self.add_operation(opcode, inputs, outputs) + + + def add_pointwise_simple_binary_broadcast_op(self, node, opcode): + assert node.inputsSize() == 2 + self._do_add_binary(node, opcode) + + + def add_add_sub_op(self, node, opcode, fuse_code): + assert node.inputsSize() == 3 + + _, alpha = self.get_constant_value(node.inputsAt(2), "IntType") + if alpha != 1: + raise Exception(f"NNAPI does not support add/sub with alpha.") + + self._do_add_binary(node, opcode, fuse_code) + + + def add_qadd(self, node, opcode, fuse_code): + assert node.inputsSize() == 4 + + _, scale = self.get_constant_value(node.inputsAt(2), "FloatType") + _, zero_point = self.get_constant_value(node.inputsAt(3), "IntType") + + self._do_add_binary(node, opcode, fuse_code, qparams=(scale, zero_point)) + + + def add_prelu_op(self, node): + assert node.inputsSize() == 2 + assert node.outputsSize() == 1 + + assert node.inputsAt(0).type().kind() == "TensorType" + assert node.inputsAt(1).type().kind() == "TensorType" + + in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + w_id, w_oper = self.get_tensor_operand_for_weight(node.inputsAt(1)) + assert len(w_oper.shape) == 1 + assert w_oper.shape[0] > 0 + if w_oper.shape[0] > 1: + if in_oper.use_nchw(): + # TODO: Support this by adding trailing 1 dims. + raise Exception("Per-channel PReLU only supports channels_last right now.") + + inputs = [None] * 2 + inputs[0] = in_id + inputs[1] = w_id + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), in_oper) + + self.add_operation(NNAPI_OperationCode.PRELU, inputs, outputs) + + + def add_pool2d_node(self, node, opcode): + assert node.inputsSize() == 6 + assert node.outputsSize() == 1 + image, kernel, stride, padding, dilation, ceil_mode = node.inputs() + + stride = stride or kernel + + # TODO: Validate ceil_mode semantics. + + args = self.get_conv_pool_args_2d_from_jit(self.get_size_arg(kernel), stride, padding, dilation) + if args.dilation_h != 1 or args.dilation_w != 1: + raise Exception("NNAPI does not support dilated pooling.") + + image_id, image_oper = self.get_tensor_operand_by_jitval(image) + assert len(image_oper.shape) == 4 + + out_shape = get_conv_pool_shape(image_oper.shape, args, image_oper.shape[1], False) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 11 + inputs[0] = image_id + inputs[1] = self.add_immediate_int_scalar(args.pad_l) + inputs[2] = self.add_immediate_int_scalar(args.pad_r) + inputs[3] = self.add_immediate_int_scalar(args.pad_t) + inputs[4] = self.add_immediate_int_scalar(args.pad_b) + inputs[5] = self.add_immediate_int_scalar(args.stride_w) + inputs[6] = self.add_immediate_int_scalar(args.stride_h) + inputs[7] = self.add_immediate_int_scalar(args.kernel_w) + inputs[8] = self.add_immediate_int_scalar(args.kernel_h) + inputs[9] = self.add_immediate_int_scalar(NNAPI_FuseCode.FUSED_NONE) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + + self.add_operation(opcode, inputs, outputs) + + + def add_upsample_nearest2d(self, node): + assert node.inputsSize() == 3 + assert node.outputsSize() == 1 + image, size_jit, scale_jit = node.inputs() + size_ctype, size_arg = self.get_constant_value(size_jit) + scale_ctype, scale_arg = self.get_constant_value(scale_jit) + + image_id, image_oper = self.get_tensor_operand_by_jitval(image) + assert len(image_oper.shape) == 4 + + if size_ctype.kind() != "NoneType" and scale_ctype.kind() != "NoneType": + raise Exception("Size and scale cannot both be non-None.") + elif size_ctype.kind() != "NoneType": + assert size_ctype.kind() == "ListType" + assert size_ctype.getElementType().kind() == "IntType" + assert scale_ctype.kind() == "NoneType" + assert scale_arg is None + assert isinstance(size_arg, list) + assert size_arg + assert all(isinstance(val, int) for val in size_arg) + if len(size_arg) == 1: + size_arg = size_arg * 2 + assert len(size_arg) == 2 + out_h = size_arg[0] + out_w = size_arg[1] + arg_h = self.add_immediate_int_scalar(out_h) + arg_w = self.add_immediate_int_scalar(out_w) + elif scale_ctype.kind() != "NoneType": + assert scale_ctype.kind() == "ListType" + assert scale_ctype.getElementType().kind() == "FloatType" + assert size_ctype.kind() == "NoneType" + assert size_arg is None + assert isinstance(scale_arg, list) + assert scale_arg + assert all(isinstance(val, float) for val in scale_arg) + if len(scale_arg) == 1: + scale_arg = scale_arg * 2 + assert len(scale_arg) == 2 + out_h = int(scale_arg[0] * image_oper.shape[2]) + out_w = int(scale_arg[1] * image_oper.shape[3]) + arg_h = self.add_immediate_float_scalar(scale_arg[0]) + arg_w = self.add_immediate_float_scalar(scale_arg[1]) + else: + raise Exception("Size and scale cannot both be None.") + + out_shape = (image_oper.shape[0], image_oper.shape[1], out_h, out_w) + use_nchw = image_oper.use_nchw() + + inputs = [None] * 4 + inputs[0] = image_id + inputs[1] = arg_w + inputs[2] = arg_h + inputs[3] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(node.outputsAt(0), image_oper._replace(shape=out_shape)) + + self.add_operation(NNAPI_OperationCode.RESIZE_NEAREST_NEIGHBOR, inputs, outputs) + + + def add_conv(self, node): + assert node.inputsSize() == 13 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_weight, + jit_bias, + jit_stride, + jit_pad, + jit_dilation, + jit_transpose, + _, + jit_groups, + _, + _, + _, + _, + ) = node.inputs() + + _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") + bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias) + args = self.get_conv_pool_args_2d_from_jit(weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups) + + return self.add_conv2d_common( + node.outputsAt(0), + 0.0, + 0, + jit_image, + weight_tensor, + bias_id, + args, + False, # transpose + NNAPI_FuseCode.FUSED_NONE, + ) + + + def add_qconv2d(self, node, fuse_code): + assert node.inputsSize() == 4 + assert node.outputsSize() == 1 + + ( + jit_image, + jit_packed_weight, + jit_scale, + jit_zero_point, + ) = node.inputs() + + _, out_scale = self.get_constant_value(jit_scale, "FloatType") + _, out_zero_point = self.get_constant_value(jit_zero_point, "IntType") + weight_ctype, packed_weight = self.get_constant_value(jit_packed_weight) + assert weight_ctype.name() == "Conv2dPackedParamsBase" + ( + pack_version, + tensors, + opt_tensors, + ) = packed_weight.__getstate__()[0] + assert pack_version == "2" + packed_config, raw_weight = tensors + raw_bias, = opt_tensors + assert raw_bias is not None + args = self.get_conv_pool_args_2d_from_pack(raw_weight.shape[2:4], packed_config) + + assert raw_weight.qscheme() == torch.per_tensor_affine + if raw_weight.dtype == torch.quint8: + unsigned_weight = raw_weight + else: + assert raw_weight.dtype == torch.qint8 + unsigned_weight = torch._make_per_tensor_quantized_tensor( + (raw_weight.int_repr().int() + 128).to(torch.uint8), + scale=raw_weight.q_scale(), + zero_point=raw_weight.q_zero_point() + 128) + weight_scale = unsigned_weight.q_scale() + _, image_oper = self.get_tensor_operand_by_jitval(jit_image) + bias_scale = image_oper.scale * weight_scale + int_bias = torch.quantize_per_tensor(raw_bias, bias_scale, 0, torch.qint32) + bias_id = self.add_tensor_operand_for_weight(int_bias) + + multiplier = image_oper.scale * weight_scale / out_scale + assert multiplier > 0 + if multiplier >= 1: + raise Exception( + "Quantized convolution multiplier is greater than 1. " + "This is supported by NNAPI, but not by most hardware backends. " + "Try training a model without quantization-aware training. ") + + return self.add_conv2d_common( + node.outputsAt(0), + out_scale, + out_zero_point, + jit_image, + unsigned_weight, + bias_id, + args, + False, # transpose + fuse_code, + ) + + + def add_conv2d_common(self, + jit_out, + out_scale, + out_zero_point, + jit_image, + weight_tensor, + bias_id, + args, + transpose, + fuse_code): + image_id, image_oper = self.get_tensor_operand_by_jitval(jit_image) + in_c = image_oper.shape[1] + + if args.group == 1: + # Full convolution + depthwise = False + weight_permutation = (0, 2, 3, 1) + elif args.group == in_c: + # Depthwise convolution + depthwise = True + weight_permutation = (1, 2, 3, 0) + else: + raise Exception("Group convolution not supported yet.") + + # TODO: Transform at load time to share weights with CPU model. + nnapi_weight_tensor = weight_tensor.permute(*weight_permutation).contiguous() + weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor) + weight_oper = self.operands[weight_id] + + bias_oper = self.operands[bias_id] + + if image_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32: + assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 + assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_FLOAT32 + elif image_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM: + assert weight_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM + assert bias_oper.op_type == NNAPI_OperandCode.TENSOR_INT32 + assert approx_equal(image_oper.scale * weight_oper.scale, bias_oper.scale) + assert bias_oper.zero_point == 0 + else: + raise Exception( + "Unsupported input type for conv2d: {}" + .format(image_oper.op_type)) + + assert len(image_oper.shape) == 4 + assert len(weight_oper.shape) == 4 + assert len(bias_oper.shape) == 1 + + + if depthwise: + # Depthwise convolution + one, kern_h, kern_w, out_c = weight_oper.shape + assert one == 1 + assert out_c % in_c == 0 + channel_multiplier = out_c // in_c + assert channel_multiplier == 1 # Don't support multiplier + assert out_c == in_c + else: + # Full convolution + kern_nf, kern_h, kern_w, kern_d = weight_oper.shape + out_c = kern_nf + assert kern_d == in_c + + assert out_c == bias_oper.shape[0] + + out_shape = get_conv_pool_shape(image_oper.shape, args, out_c, transpose) + out_oper = image_oper._replace( + shape=out_shape, + scale=out_scale, + zero_point=out_zero_point, + ) + + use_nchw = image_oper.use_nchw() + + if depthwise: + num_args = 12 + opcode = NNAPI_OperationCode.DEPTHWISE_CONV_2D + else: + num_args = 11 + if transpose: + opcode = NNAPI_OperationCode.TRANSPOSE_CONV_2D + else: + opcode = NNAPI_OperationCode.CONV_2D + + inputs = [None] * num_args + inputs[0] = image_id + inputs[1] = weight_id + inputs[2] = bias_id + inputs[3] = self.add_immediate_int_scalar(args.pad_l) + inputs[4] = self.add_immediate_int_scalar(args.pad_r) + inputs[5] = self.add_immediate_int_scalar(args.pad_t) + inputs[6] = self.add_immediate_int_scalar(args.pad_b) + inputs[7] = self.add_immediate_int_scalar(args.stride_w) + inputs[8] = self.add_immediate_int_scalar(args.stride_h) + if depthwise: + inputs[9] = self.add_immediate_int_scalar(1) + inputs[10] = self.add_immediate_int_scalar(fuse_code) + inputs[11] = self.add_immediate_bool_scalar(use_nchw) + else: + inputs[9] = self.add_immediate_int_scalar(fuse_code) + inputs[10] = self.add_immediate_bool_scalar(use_nchw) + + outputs = [None] * 1 + outputs[0] = self.add_tensor_operand(jit_out, out_oper) + + self.add_operation(opcode, inputs, outputs) + + + +def serialize_model(module, inputs, config=None): + return _NnapiSerializer(config).serialize_model(module, inputs)