Skip to content

Commit

Permalink
[PyTorchEdge] backport v8 to v7 to support promoted ops as instructio…
Browse files Browse the repository at this point in the history
…n (#71662)

Summary:
Pull Request resolved: pytorch/pytorch#71662

backport v8 to v7 to support promoted ops as instruction

a flag to help export as instruction from v8 and export as operators for v7 and below

Test Plan:
```
buck test caffe2/test/cpp/jit:jit -- LiteInterpreterTest.BackPortByteCodeModelAllVersions

Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/5629499620570927
    ✓ ListingSuccess: caffe2/test/cpp/jit:jit : 461 tests discovered (15.693)
    ✓ Pass: caffe2/test/cpp/jit:jit - LiteInterpreterTest.BackPortByteCodeModelAllVersions (2.712)
Summary
  Pass: 1
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/5629499620570927
```

```
buck run mode/opt //caffe2/torch/fb/mobile/upgrader_codegen:upgrader_codegen

buck test mode/opt //caffe2/test:upgrader_codegen -- mobile.test_upgrader_codegen.TestLiteScriptModule
Parsing buck files: finished in 0.8 sec
Downloaded 0/2 artifacts, 0.00 bytes, 100.0% cache miss (for updated rules)
Building: finished in 01:39.4 min (100%) 11031/11031 jobs, 2/11031 updated
  Total time: 01:40.2 min
More details at https://www.internalfb.com/intern/buck/build/a8b0e417-019c-44ba-be6b-23379411a965
BUILD SUCCEEDED
Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details.
Running with tpx session id: 44fbfa66-cce8-4277-82ac-f89d79558581
Trace available for this run at /tmp/tpx-20220202-160956.915412/trace.log
RemoteExecution session id: reSessionID-44fbfa66-cce8-4277-82ac-f89d79558581-tpx
Started reporting to test run: https://www.internalfb.com/intern/testinfra/testrun/281475200877601
    ✓ ListingSuccess: caffe2/test:upgrader_codegen : 1 tests discovered (1.249)
    ✓ Pass: caffe2/test:upgrader_codegen - test_generate_bytecode (mobile.test_upgrader_codegen.TestLiteScriptModule) (1.365)
Summary
  Pass: 1
  ListingSuccess: 1
If you need help understanding your runs, please follow the wiki: https://fburl.com/posting_in_tpx_users
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/281475200877601
```

Reviewed By: iseeyuan

Differential Revision: D33719098

fbshipit-source-id: e2d2b23d298f98e4d4fcdfc344f7b8c6f92cff26
(cherry picked from commit 81b956c23abc19489b69eee986721252474d00dc)
  • Loading branch information
pavithranrao authored and cyyever committed Feb 16, 2022
1 parent 6507a16 commit a654e20
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 53 deletions.
32 changes: 19 additions & 13 deletions caffe2/serialize/versions.h
Expand Up @@ -110,22 +110,28 @@ constexpr uint64_t kMinProducedFileFormatVersion = 0x3L;
// 0x2L: (Comment missing)
// 0x3L: (Comment missing)
// 0x4L: (update) Added schema to function tuple. Forward-compatible change.
// 0x5L: (update) Update bytecode is sharing constant tensor files from torchscript, and only serialize
// extra tensors that are not in the torchscript constant table. Also update tensor storage schema adapting
// to the unify format, the root key of tensor storage is updated from {index} to
// {the_pointer_value_the_tensor.storage}, for example: `140245072983168.storage`
// Forward-compatibility change.
// 0x6L: Implicit opereator versioning using number of specified argument.
// Refer to the summary of https://github.com/pytorch/pytorch/pull/56845
// for details.
// 0x7L: Enable support for operators with default arguments plus out arguments.
constexpr uint64_t kProducedBytecodeVersion = 0x7L;
// 0x5L: (update) Update bytecode is sharing constant tensor files from
// torchscript, and only serialize extra tensors that are not in the
// torchscript constant table. Also update tensor storage schema adapting to
// the unify format, the root key of tensor storage is updated from {index} to
// {the_pointer_value_the_tensor.storage}, for example:
// `140245072983168.storage` Forward-compatibility change. 0x6L: Implicit
// opereator versioning using number of specified argument. Refer to the
// summary of https://github.com/pytorch/pytorch/pull/56845 for details. 0x7L:
// Enable support for operators with default arguments plus out arguments.
// 0x8L: Emit promoted operators as instructions
constexpr uint64_t kProducedBytecodeVersion = 0x8L;

// static_assert(
// kProducedBytecodeVersion >= kProducedFileFormatVersion,
// "kProducedBytecodeVersion must be higher or equal to
// kProducedFileFormatVersion.");

// Introduce kMinSupportedBytecodeVersion and kMaxSupportedBytecodeVersion
// for limited backward/forward compatibility support of bytecode. If
// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion (in loader),
// we should support this model_version. For example, we provide a wrapper to
// handle an updated operator.
// kMinSupportedBytecodeVersion <= model_version <= kMaxSupportedBytecodeVersion
// (in loader), we should support this model_version. For example, we provide a
// wrapper to handle an updated operator.
constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L;
constexpr uint64_t kMaxSupportedBytecodeVersion = 0x8L;

Expand Down
63 changes: 52 additions & 11 deletions test/cpp/jit/test_lite_interpreter.cpp
Expand Up @@ -571,19 +571,34 @@ namespace {

void compareModelOutput(
c10::ArrayRef<IValue> actual_result_list,
const std::vector<Tensor>& expect_result_list) {
const std::vector<IValue>& expect_result_list) {
AT_ASSERT(actual_result_list.size() == expect_result_list.size());
AT_ASSERT(actual_result_list[0].toTensor().equal(expect_result_list[0]));
AT_ASSERT(
actual_result_list[1].toTensor().dim() == expect_result_list[1].dim());
AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2]));
AT_ASSERT(actual_result_list[3].toTensor().equal(expect_result_list[3]));
actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor()));
AT_ASSERT(
actual_result_list[1].toTensor().dim() ==
expect_result_list[1].toTensor().dim());
AT_ASSERT(
actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor()));
AT_ASSERT(
actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor()));
ASSERT_EQ(
actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef());
ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool());
ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool());
ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool());
AT_ASSERT(
actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor()));
ASSERT_EQ(
actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef());
ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt());
ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool());
}

void runAndCheckTorchScriptModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
const std::vector<Tensor>& expect_result_list,
const std::vector<IValue>& expect_result_list,
const int64_t expect_version) {
auto actual_version = _get_model_bytecode_version(input_model_stream);
AT_ASSERT(actual_version == expect_version);
Expand All @@ -600,7 +615,7 @@ void runAndCheckTorchScriptModel(
void runAndCheckBytecodeModel(
std::stringstream& input_model_stream,
const std::vector<IValue>& input_data,
const std::vector<Tensor>& expect_result_list,
const std::vector<IValue>& expect_result_list,
const int64_t expect_version) {
auto actual_version = _get_model_bytecode_version(input_model_stream);
AT_ASSERT(actual_version == expect_version);
Expand All @@ -618,7 +633,7 @@ void runAndCheckBytecodeModel(
void backportAllVersionCheck(
std::stringstream& test_model_file_stream,
std::vector<IValue>& input_data,
std::vector<Tensor>& expect_result_list,
std::vector<IValue>& expect_result_list,
const int64_t expect_from_version) {
auto from_version = _get_model_bytecode_version(test_model_file_stream);
AT_ASSERT(from_version == expect_from_version);
Expand Down Expand Up @@ -668,6 +683,9 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
module.register_parameter("bias", torch::ones({20}), false);
module.define(R"(
def fn(self, x:float=1.0):
return x
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
Expand All @@ -677,21 +695,44 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
x = 2 * torch.ones(1)
h = torch.ones(1)
torch.add(x, h, out=x)
return (x1, x2, x3, x)
)");
device = torch.ones(1, 1).cpu().device.type
is_cuda = x1.is_cuda
bool_val = True
check_is = [] is None
check_is_not = [1] is not None
check_not = not bool_val
num_to_tensor = torch.tensor([self.fn()])
d = {"a": "abc"}
check_dict_index = d["a"]
check_dim = x1.dim()
return (
x1, x2, x3, x, device, is_cuda, check_is,
check_is_not, num_to_tensor, check_dict_index,
check_dim, check_not
)
)");

torch::jit::Module module_freeze = freeze(module);

std::stringstream input_model_stream;
module_freeze._save_for_mobile(input_model_stream);
std::vector<IValue> input_data =
std::vector<IValue>({torch::ones({1, 1, 28, 28})});
std::vector<Tensor> expect_result_list;
std::vector<IValue> expect_result_list;
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0);
expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float));
expect_result_list.emplace_back(
at::ones({1, 20, 24, 24}, ScalarType::Float) * 26);
expect_result_list.emplace_back(3 * at::ones({1}));
// "cpu" False, False, True, tensor(1), "abc", 2, False)
expect_result_list.emplace_back(c10::IValue("cpu"));
expect_result_list.emplace_back(c10::IValue(false));
expect_result_list.emplace_back(c10::IValue(false));
expect_result_list.emplace_back(c10::IValue(true));
expect_result_list.emplace_back(c10::IValue(at::ones({1})));
expect_result_list.emplace_back(c10::IValue("abc"));
expect_result_list.emplace_back(c10::IValue(2));
expect_result_list.emplace_back(c10::IValue(false));

backportAllVersionCheck(
input_model_stream,
Expand Down
4 changes: 2 additions & 2 deletions test/test_mobile_optimizer.py
Expand Up @@ -151,7 +151,7 @@ def forward(self, x):
bn_scripted_module = torch.jit.script(bn_test_module)
bn_scripted_module.eval()

self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 14)
self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(bn_scripted_module._c).graph))

