Skip to content

Commit

Permalink
Add IREE support in TorchScript e2e tests.
Browse files Browse the repository at this point in the history
- Add support for "expected failures" in test reporting. The new error
  reports look like
  [this](https://gist.github.com/silvasean/6ffd95e1d55302b699673da201da210d).
  - We will now be able to put these tests into CI, since the harness
    understand which tests are expected to pass and fail.
- Refactor RefBackendTestConfig to NpcompBackendTestConfig which
  supports both RefBackend and IREE.
- Add instructions for installing IREE dependencies (both from packages
  and for local builds of IREE)
- Add `tools/torchscript_e2e_test.sh` for invoking the e2e test
  harness (this makes invoking a bit easier, as it doesn't rely on a
  loose Python invocation).
  • Loading branch information
silvasean committed Jun 30, 2021
1 parent 79928cd commit d5108b9
Show file tree
Hide file tree
Showing 22 changed files with 284 additions and 64 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,37 @@ cd /src/mlir-npcomp
cmake --build /build/npcomp --target check-npcomp check-frontends-pytorch
```

### IREE Backend (from IREE packages)

```shell
# We currently track and require the latest snapshot.
pip3 install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases



# Run TorchScript E2E tests targeting IREE.
# Make sure to run "PyTorch Frontend" setup instructions first.
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
```

### IREE Backend (from local IREE build)

This configuration is useful for iterating locally, as you can
poke/debug/rebuild things in IREE.

```shell
# Locally build IREE.
# See https://google.github.io/iree/building-from-source/getting-started/
# Make sure IREE is configured with `-DIREE_BUILD_PYTHON_BINDINGS=ON`.

echo 'PYTHONPATH="${PYTHONPATH}:/path/to/iree-build/bindings/python"' >> .env

# Run TorchScript E2E tests targeting IREE.
# Make sure to run "PyTorch Frontend" setup instructions first.
python frontends/pytorch/e2e_testing/torchscript/main.py --config=iree
```


### VSCode with a Docker Dev Image

#### Start a docker dev container based on our image
Expand Down
45 changes: 28 additions & 17 deletions frontends/pytorch/e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,39 @@

# Available test configs.
from torch_mlir.torchscript.e2e_test.configs import (
RefBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
NpcompBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig
)

from npcomp.compiler.pytorch.backend import is_iree_enabled
IREE_ENABLED = is_iree_enabled()
if IREE_ENABLED:
from npcomp.compiler.pytorch.backend.iree import IreeNpcompBackend
from npcomp.compiler.pytorch.backend.refjit import RefjitNpcompBackend

from .xfail_sets import XFAIL_SETS

# Import tests to register them in the global registry.
# TODO: Use a relative import.
# That requires invoking this file as a "package" though, which makes it
# not possible to just do `python main.py`. Instead, it requires something
# like `python -m tochscript.main` which is annoying because it can only
# be run from a specific directory.
# TODO: Find out best practices for python "main" files.
import basic
import vision_models
import mlp
import batchnorm
import quantized_models
import elementwise
# Make sure to use `tools/torchscript_e2e_test.sh` wrapper for invoking
# this script.
from . import basic
from . import vision_models
from . import mlp
from . import batchnorm
from . import quantized_models
from . import elementwise

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend']
if IREE_ENABLED:
config_choices += ['iree']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('--config',
choices=['native_torch', 'torchscript', 'refbackend'],
choices=config_choices,
default='refbackend',
help='''
help=f'''
Meaning of options:
"refbackend": run through npcomp's RefBackend.
"iree"{'' if IREE_ENABLED else '(disabled)'}: run through npcomp's IREE backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
''')
Expand All @@ -54,7 +62,9 @@ def main():

# Find the selected config.
if args.config == 'refbackend':
config = RefBackendTestConfig()
config = NpcompBackendTestConfig(RefjitNpcompBackend())
elif args.config == 'iree':
config = NpcompBackendTestConfig(IreeNpcompBackend())
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
elif args.config == 'torchscript':
Expand All @@ -78,7 +88,8 @@ def main():
results = run_tests(tests, config)

# Report the test results.
report_results(results, args.verbose)
failed = report_results(results, XFAIL_SETS[args.config], args.verbose)
sys.exit(1 if failed else 0)

if __name__ == '__main__':
main()
29 changes: 29 additions & 0 deletions frontends/pytorch/e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# This file describes the sets of tests expected to fail for each config.
# This information is deliberately kept in a side table, rather than
# in-situ on the test, as a deliberate layering decision: tests should
# have unique keys to identify them and enable side tables of various kinds
# (this includes down into lower parts of the stack, where a side table
# might be used to keep more elaborate sets of testing configurations).

XFAIL_SETS = {}

# Lists of tests that fail to even reach the backends.
# These represent further work needed in npcomp to lower them properly
# to the backend contract.
_common_npcomp_lowering_xfails = {
'ResNet18Module_basic',
'QuantizedMLP_basic',
}

XFAIL_SETS['refbackend'] = _common_npcomp_lowering_xfails

XFAIL_SETS['iree'] = _common_npcomp_lowering_xfails | {
# https://github.com/google/iree/issues/6368
'MmDagModule_basic',
'Mlp1LayerModule_basic',
'Mlp2LayerModule_basic',
}
2 changes: 1 addition & 1 deletion frontends/pytorch/examples/cos_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
result = torch.cos(input)
f.returns([result])

backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))

logging.debug(f"Executing jit_module.cos")
Expand Down
2 changes: 1 addition & 1 deletion frontends/pytorch/examples/div_inplace_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def fun(a, b):
with mb.capture_function("test", [arg0, arg1]) as f:
f.returns([fun(arg0, arg1)])

backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))

test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
Expand Down
2 changes: 1 addition & 1 deletion frontends/pytorch/examples/mm_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
result = torch.mm(lhs, rhs)
f.returns([result])

backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))

test_utils.compare_outputs(torch.mm, jit_module.mm, lhs, rhs)
Expand Down
2 changes: 1 addition & 1 deletion frontends/pytorch/examples/mul_maximum_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def mul_maximum(lhs, rhs, threshold, bias):
result = mul_maximum(lhs, rhs, threshold, bias)
f.returns([result])

backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))

test_utils.compare_outputs(mul_maximum, jit_module.mul_maximum, lhs, rhs,
Expand Down
2 changes: 1 addition & 1 deletion frontends/pytorch/examples/tanh_out_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def fun(a):
with mb.capture_function("test", [arg0]) as f:
f.returns([fun(arg0)])

backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
jit_module = backend.load(backend.compile(frontend_lowering.lower_module(mb.module)))

test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
Expand Down
2 changes: 1 addition & 1 deletion frontends/pytorch/examples/torchscript_mm_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def forward(self, lhs, rhs):
mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print()

backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled)

Expand Down
2 changes: 1 addition & 1 deletion frontends/pytorch/examples/torchscript_tanh_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(self, x):
mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print()

backend = refjit.CompilerBackend()
backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled)

Expand Down
4 changes: 2 additions & 2 deletions frontends/pytorch/examples/torchscript_tanh_e2e_iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def forward(self, x):
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
class_annotator.annotateArgs(recursivescriptmodule._c._type(), ['forward'], [
None,
([2, 3, -1], torch.float32)
([2, 3, -1], torch.float32, True)
])
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print()

backend = iree.CompilerBackend()
backend = iree.IreeNpcompBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .ref_backend import RefBackendTestConfig
from .npcomp_backend import NpcompBackendTestConfig
from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,32 @@

import torch_mlir
from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.pytorch.backend.abc import NpcompBackend
from torch_mlir.torchscript.e2e_test.framework import TestConfig, Trace, TraceItem
from torch_mlir.torchscript.annotations import extract_annotations

class PrettyErrorReportForIrOperation(object):
def __init__(self, module, module_name_for_ir_dump: str):
sys.stderr = StringIO()
self.filename_for_ir_dump = os.path.join(tempfile.gettempdir(),
module_name_for_ir_dump + '.mlir')
self.asm_for_error_report = module.get_asm(
large_elements_limit=10, enable_debug_info=True)
def __enter__(self):
pass
def __exit__(self, type, value, traceback):
with open(self.filename_for_ir_dump, 'w') as f:
f.write(self.asm_for_error_report)

class RefBackendTestConfig(TestConfig):
"""TestConfig that just runs the torch.nn.Module through RefBackend."""
def __init__(self):
class NpcompBackendTestConfig(TestConfig):
"""Base class for TestConfig's that are implemented with npcomp.
This class handles all the common lowering that npcomp does before reaching
its backends.
"""
def __init__(self, backend: NpcompBackend):
super().__init__()
self.backend = refjit.CompilerBackend()
self.backend = backend

def compile(self, program: torch.nn.Module) -> Any:
mb = torch_mlir.ModuleBuilder()
Expand Down Expand Up @@ -79,14 +96,36 @@ def compile(self, program: torch.nn.Module) -> Any:
""") from None
finally:
sys.stderr = sys.__stderr__
return self.backend.compile(mb.module)
try:
sys.stderr = StringIO()
asm_for_error_report = mb.module.operation.get_asm(
large_elements_limit=10, enable_debug_info=True)
return self.backend.compile(mb.module)
except Exception as e:
filename = os.path.join(tempfile.gettempdir(),
scripted.original_name + '.mlir')
with open(filename, 'w') as f:
f.write(asm_for_error_report)
raise Exception(f"""
NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
## Exception:
{e}
## Stderr:
{sys.stderr.getvalue()}
## Input IR has been saved in {filename}
""") from None
finally:
sys.stderr = sys.__stderr__


def run(self, artifact: Any, trace: Trace) -> Trace:
jit_module = self.backend.load(artifact)
backend_module = self.backend.load(artifact)
result: Trace = []
for item in trace:
numpy_inputs = [t.numpy() for t in item.inputs]
outputs = getattr(jit_module, item.symbol)(*numpy_inputs)
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
if isinstance(outputs, np.ndarray):
outputs = [outputs]
torch_outputs = [torch.tensor(ndarray) for ndarray in outputs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
Utilities for reporting the results of the test framework.
"""

from typing import List, Optional
from typing import List, Optional, Set

import collections
import io
import textwrap

Expand Down Expand Up @@ -70,7 +71,7 @@ def error_str(self):
assert self.failed
if self.value.size() != self.golden_value.size():
return self.context.format_error(
f'tensor shape mismatch: got {tensor.size()!r}, expected {golden_tensor.size()!r}'
f'tensor shape mismatch: got {self.value.size()!r}, expected {self.golden_value.size()!r}'
)
f = io.StringIO()
p = lambda *x: print(*x, file=f)
Expand Down Expand Up @@ -167,17 +168,60 @@ def error_str(self):
return f.getvalue()


def report_results(results: List[TestResult], verbose: bool = False):
"""Provide a basic error report summarizing various TestResult's.
def report_results(results: List[TestResult],
expected_failures: Set[str],
verbose: bool = False):
"""Print a basic error report summarizing various TestResult's.
This report uses the PASS/FAIL/XPASS/XFAIL nomenclature of LLVM's
"lit" testing utility. See
https://llvm.org/docs/CommandGuide/lit.html#test-status-results
The `expected_failures` set should contain the names of tests
(according to their `unique_name`) which are expected to fail.
The overall passing/failing status of the report requires these to fail
in order to succeed (this catches cases where things suddenly
start working).
If `verbose` is True, then provide an explanation of what failed.
Returns True if the run resulted in any unexpected pass/fail behavior.
Otherwise False.
"""
summary = collections.Counter()
for result in results:
report = SingleTestReport(result, ErrorContext.empty())
if not report.failed:
print(f'SUCCESS - "{result.unique_name}"')
expected_failure = result.unique_name in expected_failures
if expected_failure:
if report.failed:
error_str = ''
if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'XFAIL - "{result.unique_name}"' + error_str)
summary['XFAIL'] += 1
else:
print(f'XPASS - "{result.unique_name}"')
summary['XPASS'] += 1
else:
error_str = ''
if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'FAILURE - "{result.unique_name}"' + error_str)
if not report.failed:
print(f'PASS - "{result.unique_name}"')
summary['PASS'] += 1
else:
error_str = ''
if verbose:
error_str = '\n' + textwrap.indent(report.error_str(), ' ')
print(f'FAIL - "{result.unique_name}"' + error_str)
summary['FAIL'] += 1

# Print a summary for easy scanning.
print('\nSummary:')
KEY_MEANINGS = {
'PASS': 'Passed',
'FAIL': 'Failed',
'XFAIL': 'Expectedly Failed',
'XPASS': 'Unexpectedly Passed',
}
for key in ['PASS', 'FAIL', 'XFAIL', 'XPASS']:
if summary[key]:
print(f' {KEY_MEANINGS[key]}: {summary[key]}')
return summary['FAIL'] != 0 or summary['XPASS'] != 0
Loading

0 comments on commit d5108b9

Please sign in to comment.