Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rfc] Use logging.getLogger for projects/pt1/e2e_testing #3173

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 5104 files
5 changes: 3 additions & 2 deletions lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
do_bcast = true;
} else {
op->emitError("The size of tensor a (")
<< inDim << ")" << "must match the size of tensor b (" << outDim
<< ")" << "at non-singleton dimension " << inPos;
<< inDim << ")"
<< "must match the size of tensor b (" << outDim << ")"
<< "at non-singleton dimension " << inPos;
}
}
std::reverse(bcastDims.begin(), bcastDims.end());
Expand Down
3 changes: 2 additions & 1 deletion lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ class LowerToBackendContractPass
return signalPassFailure();
} while (!satisfiesBackendContract(module, target));
LLVM_DEBUG({
llvm::dbgs() << "LowerToBackendContractPass: " << "succeeded after " << i
llvm::dbgs() << "LowerToBackendContractPass: "
<< "succeeded after " << i
<< " iterations of the simplification pipeline\n";
});
}
Expand Down
50 changes: 48 additions & 2 deletions projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Also available under a BSD-style license. See LICENSE.

import argparse
import logging
renxida marked this conversation as resolved.
Show resolved Hide resolved
import re
import sys

Expand Down Expand Up @@ -103,6 +104,18 @@ def _get_argparse():
default=".*",
help="""
Regular expression specifying which tests to include in this run.
<<<<<<< HEAD
""")
parser.add_argument("--log_level", default="WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="set the log level")
renxida marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("-d", "--debug", action="store_const", dest="log_level", const="DEBUG", help="set log level to DEBUG for detailed debug output")
parser.add_argument("-v", "--verbose", action="store_const", dest="log_level", const="INFO", help="set log level to INFO to report a more detailed but still user-friendly level of verbosity")
parser.add_argument("-q", "--quiet", action="store_const", dest="log_level", const="ERROR", help="suppress all logs except errors")

parser.add_argument("-s", "--sequential",
default=False,
action="store_true",
help="""Run tests sequentially rather than in parallel.
=======
""",
)
parser.add_argument(
Expand All @@ -118,6 +131,7 @@ def _get_argparse():
default=False,
action="store_true",
help="""Run tests sequentially rather than in parallel.
>>>>>>> 0a2d21b108602d2b11c208ca1a713a72f483f6c1
This can be useful for debugging, since it runs the tests in the same process,
which make it easier to attach a debugger or get a stack trace.""",
)
Expand All @@ -139,6 +153,19 @@ def _get_argparse():

def main():
args = _get_argparse().parse_args()
args.log_level = args.log_level.upper()

logger = logging.getLogger() # use root logger by default. Easy to change later.
logger.setLevel(logging.NOTSET)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(args.log_level)
if args.log_level != "DEBUG":
fmt = "%(levelname)s: %(message)s"
else:
fmt = "%(levelname)s: %(filename)s:%(lineno)d:\n%(message)s"
formatter = logging.Formatter(fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)

all_test_unique_names = set(test.unique_name for test in GLOBAL_TEST_REGISTRY)

Expand Down Expand Up @@ -197,32 +224,51 @@ def main():
if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None:
for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed:
if arg not in all_test_unique_names:
<<<<<<< HEAD
logger.error(f"--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name")
=======
print(
f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name"
)
>>>>>>> 0a2d21b108602d2b11c208ca1a713a72f483f6c1
sys.exit(1)

# Find the selected tests, and emit a diagnostic if none are found.
tests = [
test for test in available_tests if re.match(args.filter, test.unique_name)
]
available_tests = [test.unique_name for test in available_tests]
if len(tests) == 0:
<<<<<<< HEAD
logger.error(
f"the provided filter {args.filter!r} does not match any tests. The available tests are:\n\t" + "\n\t".join(available_tests)
)
=======
print(f"ERROR: the provided filter {args.filter!r} does not match any tests")
print("The available tests are:")
for test in available_tests:
print(test.unique_name)
>>>>>>> 0a2d21b108602d2b11c208ca1a713a72f483f6c1
sys.exit(1)

# Run the tests.
results = run_tests(tests, config, args.sequential, args.verbose)
results = run_tests(tests, config, args.sequential,
verbose=logger.level >= logging.INFO)

# Report the test results.
failed = report_results(results, xfail_set, args.verbose, args.config)
failed = report_results(results, xfail_set,
verbose=logger.level >= logging.INFO,
config=args.config)
if args.config == "torchdynamo":
<<<<<<< HEAD
logger.warning("the TorchScript based dynamo support is deprecated. "
"The config for torchdynamo is planned to be removed in the future.")
=======
print(
"\033[91mWarning: the TorchScript based dynamo support is deprecated. "
"The config for torchdynamo is planned to be removed in the future.\033[0m"
)
>>>>>>> 0a2d21b108602d2b11c208ca1a713a72f483f6c1
if args.ignore_failures:
sys.exit(0)
sys.exit(1 if failed else 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import logging
logger = logging.getLogger()

from typing import Any

Expand Down Expand Up @@ -32,9 +34,14 @@ def __init__(self, backend: LinalgOnTensorsBackend):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torchscript.compile(
<<<<<<< HEAD
program, example_args, output_type="linalg-on-tensors")
logger.debug("MLIR produced by LinalgOnTensorsBackendTestConfig:\n" + str(module))
=======
program, example_args, output_type="linalg-on-tensors"
)

>>>>>>> 0a2d21b108602d2b11c208ca1a713a72f483f6c1
return self.backend.compile(module)

def run(self, artifact: Any, trace: Trace) -> Trace:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from torch_mlir.extras import onnx_importer
from torch_mlir.dialects import torch as torch_d
from torch_mlir.ir import Context, Module

import logging
logger = logging.getLogger()

def import_onnx(contents):
# Import the ONNX model proto from the file contents:
Expand All @@ -39,7 +40,7 @@ def import_onnx(contents):
return m


def convert_onnx(model, inputs):
def convert_onnx(model: torch.nn.Module, inputs):
buffer = io.BytesIO()

# Process the type information so we export with the dynamic shape information
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(self, backend: OnnxBackend, use_make_fx: bool = False):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
onnx_module = convert_onnx(program, example_args)
logger.debug("MLIR produced by OnnxBackendTestConfig:\n" + str(onnx_module))
compiled_module = self.backend.compile(onnx_module)
return compiled_module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import logging
logger = logging.getLogger()

from typing import Any

Expand Down Expand Up @@ -31,7 +33,7 @@ def __init__(self, backend: StablehloBackend):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torchscript.compile(program, example_args, output_type="stablehlo")

logger.debug("MLIR produced by StablehloBackendTestConfig:\n" + str(module))
return self.backend.compile(module)

def run(self, artifact: Any, trace: Trace) -> Trace:
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
Utilities for reporting the results of the test framework.
"""

from logging import getLogger
logger = getLogger()

from typing import Any, List, Optional, Set

import collections
Expand Down Expand Up @@ -310,6 +313,20 @@ def report_results(
expected_failure = result.unique_name in expected_failures
if expected_failure:
if report.failed:
<<<<<<< HEAD
logger.info(f'XFAIL - "{result.unique_name}"')
results_by_outcome['XFAIL'].append((result, report))
else:
logger.info(f'XPASS - "{result.unique_name}"')
results_by_outcome['XPASS'].append((result, report))
else:
if not report.failed:
logger.info(f'PASS - "{result.unique_name}"')
results_by_outcome['PASS'].append((result, report))
else:
logger.info(f'FAIL - "{result.unique_name}"')
results_by_outcome['FAIL'].append((result, report))
=======
print(f'XFAIL - "{result.unique_name}"')
results_by_outcome["XFAIL"].append((result, report))
else:
Expand All @@ -322,6 +339,7 @@ def report_results(
else:
print(f'FAIL - "{result.unique_name}"')
results_by_outcome["FAIL"].append((result, report))
>>>>>>> 0a2d21b108602d2b11c208ca1a713a72f483f6c1

OUTCOME_MEANINGS = collections.OrderedDict()
OUTCOME_MEANINGS["PASS"] = "Passed"
Expand All @@ -348,8 +366,13 @@ def report_results(
for result, report in results:
print(f' {outcome} - "{result.unique_name}"')
# If the test failed, print the error message.
<<<<<<< HEAD
if outcome == 'FAIL':
logger.info(textwrap.indent(report.error_str(), ' ' * 8))
=======
if outcome == "FAIL" and verbose:
print(textwrap.indent(report.error_str(), " " * 8))
>>>>>>> 0a2d21b108602d2b11c208ca1a713a72f483f6c1

# Print a summary for easy scanning.
print("\nSummary:")
Expand Down
Loading