Expand Down Expand Up @@ -252,7 +252,7 @@ def foo(self, x):
bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module)
bn_no_forward_scripted_module.eval()

self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 14)
self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11)
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(bn_no_forward_scripted_module.foo.graph)

Expand Down
41 changes: 39 additions & 2 deletions torch/csrc/jit/mobile/compatibility/backport_manager.cpp
Expand Up @@ -27,6 +27,7 @@ constexpr int64_t kBytecodeVersionV4 = 0x4L;
constexpr int64_t kBytecodeVersionV5 = 0x5L;
constexpr int64_t kBytecodeVersionV6 = 0x6L;
constexpr int64_t kBytecodeVersionV7 = 0x7L;
constexpr int64_t kBytecodeVersionV8 = 0x8L;
} // namespace

/********************** Utility Functions **********************/
Expand Down Expand Up @@ -434,7 +435,8 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
{
BytecodeEmitModeGuard argNumGuard(
true /*emit_default_input_instructions*/,
false /*enable_defaults_args_with_out_args*/);
false /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}
Expand Down Expand Up @@ -501,7 +503,8 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
{
BytecodeEmitModeGuard argNumGuard(
false /*emit_default_input_instructions*/,
false /*enable_defaults_args_with_out_args*/);
false /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}
Expand All @@ -512,6 +515,39 @@ std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
return output_model_stream;
}

std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) {
std::shared_ptr<IStreamAdapter> rai =
std::make_shared<IStreamAdapter>(&input_model_stream);
auto reader = std::make_shared<PyTorchStreamReader>(rai);
// extra_files are kept
auto records = reader->getAllRecords();
bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
ExtraFilesMap extra_files;
for (const auto& record : records) {
std::size_t found = record.find_last_of("/\\");
auto path = record.substr(0, found);
if ("extra" == path) {
extra_files.emplace(record.substr(found + 1), "");
}
}
Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files);
std::stringstream intermediate_model_stream;
{
BytecodeEmitModeGuard argNumGuard(
false /*emit_default_input_instructions*/,
true /*enable_defaults_args_with_out_args*/,
false /*enable_emit_promoted_ops*/);
torch_script._save_for_mobile(
intermediate_model_stream, extra_files, hasBytecodeDebug);
}

// Update the bytecode version (from 8 to 7)
std::stringstream output_model_stream =
update_bytecode_version(intermediate_model_stream, kBytecodeVersionV7);

return output_model_stream;
}

} // namespace

/********************** BackportManager **********************/
Expand All @@ -528,6 +564,7 @@ BackportManager::BackportManager() {
registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4);
registerBytecodeBackportFunction(kBytecodeVersionV6, backport_v6_to_v5);
registerBytecodeBackportFunction(kBytecodeVersionV7, backport_v7_to_v6);
registerBytecodeBackportFunction(kBytecodeVersionV8, backport_v8_to_v7);
}

std::unordered_map<
Expand Down

0 comments on commit a654e20

Please sign in to comment.