diff --git a/README.md b/README.md index 98a6ba3fdde..69dd7709531 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,37 @@ cd /src/mlir-npcomp cmake --build /build/npcomp --target check-npcomp check-frontends-pytorch ``` +### IREE Backend (from IREE packages) + +```shell +# We currently track and require the latest snapshot. +pip3 install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases + + + +# Run TorchScript E2E tests targeting IREE. +# Make sure to run "PyTorch Frontend" setup instructions first. +python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree +``` + +### IREE Backend (from local IREE build) + +This configuration is useful for iterating locally, as you can +poke/debug/rebuild things in IREE. + +```shell +# Locally build IREE. +# See https://google.github.io/iree/building-from-source/getting-started/ +# Make sure IREE is configured with `-DIREE_BUILD_PYTHON_BINDINGS=ON`. + +echo 'PYTHONPATH="${PYTHONPATH}:/path/to/iree-build/bindings/python"' >> .env + +# Run TorchScript E2E tests targeting IREE. +# Make sure to run "PyTorch Frontend" setup instructions first. +python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree +``` + + ### VSCode with a Docker Dev Image #### Start a docker dev container based on our image diff --git a/frontends/pytorch/e2e_testing/torchscript/main.py b/frontends/pytorch/e2e_testing/torchscript/main.py index 97a94290c06..6bf3fa248cd 100644 --- a/frontends/pytorch/e2e_testing/torchscript/main.py +++ b/frontends/pytorch/e2e_testing/torchscript/main.py @@ -12,31 +12,39 @@ # Available test configs. from torch_mlir.torchscript.e2e_test.configs import ( - RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig + NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig ) +from npcomp.compiler.pytorch.backend import is_iree_enabled +IREE_ENABLED = is_iree_enabled() +if IREE_ENABLED: + from npcomp.compiler.pytorch.backend.iree import IreeNpcompBackend +from npcomp.compiler.pytorch.backend.refjit import RefjitNpcompBackend + +from .xfail_sets import XFAIL_SETS + # Import tests to register them in the global registry. -# TODO: Use a relative import. -# That requires invoking this file as a "package" though, which makes it -# not possible to just do `python main.py`. Instead, it requires something -# like `python -m tochscript.main` which is annoying because it can only -# be run from a specific directory. -# TODO: Find out best practices for python "main" files. -import basic -import vision_models -import mlp -import batchnorm -import quantized_models -import elementwise +# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking +# this script. +from . import basic +from . import vision_models +from . import mlp +from . import batchnorm +from . import quantized_models +from . import elementwise def _get_argparse(): + config_choices = ['native_torch', 'torchscript', 'refbackend'] + if IREE_ENABLED: + config_choices += ['iree'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('--config', - choices=['native_torch', 'torchscript', 'refbackend'], + choices=config_choices, default='refbackend', - help=''' + help=f''' Meaning of options: "refbackend": run through npcomp's RefBackend. +"iree"{'' if IREE_ENABLED else '(disabled)'}: run through npcomp's IREE backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). ''') @@ -54,7 +62,9 @@ def main(): # Find the selected config. if args.config == 'refbackend': - config = RefBackendTestConfig() + config = NpcompBackendTestConfig(RefjitNpcompBackend()) + elif args.config == 'iree': + config = NpcompBackendTestConfig(IreeNpcompBackend()) elif args.config == 'native_torch': config = NativeTorchTestConfig() elif args.config == 'torchscript': @@ -78,7 +88,8 @@ def main(): results = run_tests(tests, config) # Report the test results. - report_results(results, args.verbose) + failed = report_results(results, XFAIL_SETS[args.config], args.verbose) + sys.exit(1 if failed else 0) if __name__ == '__main__': main() diff --git a/frontends/pytorch/e2e_testing/torchscript/xfail_sets.py b/frontends/pytorch/e2e_testing/torchscript/xfail_sets.py new file mode 100644 index 00000000000..4ce42781461 --- /dev/null +++ b/frontends/pytorch/e2e_testing/torchscript/xfail_sets.py @@ -0,0 +1,29 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This file describes the sets of tests expected to fail for each config. +# This information is deliberately kept in a side table, rather than +# in-situ on the test, as a deliberate layering decision: tests should +# have unique keys to identify them and enable side tables of various kinds +# (this includes down into lower parts of the stack, where a side table +# might be used to keep more elaborate sets of testing configurations). + +XFAIL_SETS = {} + +# Lists of tests that fail to even reach the backends. +# These represent further work needed in npcomp to lower them properly +# to the backend contract. +_common_npcomp_lowering_xfails = { + 'ResNet18Module_basic', + 'QuantizedMLP_basic', +} + +XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails + +XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails | { + # https://github.com/google/iree/issues/6368 + 'MmDagModule_basic', + 'Mlp1LayerModule_basic', + 'Mlp2LayerModule_basic', +} diff --git a/frontends/pytorch/examples/cos_e2e.py b/frontends/pytorch/examples/cos_e2e.py index 848dcbfe447..455f8191210 100644 --- a/frontends/pytorch/examples/cos_e2e.py +++ b/frontends/pytorch/examples/cos_e2e.py @@ -21,7 +21,7 @@ result = torch.cos(input) f.returns([result]) -backend = refjit.CompilerBackend() +backend = iree.IreeNpcompBackend() jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) logging.debug(f"Executing jit_module.cos") diff --git a/frontends/pytorch/examples/div_inplace_e2e.py b/frontends/pytorch/examples/div_inplace_e2e.py index 4e9513d474f..44e0e254c6f 100644 --- a/frontends/pytorch/examples/div_inplace_e2e.py +++ b/frontends/pytorch/examples/div_inplace_e2e.py @@ -25,7 +25,7 @@ def fun(a, b): with mb.capture_function("test", [arg0, arg1]) as f: f.returns([fun(arg0, arg1)]) -backend = refjit.CompilerBackend() +backend = iree.IreeNpcompBackend() jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1) diff --git a/frontends/pytorch/examples/mm_e2e.py b/frontends/pytorch/examples/mm_e2e.py index 433373bb00c..ac0956a98e5 100644 --- a/frontends/pytorch/examples/mm_e2e.py +++ b/frontends/pytorch/examples/mm_e2e.py @@ -22,7 +22,7 @@ result = torch.mm(lhs, rhs) f.returns([result]) -backend = refjit.CompilerBackend() +backend = iree.IreeNpcompBackend() jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs) diff --git a/frontends/pytorch/examples/mul_maximum_e2e.py b/frontends/pytorch/examples/mul_maximum_e2e.py index 4f23f9cc45c..58f06e69ba2 100644 --- a/frontends/pytorch/examples/mul_maximum_e2e.py +++ b/frontends/pytorch/examples/mul_maximum_e2e.py @@ -28,7 +28,7 @@ def mul_maximum(lhs, rhs, threshold, bias): result = mul_maximum(lhs, rhs, threshold, bias) f.returns([result]) -backend = refjit.CompilerBackend() +backend = iree.IreeNpcompBackend() jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs, diff --git a/frontends/pytorch/examples/tanh_out_e2e.py b/frontends/pytorch/examples/tanh_out_e2e.py index 064deebd657..15837521753 100644 --- a/frontends/pytorch/examples/tanh_out_e2e.py +++ b/frontends/pytorch/examples/tanh_out_e2e.py @@ -26,7 +26,7 @@ def fun(a): with mb.capture_function("test", [arg0]) as f: f.returns([fun(arg0)]) -backend = refjit.CompilerBackend() +backend = iree.IreeNpcompBackend() jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module))) test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1) diff --git a/frontends/pytorch/examples/torchscript_mm_e2e.py b/frontends/pytorch/examples/torchscript_mm_e2e.py index fd7fb106c08..6ed576f0256 100644 --- a/frontends/pytorch/examples/torchscript_mm_e2e.py +++ b/frontends/pytorch/examples/torchscript_mm_e2e.py @@ -48,7 +48,7 @@ def forward(self, lhs, rhs): mb.import_module(recursivescriptmodule._c, class_annotator) #mb.module.operation.print() -backend = refjit.CompilerBackend() +backend = iree.IreeNpcompBackend() compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module)) jit_module = backend.load(compiled) diff --git a/frontends/pytorch/examples/torchscript_tanh_e2e.py b/frontends/pytorch/examples/torchscript_tanh_e2e.py index 21b28934aa1..964765bb666 100644 --- a/frontends/pytorch/examples/torchscript_tanh_e2e.py +++ b/frontends/pytorch/examples/torchscript_tanh_e2e.py @@ -40,7 +40,7 @@ def forward(self, x): mb.import_module(recursivescriptmodule._c, class_annotator) #mb.module.operation.print() -backend = refjit.CompilerBackend() +backend = iree.IreeNpcompBackend() compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module)) jit_module = backend.load(compiled) diff --git a/frontends/pytorch/examples/torchscript_tanh_e2e_iree.py b/frontends/pytorch/examples/torchscript_tanh_e2e_iree.py index 71ab71ddffd..9d2863307df 100644 --- a/frontends/pytorch/examples/torchscript_tanh_e2e_iree.py +++ b/frontends/pytorch/examples/torchscript_tanh_e2e_iree.py @@ -34,13 +34,13 @@ def forward(self, x): class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward']) class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [ None, - ([2, 3, -1], torch.float32) + ([2, 3, -1], torch.float32, True) ]) # TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule. mb.import_module(recursivescriptmodule._c, class_annotator) #mb.module.operation.print() -backend = iree.CompilerBackend() +backend = iree.IreeNpcompBackend() compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module)) jit_module = backend.load(compiled) diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/__init__.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/__init__.py index 935aca84a62..44f85c719d8 100644 --- a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/__init__.py +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/__init__.py @@ -2,6 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .ref_backend import RefBackendTestConfig +from .npcomp_backend import NpcompBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/npcomp_backend.py similarity index 64% rename from frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py rename to frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/npcomp_backend.py index 9e19d5f48fa..09759b66c71 100644 --- a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/ref_backend.py +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/configs/npcomp_backend.py @@ -14,15 +14,32 @@ import torch_mlir from npcomp.compiler.pytorch.backend import refjit +from npcomp.compiler.pytorch.backend.abc import NpcompBackend from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir.torchscript.annotations import extract_annotations +class PrettyErrorReportForIrOperation(object): + def __init__(self, module, module_name_for_ir_dump: str): + sys.stderr = StringIO() + self.filename_for_ir_dump = os.path.join(tempfile.gettempdir(), + module_name_for_ir_dump + '.mlir') + self.asm_for_error_report = module.get_asm( + large_elements_limit=10, enable_debug_info=True) + def __enter__(self): + pass + def __exit__(self, type, value, traceback): + with open(self.filename_for_ir_dump, 'w') as f: + f.write(self.asm_for_error_report) -class RefBackendTestConfig(TestConfig): - """TestConfig that just runs the torch.nn.Module through RefBackend.""" - def __init__(self): +class NpcompBackendTestConfig(TestConfig): + """Base class for TestConfig's that are implemented with npcomp. + + This class handles all the common lowering that npcomp does before reaching + its backends. + """ + def __init__(self, backend: NpcompBackend): super().__init__() - self.backend = refjit.CompilerBackend() + self.backend = backend def compile(self, program: torch.nn.Module) -> Any: mb = torch_mlir.ModuleBuilder() @@ -79,14 +96,36 @@ def compile(self, program: torch.nn.Module) -> Any: """) from None finally: sys.stderr = sys.__stderr__ - return self.backend.compile(mb.module) + try: + sys.stderr = StringIO() + asm_for_error_report = mb.module.operation.get_asm( + large_elements_limit=10, enable_debug_info=True) + return self.backend.compile(mb.module) + except Exception as e: + filename = os.path.join(tempfile.gettempdir(), + scripted.original_name + '.mlir') + with open(filename, 'w') as f: + f.write(asm_for_error_report) + raise Exception(f""" +NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics: +## Exception: +{e} + +## Stderr: +{sys.stderr.getvalue()} + +## Input IR has been saved in {filename} +""") from None + finally: + sys.stderr = sys.__stderr__ + def run(self, artifact: Any, trace: Trace) -> Trace: - jit_module = self.backend.load(artifact) + backend_module = self.backend.load(artifact) result: Trace = [] for item in trace: numpy_inputs = [t.numpy() for t in item.inputs] - outputs = getattr(jit_module, item.symbol)(*numpy_inputs) + outputs = getattr(backend_module, item.symbol)(*numpy_inputs) if isinstance(outputs, np.ndarray): outputs = [outputs] torch_outputs = [torch.tensor(ndarray) for ndarray in outputs] diff --git a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py index 1301f33c759..5934f56d4ba 100644 --- a/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py +++ b/frontends/pytorch/python/torch_mlir/torchscript/e2e_test/reporting.py @@ -5,8 +5,9 @@ Utilities for reporting the results of the test framework. """ -from typing import List, Optional +from typing import List, Optional, Set +import collections import io import textwrap @@ -70,7 +71,7 @@ def error_str(self): assert self.failed if self.value.size() != self.golden_value.size(): return self.context.format_error( - f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}' + f'tensor shape mismatch: got {self.value.size()!r}, expected {self.golden_value.size()!r}' ) f = io.StringIO() p = lambda *x: print(*x, file=f) @@ -167,17 +168,60 @@ def error_str(self): return f.getvalue() -def report_results(results: List[TestResult], verbose: bool = False): - """Provide a basic error report summarizing various TestResult's. +def report_results(results: List[TestResult], + expected_failures: Set[str], + verbose: bool = False): + """Print a basic error report summarizing various TestResult's. + + This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's + "lit" testing utility. See + https://llvm.org/docs/CommandGuide/lit.html#test-status-results + + The `expected_failures` set should contain the names of tests + (according to their `unique_name`) which are expected to fail. + The overall passing/failing status of the report requires these to fail + in order to succeed (this catches cases where things suddenly + start working). If `verbose` is True, then provide an explanation of what failed. + + Returns True if the run resulted in any unexpected pass/fail behavior. + Otherwise False. """ + summary = collections.Counter() for result in results: report = SingleTestReport(result, ErrorContext.empty()) - if not report.failed: - print(f'SUCCESS - "{result.unique_name}"') + expected_failure = result.unique_name in expected_failures + if expected_failure: + if report.failed: + error_str = '' + if verbose: + error_str = '\n' + textwrap.indent(report.error_str(), ' ') + print(f'XFAIL - "{result.unique_name}"' + error_str) + summary['XFAIL'] += 1 + else: + print(f'XPASS - "{result.unique_name}"') + summary['XPASS'] += 1 else: - error_str = '' - if verbose: - error_str = '\n' + textwrap.indent(report.error_str(), ' ') - print(f'FAILURE - "{result.unique_name}"' + error_str) + if not report.failed: + print(f'PASS - "{result.unique_name}"') + summary['PASS'] += 1 + else: + error_str = '' + if verbose: + error_str = '\n' + textwrap.indent(report.error_str(), ' ') + print(f'FAIL - "{result.unique_name}"' + error_str) + summary['FAIL'] += 1 + + # Print a summary for easy scanning. + print('\nSummary:') + KEY_MEANINGS = { + 'PASS': 'Passed', + 'FAIL': 'Failed', + 'XFAIL': 'Expectedly Failed', + 'XPASS': 'Unexpectedly Passed', + } + for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']: + if summary[key]: + print(f' {KEY_MEANINGS[key]}: {summary[key]}') + return summary['FAIL'] != 0 or summary['XPASS'] != 0 diff --git a/frontends/pytorch/test/torchscript_e2e_test/basic.py b/frontends/pytorch/test/torchscript_e2e_test/basic.py index 762d6e7f1c9..7918f2bab32 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/basic.py +++ b/frontends/pytorch/test/torchscript_e2e_test/basic.py @@ -21,13 +21,13 @@ def forward(self, lhs, rhs): # TODO: Refine messages. -# CHECK: SUCCESS - "MmModule_basic" +# CHECK: PASS - "MmModule_basic" @register_test_case(module_factory=lambda: MmModule()) def MmModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 4), tu.rand(4, 4)) -# CHECK: SUCCESS - "MmModule_basic2" +# CHECK: PASS - "MmModule_basic2" @register_test_case(module_factory=lambda: MmModule()) def MmModule_basic2(module, tu: TestUtils): module.forward(tu.rand(4, 4), tu.rand(4, 4)) @@ -36,7 +36,7 @@ def MmModule_basic2(module, tu: TestUtils): def main(): config = TorchScriptTestConfig() results = run_tests(GLOBAL_TEST_REGISTRY, config) - report_results(results) + report_results(results, set()) if __name__ == '__main__': diff --git a/frontends/pytorch/test/torchscript_e2e_test/compilation_failure.py b/frontends/pytorch/test/torchscript_e2e_test/compilation_failure.py index fc4041cc0b3..0d36a4b6ac8 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/compilation_failure.py +++ b/frontends/pytorch/test/torchscript_e2e_test/compilation_failure.py @@ -25,7 +25,7 @@ def forward(self, t): return 3 -# CHECK: FAILURE - "MmModule_basic" +# CHECK: FAIL - "MmModule_basic" # CHECK: compilation error # Assume that the diagnostic from the TorchScript compiler will at least contain # the offending "return 3". @@ -38,7 +38,7 @@ def MmModule_basic(module, tu: TestUtils): def main(): config = TorchScriptTestConfig() results = run_tests(GLOBAL_TEST_REGISTRY, config) - report_results(results, verbose=True) + report_results(results, set(), verbose=True) if __name__ == '__main__': diff --git a/frontends/pytorch/test/torchscript_e2e_test/miscompile.py b/frontends/pytorch/test/torchscript_e2e_test/miscompile.py index a3f2488cf0d..64c8ab64dba 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/miscompile.py +++ b/frontends/pytorch/test/torchscript_e2e_test/miscompile.py @@ -26,7 +26,7 @@ def forward(self, lhs, rhs): # TODO: Refine error messages. -# CHECK: FAILURE - "MmModule_basic" +# CHECK: FAIL - "MmModule_basic" # CHECK: @ trace item #0 - call to "forward" # CHECK: @ output #0 # CHECK: ERROR: values mismatch @@ -40,7 +40,7 @@ def MmModule_basic(module, tu: TestUtils): def main(): config = TorchScriptTestConfig() results = run_tests(GLOBAL_TEST_REGISTRY, config) - report_results(results, verbose=True) + report_results(results, set(), verbose=True) if __name__ == '__main__': diff --git a/python/npcomp/compiler/pytorch/backend/__init__.py b/python/npcomp/compiler/pytorch/backend/__init__.py index e69de29bb2d..95558afbd9e 100644 --- a/python/npcomp/compiler/pytorch/backend/__init__.py +++ b/python/npcomp/compiler/pytorch/backend/__init__.py @@ -0,0 +1,7 @@ +def is_iree_enabled(): + try: + import iree.runtime + import iree.compiler + except: + return False + return True diff --git a/python/npcomp/compiler/pytorch/backend/abc.py b/python/npcomp/compiler/pytorch/backend/abc.py new file mode 100644 index 00000000000..a5f0c80b2f8 --- /dev/null +++ b/python/npcomp/compiler/pytorch/backend/abc.py @@ -0,0 +1,45 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import abc +from typing import TypeVar + +import torch + +from mlir.ir import Module + +# A type shared between the result of `NpcompBackend.compile` and the input +# to `NpcompBackend.load`. Each backend will likely have a different definition +# of this type. +CompiledArtifact = TypeVar('CompiledArtifact') + +# A wrapper around a backend-specific loaded program representation +# that uniformly translates the `x.method(...)` interface expected of +# Torch modules into appropriate lower-level operations. +Invoker = TypeVar('Invoker') + + +class NpcompBackend(abc.ABC): + """The interface to an npcomp backend. + """ + @abc.abstractmethod + def compile(self, module: Module) -> CompiledArtifact: + """Compile the provided MLIR module into a compiled artifact. + + The module adheres to the npcomp backend contract + (see the VerifyBackendContract pass). + + The compiled artifact can be any type, but must be correctly + interpreted by the `load` method. + """ + + @abc.abstractmethod + def load(self, artifact: CompiledArtifact) -> Invoker: + """Load the compiled artifact into a uniformly invokable form. + + The compiled artifact is the result of a previous call to `compile`. + + See the description of `Invoker` for the requirements on the returned + type. + """ diff --git a/python/npcomp/compiler/pytorch/backend/iree.py b/python/npcomp/compiler/pytorch/backend/iree.py index 6fd0c76b09b..2edb1ab0239 100644 --- a/python/npcomp/compiler/pytorch/backend/iree.py +++ b/python/npcomp/compiler/pytorch/backend/iree.py @@ -5,6 +5,7 @@ import os import torch +import numpy as np from mlir.ir import * from mlir.passmanager import * @@ -12,8 +13,10 @@ import iree.runtime as ireert import iree.compiler as ireec +from .abc import NpcompBackend + __all__ = [ - "CompilerBackend", + "IreeNpcompBackend", ] PREPARE_FOR_IREE_PASSES = ( @@ -34,6 +37,8 @@ def __getitem__(self, function_name): def invoke(*args): results = self._iree_module[function_name](*args) + if isinstance(results, np.ndarray): + return results if len(results) == 1: # De-tuple. return results[0] @@ -58,7 +63,7 @@ def invoke(*args): return invoke -class CompilerBackend: +class IreeNpcompBackend(NpcompBackend): """Main entry-point for the backend.""" def __init__(self): @@ -67,9 +72,8 @@ def __init__(self): def compile(self, imported_module: Module): """Compiles an imported module, with a flat list of functions. - The module is expected to be in "TCP + scalar code" form. - TODO: More clearly define the backend contract. Generally this will - extend to support globals, lists, and other stuff. + The module is expected to conform to the npcomp backend contract. + See the VerifyBackendContract pass for more details. Args: imported_module: The MLIR module consisting of funcs in the torch @@ -97,12 +101,13 @@ def compile(self, imported_module: Module): # Backend. binary = ireec.compile_str(str(imported_module), target_backends=["dylib-llvm-aot"]) - iree_config = ireert.Config(driver_name="dylib") - - iree_module = ireert.load_module(ireert.VmModule.from_flatbuffer(binary), - config=iree_config) - return iree_module + return binary def load(self, iree_module) -> TorchIreeModuleInvoker: """Loads a compiled artifact into the runtime.""" - return TorchIreeModuleInvoker(iree_module) + vm_module = ireert.VmModule.from_flatbuffer(iree_module) + + iree_config = ireert.Config(driver_name="dylib") + ctx = ireert.SystemContext(config=iree_config) + ctx.add_vm_module(vm_module) + return TorchIreeModuleInvoker(ctx.modules.module) diff --git a/python/npcomp/compiler/pytorch/backend/refjit.py b/python/npcomp/compiler/pytorch/backend/refjit.py index ece83ae6b2d..6eefaf9748b 100644 --- a/python/npcomp/compiler/pytorch/backend/refjit.py +++ b/python/npcomp/compiler/pytorch/backend/refjit.py @@ -10,10 +10,11 @@ from mlir.passmanager import * from npcomp.compiler.generic.backend import refjit as refjit_backend from npcomp.compiler.utils import logging +from .abc import NpcompBackend __all__ = [ "is_enabled", - "CompilerBackend", + "RefjitNpcompBackend", ] # Re-export. @@ -34,7 +35,7 @@ def invoke(*args): return invoke -class CompilerBackend: +class RefjitNpcompBackend(NpcompBackend): """Main entry-point for the backend.""" def __init__(self): diff --git a/tools/torchscript_e2e_test.sh b/tools/torchscript_e2e_test.sh new file mode 100755 index 00000000000..b1940f7a724 --- /dev/null +++ b/tools/torchscript_e2e_test.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -euo pipefail + +src_dir="$(realpath $(dirname $0)/..)" + +cd "$src_dir" +source .env +python -m frontends.pytorch.e2e_testing.torchscript.main "$@"