diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index fa18e46b2c6..40cbd9dea87 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -110,22 +110,28 @@ constexpr uint64_t kMinProducedFileFormatVersion = 0x3L; // 0x2L: (Comment missing) // 0x3L: (Comment missing) // 0x4L: (update) Added schema to function tuple. Forward-compatible change. -// 0x5L: (update) Update bytecode is sharing constant tensor files from torchscript, and only serialize -// extra tensors that are not in the torchscript constant table. Also update tensor storage schema adapting -// to the unify format, the root key of tensor storage is updated from {index} to -// {the_pointer_value_the_tensor.storage}, for example: `140245072983168.storage` -// Forward-compatibility change. -// 0x6L: Implicit opereator versioning using number of specified argument. -// Refer to the summary of https://github.com/pytorch/pytorch/pull/56845 -// for details. -// 0x7L: Enable support for operators with default arguments plus out arguments. -constexpr uint64_t kProducedBytecodeVersion = 0x7L; +// 0x5L: (update) Update bytecode is sharing constant tensor files from +// torchscript, and only serialize extra tensors that are not in the +// torchscript constant table. Also update tensor storage schema adapting to +// the unify format, the root key of tensor storage is updated from {index} to +// {the_pointer_value_the_tensor.storage}, for example: +// `140245072983168.storage` Forward-compatibility change. 0x6L: Implicit +// opereator versioning using number of specified argument. Refer to the +// summary of https://github.com/pytorch/pytorch/pull/56845 for details. 0x7L: +// Enable support for operators with default arguments plus out arguments. +// 0x8L: Emit promoted operators as instructions +constexpr uint64_t kProducedBytecodeVersion = 0x8L; + +// static_assert( +// kProducedBytecodeVersion >= kProducedFileFormatVersion, +// "kProducedBytecodeVersion must be higher or equal to +// kProducedFileFormatVersion."); // Introduce kMinSupportedBytecodeVersion and kMaxSupportedBytecodeVersion // for limited backward/forward compatibility support of bytecode. If -// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion (in loader), -// we should support this model_version. For example, we provide a wrapper to -// handle an updated operator. +// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion +// (in loader), we should support this model_version. For example, we provide a +// wrapper to handle an updated operator. constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L; constexpr uint64_t kMaxSupportedBytecodeVersion = 0x8L; diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 0e40e48514d..5e00eafa738 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -571,19 +571,34 @@ namespace { void compareModelOutput( c10::ArrayRef actual_result_list, - const std::vector& expect_result_list) { + const std::vector& expect_result_list) { AT_ASSERT(actual_result_list.size() == expect_result_list.size()); - AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0])); AT_ASSERT( - actual_result_list[1].toTensor().dim() == expect_result_list[1].dim()); - AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2])); - AT_ASSERT(actual_result_list[3].toTensor().equal(expect_result_list[3])); + actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor())); + AT_ASSERT( + actual_result_list[1].toTensor().dim() == + expect_result_list[1].toTensor().dim()); + AT_ASSERT( + actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor())); + AT_ASSERT( + actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor())); + ASSERT_EQ( + actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef()); + ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool()); + ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool()); + ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool()); + AT_ASSERT( + actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor())); + ASSERT_EQ( + actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef()); + ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt()); + ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool()); } void runAndCheckTorchScriptModel( std::stringstream& input_model_stream, const std::vector& input_data, - const std::vector& expect_result_list, + const std::vector& expect_result_list, const int64_t expect_version) { auto actual_version = _get_model_bytecode_version(input_model_stream); AT_ASSERT(actual_version == expect_version); @@ -600,7 +615,7 @@ void runAndCheckTorchScriptModel( void runAndCheckBytecodeModel( std::stringstream& input_model_stream, const std::vector& input_data, - const std::vector& expect_result_list, + const std::vector& expect_result_list, const int64_t expect_version) { auto actual_version = _get_model_bytecode_version(input_model_stream); AT_ASSERT(actual_version == expect_version); @@ -618,7 +633,7 @@ void runAndCheckBytecodeModel( void backportAllVersionCheck( std::stringstream& test_model_file_stream, std::vector& input_data, - std::vector& expect_result_list, + std::vector& expect_result_list, const int64_t expect_from_version) { auto from_version = _get_model_bytecode_version(test_model_file_stream); AT_ASSERT(from_version == expect_from_version); @@ -668,6 +683,9 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) module.register_parameter("bias", torch::ones({20}), false); module.define(R"( + def fn(self, x:float=1.0): + return x + def forward(self, input): x1 = torch.zeros(2, 2) x2 = torch.empty_like(torch.empty(2, 2)) @@ -677,8 +695,22 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) { x = 2 * torch.ones(1) h = torch.ones(1) torch.add(x, h, out=x) - return (x1, x2, x3, x) - )"); + device = torch.ones(1, 1).cpu().device.type + is_cuda = x1.is_cuda + bool_val = True + check_is = [] is None + check_is_not = [1] is not None + check_not = not bool_val + num_to_tensor = torch.tensor([self.fn()]) + d = {"a": "abc"} + check_dict_index = d["a"] + check_dim = x1.dim() + return ( + x1, x2, x3, x, device, is_cuda, check_is, + check_is_not, num_to_tensor, check_dict_index, + check_dim, check_not + ) + )"); torch::jit::Module module_freeze = freeze(module); @@ -686,12 +718,21 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) { module_freeze._save_for_mobile(input_model_stream); std::vector input_data = std::vector({torch::ones({1, 1, 28, 28})}); - std::vector expect_result_list; + std::vector expect_result_list; expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0); expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float)); expect_result_list.emplace_back( at::ones({1, 20, 24, 24}, ScalarType::Float) * 26); expect_result_list.emplace_back(3 * at::ones({1})); + // "cpu" False, False, True, tensor(1), "abc", 2, False) + expect_result_list.emplace_back(c10::IValue("cpu")); + expect_result_list.emplace_back(c10::IValue(false)); + expect_result_list.emplace_back(c10::IValue(false)); + expect_result_list.emplace_back(c10::IValue(true)); + expect_result_list.emplace_back(c10::IValue(at::ones({1}))); + expect_result_list.emplace_back(c10::IValue("abc")); + expect_result_list.emplace_back(c10::IValue(2)); + expect_result_list.emplace_back(c10::IValue(false)); backportAllVersionCheck( input_model_stream, diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index bb42702f536..9b728df8808 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -151,7 +151,7 @@ def forward(self, x): bn_scripted_module = torch.jit.script(bn_test_module) bn_scripted_module.eval() - self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14) + self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11) FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ .run(str(get_forward(bn_scripted_module._c).graph)) @@ -252,7 +252,7 @@ def foo(self, x): bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module) bn_no_forward_scripted_module.eval() - self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14) + self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11) FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \ .run(bn_no_forward_scripted_module.foo.graph) diff --git a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp index 5cf6fea9701..09377093e0b 100644 --- a/torch/csrc/jit/mobile/compatibility/backport_manager.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport_manager.cpp @@ -27,6 +27,7 @@ constexpr int64_t kBytecodeVersionV4 = 0x4L; constexpr int64_t kBytecodeVersionV5 = 0x5L; constexpr int64_t kBytecodeVersionV6 = 0x6L; constexpr int64_t kBytecodeVersionV7 = 0x7L; +constexpr int64_t kBytecodeVersionV8 = 0x8L; } // namespace /********************** Utility Functions **********************/ @@ -434,7 +435,8 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) { { BytecodeEmitModeGuard argNumGuard( true /*emit_default_input_instructions*/, - false /*enable_defaults_args_with_out_args*/); + false /*enable_defaults_args_with_out_args*/, + false /*enable_emit_promoted_ops*/); torch_script._save_for_mobile( intermediate_model_stream, extra_files, hasBytecodeDebug); } @@ -501,7 +503,8 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) { { BytecodeEmitModeGuard argNumGuard( false /*emit_default_input_instructions*/, - false /*enable_defaults_args_with_out_args*/); + false /*enable_defaults_args_with_out_args*/, + false /*enable_emit_promoted_ops*/); torch_script._save_for_mobile( intermediate_model_stream, extra_files, hasBytecodeDebug); } @@ -512,6 +515,39 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) { return output_model_stream; } +std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) { + std::shared_ptr rai = + std::make_shared(&input_model_stream); + auto reader = std::make_shared(rai); + // extra_files are kept + auto records = reader->getAllRecords(); + bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl"); + ExtraFilesMap extra_files; + for (const auto& record : records) { + std::size_t found = record.find_last_of("/\\"); + auto path = record.substr(0, found); + if ("extra" == path) { + extra_files.emplace(record.substr(found + 1), ""); + } + } + Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files); + std::stringstream intermediate_model_stream; + { + BytecodeEmitModeGuard argNumGuard( + false /*emit_default_input_instructions*/, + true /*enable_defaults_args_with_out_args*/, + false /*enable_emit_promoted_ops*/); + torch_script._save_for_mobile( + intermediate_model_stream, extra_files, hasBytecodeDebug); + } + + // Update the bytecode version (from 8 to 7) + std::stringstream output_model_stream = + update_bytecode_version(intermediate_model_stream, kBytecodeVersionV7); + + return output_model_stream; +} + } // namespace /********************** BackportManager **********************/ @@ -528,6 +564,7 @@ BackportManager::BackportManager() { registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4); registerBytecodeBackportFunction(kBytecodeVersionV6, backport_v6_to_v5); registerBytecodeBackportFunction(kBytecodeVersionV7, backport_v7_to_v6); + registerBytecodeBackportFunction(kBytecodeVersionV8, backport_v8_to_v7); } std::unordered_map< diff --git a/torch/csrc/jit/mobile/upgrader_mobile.cpp b/torch/csrc/jit/mobile/upgrader_mobile.cpp index 83e23342d5c..3e876de4766 100644 --- a/torch/csrc/jit/mobile/upgrader_mobile.cpp +++ b/torch/csrc/jit/mobile/upgrader_mobile.cpp @@ -346,7 +346,7 @@ const std::vector& getUpgraderBytecodeList() { Instruction{OpCode::STOREN, 1, 7}, Instruction{OpCode::LOAD, 3, 0}, Instruction{OpCode::LOADC, 0, 0}, - Instruction{OpCode::OP, 0, 0}, + Instruction{OpCode::__IS__, 0, 0}, Instruction{OpCode::JF, 10, 0}, Instruction{OpCode::LOAD, 1, 0}, Instruction{OpCode::LOAD, 2, 0}, @@ -355,17 +355,17 @@ const std::vector& getUpgraderBytecodeList() { Instruction{OpCode::LOAD, 5, 0}, Instruction{OpCode::LOAD, 6, 0}, Instruction{OpCode::LOAD, 7, 0}, - Instruction{OpCode::OP, 1, 0}, + Instruction{OpCode::OP, 0, 0}, Instruction{OpCode::JMP, 10, 0}, Instruction{OpCode::LOAD, 1, 0}, Instruction{OpCode::LOAD, 2, 0}, Instruction{OpCode::LOAD, 3, 0}, - Instruction{OpCode::OP, 2, 0}, + Instruction{OpCode::OP, 1, 0}, Instruction{OpCode::LOAD, 4, 0}, Instruction{OpCode::LOAD, 5, 0}, Instruction{OpCode::LOAD, 6, 0}, Instruction{OpCode::LOAD, 7, 0}, - Instruction{OpCode::OP, 1, 0}, + Instruction{OpCode::OP, 0, 0}, Instruction{OpCode::STORE, 8, 0}, Instruction{OpCode::DROPR, 7, 0}, Instruction{OpCode::DROPR, 6, 0}, @@ -385,7 +385,6 @@ const std::vector& getUpgraderBytecodeList() { 8 ), std::vector({ - OperatorString({"aten::__is__", "", 2}), OperatorString({"aten::linspace", "", 7}), OperatorString({"prim::unchecked_cast", "", 1}), }), // operators list @@ -397,20 +396,20 @@ const std::vector& getUpgraderBytecodeList() { Instruction{OpCode::STOREN, 1, 4}, Instruction{OpCode::LOAD, 3, 0}, Instruction{OpCode::LOADC, 0, 0}, - Instruction{OpCode::OP, 0, 0}, + Instruction{OpCode::__IS__, 0, 0}, Instruction{OpCode::JF, 7, 0}, Instruction{OpCode::LOAD, 1, 0}, Instruction{OpCode::LOAD, 2, 0}, Instruction{OpCode::LOADC, 1, 0}, Instruction{OpCode::LOAD, 4, 0}, - Instruction{OpCode::OP, 1, 0}, + Instruction{OpCode::OP, 0, 0}, Instruction{OpCode::JMP, 7, 0}, Instruction{OpCode::LOAD, 1, 0}, Instruction{OpCode::LOAD, 2, 0}, Instruction{OpCode::LOAD, 3, 0}, - Instruction{OpCode::OP, 2, 0}, - Instruction{OpCode::LOAD, 4, 0}, Instruction{OpCode::OP, 1, 0}, + Instruction{OpCode::LOAD, 4, 0}, + Instruction{OpCode::OP, 0, 0}, Instruction{OpCode::STORE, 5, 0}, Instruction{OpCode::DROPR, 4, 0}, Instruction{OpCode::DROPR, 2, 0}, @@ -427,7 +426,6 @@ const std::vector& getUpgraderBytecodeList() { 5 ), std::vector({ - OperatorString({"aten::__is__", "", 2}), OperatorString({"aten::linspace", "out", 4}), OperatorString({"prim::unchecked_cast", "", 1}), }), // operators list @@ -439,7 +437,7 @@ const std::vector& getUpgraderBytecodeList() { Instruction{OpCode::STOREN, 1, 8}, Instruction{OpCode::LOAD, 3, 0}, Instruction{OpCode::LOADC, 0, 0}, - Instruction{OpCode::OP, 0, 0}, + Instruction{OpCode::__IS__, 0, 0}, Instruction{OpCode::JF, 11, 0}, Instruction{OpCode::LOAD, 1, 0}, Instruction{OpCode::LOAD, 2, 0}, @@ -449,18 +447,18 @@ const std::vector& getUpgraderBytecodeList() { Instruction{OpCode::LOAD, 6, 0}, Instruction{OpCode::LOAD, 7, 0}, Instruction{OpCode::LOAD, 8, 0}, - Instruction{OpCode::OP, 1, 0}, + Instruction{OpCode::OP, 0, 0}, Instruction{OpCode::JMP, 11, 0}, Instruction{OpCode::LOAD, 1, 0}, Instruction{OpCode::LOAD, 2, 0}, Instruction{OpCode::LOAD, 3, 0}, - Instruction{OpCode::OP, 2, 0}, + Instruction{OpCode::OP, 1, 0}, Instruction{OpCode::LOAD, 4, 0}, Instruction{OpCode::LOAD, 5, 0}, Instruction{OpCode::LOAD, 6, 0}, Instruction{OpCode::LOAD, 7, 0}, Instruction{OpCode::LOAD, 8, 0}, - Instruction{OpCode::OP, 1, 0}, + Instruction{OpCode::OP, 0, 0}, Instruction{OpCode::STORE, 9, 0}, Instruction{OpCode::DROPR, 8, 0}, Instruction{OpCode::DROPR, 7, 0}, @@ -481,7 +479,6 @@ const std::vector& getUpgraderBytecodeList() { 9 ), std::vector({ - OperatorString({"aten::__is__", "", 2}), OperatorString({"aten::logspace", "", 8}), OperatorString({"prim::unchecked_cast", "", 1}), }), // operators list @@ -493,22 +490,22 @@ const std::vector& getUpgraderBytecodeList() { Instruction{OpCode::STOREN, 1, 5}, Instruction{OpCode::LOAD, 3, 0}, Instruction{OpCode::LOADC, 0, 0}, - Instruction{OpCode::OP, 0, 0}, + Instruction{OpCode::__IS__, 0, 0}, Instruction{OpCode::JF, 8, 0}, Instruction{OpCode::LOAD, 1, 0}, Instruction{OpCode::LOAD, 2, 0}, Instruction{OpCode::LOADC, 1, 0}, Instruction{OpCode::LOAD, 4, 0}, Instruction{OpCode::LOAD, 5, 0}, - Instruction{OpCode::OP, 1, 0}, + Instruction{OpCode::OP, 0, 0}, Instruction{OpCode::JMP, 8, 0}, Instruction{OpCode::LOAD, 1, 0}, Instruction{OpCode::LOAD, 2, 0}, Instruction{OpCode::LOAD, 3, 0}, - Instruction{OpCode::OP, 2, 0}, + Instruction{OpCode::OP, 1, 0}, Instruction{OpCode::LOAD, 4, 0}, Instruction{OpCode::LOAD, 5, 0}, - Instruction{OpCode::OP, 1, 0}, + Instruction{OpCode::OP, 0, 0}, Instruction{OpCode::STORE, 6, 0}, Instruction{OpCode::DROPR, 5, 0}, Instruction{OpCode::DROPR, 4, 0}, @@ -526,7 +523,6 @@ const std::vector& getUpgraderBytecodeList() { 6 ), std::vector({ - OperatorString({"aten::__is__", "", 2}), OperatorString({"aten::logspace", "out", 5}), OperatorString({"prim::unchecked_cast", "", 1}), }), // operators list diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index a01da0b8c05..e421815d7e7 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -1059,12 +1059,14 @@ MobileCode::MobileCode( std::string function_name, bool emit_default_input_instructions, bool support_default_args_before_out, + bool emit_promoted_ops, size_t remaining_bailout_depth) : Code(new interpreter::MobileCodeImpl( graph, std::move(function_name), emit_default_input_instructions, support_default_args_before_out, + emit_promoted_ops, remaining_bailout_depth)) {} MobileCode::~MobileCode() = default; diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 12441735ae6..19f997981f4 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -88,6 +88,7 @@ struct TORCH_API MobileCode : Code { std::string function_name, bool emit_default_input_instructions = true, bool support_default_args_before_out = true, + bool emit_promoted_ops = true, size_t remaining_bailout_depth = 0); ~MobileCode(); }; diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 03411a19632..63844c4e981 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -869,10 +869,12 @@ struct MobileCodeImpl : CodeImpl { std::string function_name, bool emit_default_input_instructions, bool support_default_args_before_out, + bool emit_promoted_ops, size_t remaining_bailout_depth) : CodeImpl(graph, function_name, remaining_bailout_depth, false), emit_default_input_instructions_(emit_default_input_instructions), - support_default_args_before_out_(support_default_args_before_out) { + support_default_args_before_out_(support_default_args_before_out), + emit_promoted_ops_(emit_promoted_ops) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) run(); } @@ -965,7 +967,6 @@ struct MobileCodeImpl : CodeImpl { int64_t X = 0, uint64_t N = 0, bool emit_inputs = true) override { - bool emit_promoted_ops_ = false; if (emit_promoted_ops_) { CodeImpl::emitOperatorOrInstruction(node, op, X, N, emit_inputs); } else { @@ -977,6 +978,8 @@ struct MobileCodeImpl : CodeImpl { bool emit_default_input_instructions_; // To support forward compatibility for bytecode version bump from v6 to v7 bool support_default_args_before_out_; + // To support forward compatibility for bytecode version bump from v7 to v8 + bool emit_promoted_ops_; }; } // namespace interpreter diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index b73817fb23c..17996a8ec05 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -201,6 +201,9 @@ struct TORCH_API BytecodeEmitMode { static bool is_default_args_before_out_args_enabled(); static void set_default_args_before_out_args_enabled(bool enabled); + + static bool is_emit_promoted_ops_enabled(); + static void set_default_emit_promoted_ops_enabled(bool enabled); }; // RAII guard to switch the way JIT emits the bytecode for inputs. @@ -216,24 +219,32 @@ struct TORCH_API BytecodeEmitMode { struct TORCH_API BytecodeEmitModeGuard { BytecodeEmitModeGuard( bool enable_default_value_for_unspecified_arg, - bool enable_default_args_before_out_args) + bool enable_default_args_before_out_args, + bool enable_emit_promoted_ops) : prev_default_value_for_unspecified_arg_mode( BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()), prev_default_args_before_out_args( - BytecodeEmitMode::is_default_args_before_out_args_enabled()) { + BytecodeEmitMode::is_default_args_before_out_args_enabled()), + prev_default_emit_promoted_ops( + BytecodeEmitMode::is_emit_promoted_ops_enabled()) { BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( enable_default_value_for_unspecified_arg); BytecodeEmitMode::set_default_args_before_out_args_enabled( enable_default_args_before_out_args); + BytecodeEmitMode::set_default_emit_promoted_ops_enabled( + enable_emit_promoted_ops); } ~BytecodeEmitModeGuard() { BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( prev_default_value_for_unspecified_arg_mode); BytecodeEmitMode::set_default_args_before_out_args_enabled( prev_default_args_before_out_args); + BytecodeEmitMode::set_default_emit_promoted_ops_enabled( + prev_default_emit_promoted_ops); } bool prev_default_value_for_unspecified_arg_mode; bool prev_default_args_before_out_args; + bool prev_default_emit_promoted_ops; }; TORCH_API IValue to_tuple(std::vector ivalues); diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp index 007e29ec7c3..cb2b104e039 100644 --- a/torch/csrc/jit/serialization/export_bytecode.cpp +++ b/torch/csrc/jit/serialization/export_bytecode.cpp @@ -142,7 +142,8 @@ mobile::Code compileGraphToMobileCode( graph, name, compilation_options.enable_default_value_for_unspecified_arg, - compilation_options.enable_default_args_before_out_args); + compilation_options.enable_default_args_before_out_args, + compilation_options.enable_emit_promoted_ops); mobile::Code mobile_code; diff --git a/torch/csrc/jit/serialization/export_bytecode.h b/torch/csrc/jit/serialization/export_bytecode.h index 4fb0b5043f5..96397a56eac 100644 --- a/torch/csrc/jit/serialization/export_bytecode.h +++ b/torch/csrc/jit/serialization/export_bytecode.h @@ -20,6 +20,7 @@ struct TORCH_API CompilationOptions { bool incl_interface_call = false; bool enable_default_value_for_unspecified_arg = false; bool enable_default_args_before_out_args = true; + bool enable_emit_promoted_ops = true; int model_version = caffe2::serialize::kProducedBytecodeVersion; }; diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 23bd357130f..cbfe143c0e7 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -44,6 +44,8 @@ CompilationOptions getOptionsFromGlobal() { BytecodeEmitMode::is_default_args_before_out_args_enabled(); compilation_options.enable_default_value_for_unspecified_arg = BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled(); + compilation_options.enable_emit_promoted_ops = + BytecodeEmitMode::is_emit_promoted_ops_enabled(); compilation_options.incl_interface_call = getMobileInterfaceCallExport(); compilation_options.model_version = caffe2::serialize::kProducedBytecodeVersion; @@ -864,5 +866,14 @@ void BytecodeEmitMode::set_default_args_before_out_args_enabled(bool enabled) { emitDefautlArgsWithOutArgs = enabled; } +thread_local bool emitDefaultEmitPromotedOps = + caffe2::serialize::kProducedBytecodeVersion <= 7 ? false : true; +bool BytecodeEmitMode::is_emit_promoted_ops_enabled() { + return emitDefaultEmitPromotedOps; +} +void BytecodeEmitMode::set_default_emit_promoted_ops_enabled(bool enabled) { + emitDefaultEmitPromotedOps = enabled; +} + } // namespace jit } // namespace torch