diff --git a/python/examples/mlir/compile_and_run.py b/python/examples/mlir/compile_and_run.py index 345251a..d0529f6 100644 --- a/python/examples/mlir/compile_and_run.py +++ b/python/examples/mlir/compile_and_run.py @@ -1,5 +1,5 @@ import torch -import os +import argparse from mlir import ir from mlir.dialects import transform @@ -65,7 +65,8 @@ def create_schedule(ctx: ir.Context) -> ir.Module: func = structured.MatchOp.match_op_names( named_seq.bodyTarget, ["func.func"] ) - # Use C interface wrappers - required to make function executable after jitting. + # Use C interface wrappers - required to make function executable + # after jitting. func = transform.apply_registered_pass( anytype, func, "llvm-request-c-wrappers" ) @@ -126,7 +127,7 @@ def create_pass_pipeline(ctx: ir.Context) -> PassManager: # The example's entry point. -def main(): +def main(args): ### Baseline computation ### # Create inputs. a = torch.randn(16, 32, dtype=torch.float32) @@ -149,26 +150,23 @@ def main(): pm.run(kernel.operation) ### Compilation ### - # External shared libraries, containing MLIR runner utilities, are generally - # required to execute the compiled module. - # In this case, MLIR runner utils libraries are expected: - # - libmlir_runner_utils.so - # - libmlir_c_runner_utils.so + # Parse additional libraries if present. # - # Get paths to MLIR runner shared libraries through an environment variable. + # External shared libraries, runtime utilities, might be needed to execute + # the compiled module. # The execution engine requires full paths to the libraries. - # For example, the env variable can be set as: - # LIGHTHOUSE_SHARED_LIBS=$PATH_TO_LLVM/build/lib/lib1.so:$PATH_TO_LLVM/build/lib/lib2.so - mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":") + mlir_libs = [] + if args.shared_libs: + mlir_libs += args.shared_libs.split(",") # JIT the kernel. eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs) # Initialize the JIT engine. # - # The deferred initialization executes global constructors that might have been - # created by the module during engine creation (for example, when `gpu.module` - # is present) or registered afterwards. + # The deferred initialization executes global constructors that might + # have been created by the module during engine creation (for example, + # when `gpu.module` is present) or registered afterwards. # # Initialization is not strictly necessary in this case. # However, it is a good practice to perform it regardless. @@ -194,4 +192,21 @@ def main(): if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + + # External shared libraries, runtime utilities, might be needed to + # execute the compiled module. + # For example, MLIR runner utils libraries such as: + # - libmlir_runner_utils.so + # - libmlir_c_runner_utils.so + # + # Full paths to the libraries should be provided. + # For example: + # --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so + parser.add_argument( + "--shared-libs", + type=str, + help="Comma-separated list of libraries to link dynamically", + ) + args = parser.parse_args() + main(args)