diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..4f35a54 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,22 @@ +name: Lint + +on: + push: + pull_request: + +jobs: + Lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + + - name: Install the project and its dependencies + run: | + uv sync + + - name: Run pre-commit + run: |- + uv run pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..daca500 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.14.5 + hooks: + # Run the linter. + - id: ruff-check + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/ingress/mlir-gen/mlir_gen/main.py b/ingress/mlir-gen/mlir_gen/main.py index 3a421d0..ba27074 100644 --- a/ingress/mlir-gen/mlir_gen/main.py +++ b/ingress/mlir-gen/mlir_gen/main.py @@ -136,9 +136,9 @@ def weights( assert k_as_num_inputs % block.k == 0, "invalid tile size for K dim" assert n_as_num_outputs % block.n == 0, "invalid tile size for N dim" if block.vnni: - assert ( - block.n % block.vnni == 0 - ), "incompatible tile sizes for N and VNNI dims" + assert block.n % block.vnni == 0, ( + "incompatible tile sizes for N and VNNI dims" + ) shape = ( n_as_num_outputs // block.n, k_as_num_inputs // block.k, diff --git a/pyproject.toml b/pyproject.toml index 9b4b2e8..e738e28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,9 @@ dependencies = [ [dependency-groups] dev = [ - "lit==18.1.8" # Tool to configure, discover and run tests + "lit==18.1.8", # Tool to configure, discover and run tests + "ruff==0.14.5", # Python linter and formatter + "pre-commit", # Tool to manage and apply pre-commit hooks ] [project.optional-dependencies] @@ -82,3 +84,33 @@ include = ["lighthouse*"] [tool.setuptools.dynamic] version = {attr = "lighthouse.__version__"} + +[tool.ruff] +src = ["lighthouse"] +target-version = "py310" +line-length = 88 + +[tool.ruff.format] +docstring-code-format = true +quote-style = "double" + +# List of rules: +# https://docs.astral.sh/ruff/rules/ +[tool.ruff.lint] +select = [ + "D419", # empty-docstring + "E", # Error + "F", # Pyflakes + "PERF", # Perflint + "RUF022", # __all__ is not sorted + "RUF030", # print() call in assert + "RUF034", # useless if-else + "RUF047", # empty else + "RUF200", # invalid pyproject.toml + "W", # Warning +] +ignore = [ + "E501", # line-too-long + "PERF203", # try-except-in-loop + "PERF401", # manual-list-comprehension +] diff --git a/python/examples/ingress/torch/MLPModel/model.py b/python/examples/ingress/torch/MLPModel/model.py index 5bbdbbe..18b963f 100644 --- a/python/examples/ingress/torch/MLPModel/model.py +++ b/python/examples/ingress/torch/MLPModel/model.py @@ -3,16 +3,11 @@ import torch import torch.nn as nn -import os class MLPModel(nn.Module): def __init__(self): super().__init__() - self.net = nn.Sequential( - nn.Linear(10, 32), - nn.ReLU(), - nn.Linear(32, 2) - ) + self.net = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2)) def forward(self, x): return self.net(x) diff --git a/python/examples/ingress/torch/mlp_from_file.py b/python/examples/ingress/torch/mlp_from_file.py index b8afc36..748de85 100644 --- a/python/examples/ingress/torch/mlp_from_file.py +++ b/python/examples/ingress/torch/mlp_from_file.py @@ -34,14 +34,16 @@ # - Loads the MLPModel class and instantiates it with arguments obtained from 'get_init_inputs()' # - Calls get_sample_inputs() to get sample input tensors for shape inference # - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir +# fmt: off mlir_module_ir: ir.Module = import_from_file( model_path, # Path to the Python file containing the model model_class_name="MLPModel", # Name of the PyTorch nn.Module class to convert init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__() sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)' dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types) - ir_context=ir_context # MLIR context for the conversion + ir_context=ir_context, # MLIR context for the conversion ) +# fmt: on # The PyTorch model is now converted to MLIR at this point. You can now convert # the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file. diff --git a/python/examples/ingress/torch/mlp_from_model.py b/python/examples/ingress/torch/mlp_from_model.py index 6621e90..4590429 100644 --- a/python/examples/ingress/torch/mlp_from_model.py +++ b/python/examples/ingress/torch/mlp_from_model.py @@ -31,9 +31,7 @@ ir_context = ir.Context() # Step 2: Convert the PyTorch model to MLIR mlir_module_ir: ir.Module = import_from_model( - model, - sample_args=(sample_input,), - ir_context=ir_context + model, sample_args=(sample_input,), ir_context=ir_context ) # The PyTorch model is now converted to MLIR at this point. You can now convert diff --git a/python/lighthouse/ingress/torch/__init__.py b/python/lighthouse/ingress/torch/__init__.py index d73f426..d7fc600 100644 --- a/python/lighthouse/ingress/torch/__init__.py +++ b/python/lighthouse/ingress/torch/__init__.py @@ -1,3 +1,8 @@ """Provides functions to convert PyTorch models to MLIR.""" from .importer import import_from_file, import_from_model + +__all__ = [ + "import_from_file", + "import_from_model", +] diff --git a/python/lighthouse/ingress/torch/importer.py b/python/lighthouse/ingress/torch/importer.py index 87c7655..f95808f 100644 --- a/python/lighthouse/ingress/torch/importer.py +++ b/python/lighthouse/ingress/torch/importer.py @@ -3,7 +3,10 @@ from pathlib import Path from typing import Iterable, Mapping -from lighthouse.ingress.torch.utils import load_and_run_callable, maybe_load_and_run_callable +from lighthouse.ingress.torch.utils import ( + load_and_run_callable, + maybe_load_and_run_callable, +) try: import torch @@ -25,6 +28,7 @@ from mlir import ir + def import_from_model( model: nn.Module, sample_args: Iterable, @@ -49,10 +53,10 @@ def import_from_model( ir_context (ir.Context, optional): An optional MLIR context to use for parsing the module. If not provided, the module is returned as a string. **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function. - + Returns: str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided. - + Examples: >>> import torch >>> import torch.nn as nn @@ -61,17 +65,22 @@ def import_from_model( ... def __init__(self): ... super().__init__() ... self.fc = nn.Linear(10, 5) + ... ... def forward(self, x): ... return self.fc(x) >>> model = SimpleModel() >>> sample_input = (torch.randn(1, 10),) >>> # >>> # option 1: get MLIR module as a string - >>> mlir_module : str = import_from_model(model, sample_input, dialect="linalg-on-tensors") - >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect + >>> mlir_module: str = import_from_model( + ... model, sample_input, dialect="linalg-on-tensors" + ... ) + >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect >>> # option 2: get MLIR module as an ir.Module >>> ir_context = ir.Context() - >>> mlir_module_ir : ir.Module = import_from_model(model, sample_input, dialect="tosa", ir_context=ir_context) + >>> mlir_module_ir: ir.Module = import_from_model( + ... model, sample_input, dialect="tosa", ir_context=ir_context + ... ) """ if dialect == "linalg": raise ValueError( @@ -134,23 +143,26 @@ def import_from_file( ir_context (ir.Context, optional): An optional MLIR context to use for parsing the module. If not provided, the module is returned as a string. **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function. - + Returns: str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided. - + Examples: Given a file `path/to/model_file.py` with the following content: ```python import torch import torch.nn as nn + class MyModel(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 5) + def forward(self, x): return self.fc(x) + def get_inputs(): return (torch.randn(1, 10),) ``` @@ -158,21 +170,21 @@ def get_inputs(): The import script would look like: >>> from lighthouse.ingress.torch_import import import_from_file >>> # option 1: get MLIR module as a string - >>> mlir_module : str = import_from_file( + >>> mlir_module: str = import_from_file( ... "path/to/model_file.py", ... model_class_name="MyModel", ... init_args_fn_name=None, - ... dialect="linalg-on-tensors" + ... dialect="linalg-on-tensors", ... ) - >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect + >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect >>> # option 2: get MLIR module as an ir.Module >>> ir_context = ir.Context() - >>> mlir_module_ir : ir.Module = import_from_file( + >>> mlir_module_ir: ir.Module = import_from_file( ... "path/to/model_file.py", ... model_class_name="MyModel", ... init_args_fn_name=None, ... dialect="linalg-on-tensors", - ... ir_context=ir_context + ... ir_context=ir_context, ... ) """ if isinstance(filepath, str): @@ -191,24 +203,24 @@ def get_inputs(): module, init_args_fn_name, default=tuple(), - error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}" + error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}", ) model_init_kwargs = maybe_load_and_run_callable( module, init_kwargs_fn_name, default={}, - error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}" + error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}", ) sample_args = load_and_run_callable( module, sample_args_fn_name, - f"Sample args function '{sample_args_fn_name}' not found in {filepath}" + f"Sample args function '{sample_args_fn_name}' not found in {filepath}", ) sample_kwargs = maybe_load_and_run_callable( module, sample_kwargs_fn_name, default={}, - error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}" + error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}", ) nn_model: nn.Module = model(*model_init_args, **model_init_kwargs) diff --git a/python/lighthouse/ingress/torch/utils.py b/python/lighthouse/ingress/torch/utils.py index 9a464b3..d3aec36 100644 --- a/python/lighthouse/ingress/torch/utils.py +++ b/python/lighthouse/ingress/torch/utils.py @@ -43,8 +43,4 @@ def maybe_load_and_run_callable( """ if symbol_name is None: return default - return load_and_run_callable( - module, - symbol_name, - error_msg=error_msg - ) + return load_and_run_callable(module, symbol_name, error_msg=error_msg) diff --git a/python/lighthouse/utils/__init__.py b/python/lighthouse/utils/__init__.py index 22799cc..4cff9a5 100644 --- a/python/lighthouse/utils/__init__.py +++ b/python/lighthouse/utils/__init__.py @@ -7,3 +7,11 @@ torch_to_memref, torch_to_packed_args, ) + +__all__ = [ + "get_packed_arg", + "memref_to_ctype", + "memrefs_to_packed_args", + "torch_to_memref", + "torch_to_packed_args", +]