Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Lint

on:
push:
pull_request:

jobs:
Lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5

- name: Install uv
uses: astral-sh/setup-uv@v6

- name: Install the project and its dependencies
run: |
uv sync

- name: Run pre-commit
run: |-
uv run pre-commit run --all-files
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.5
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format
6 changes: 3 additions & 3 deletions ingress/mlir-gen/mlir_gen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def weights(
assert k_as_num_inputs % block.k == 0, "invalid tile size for K dim"
assert n_as_num_outputs % block.n == 0, "invalid tile size for N dim"
if block.vnni:
assert (
block.n % block.vnni == 0
), "incompatible tile sizes for N and VNNI dims"
assert block.n % block.vnni == 0, (
"incompatible tile sizes for N and VNNI dims"
)
shape = (
n_as_num_outputs // block.n,
k_as_num_inputs // block.k,
Expand Down
34 changes: 33 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ dependencies = [

[dependency-groups]
dev = [
"lit==18.1.8" # Tool to configure, discover and run tests
"lit==18.1.8", # Tool to configure, discover and run tests
"ruff==0.14.5", # Python linter and formatter
"pre-commit", # Tool to manage and apply pre-commit hooks
]

[project.optional-dependencies]
Expand Down Expand Up @@ -82,3 +84,33 @@ include = ["lighthouse*"]

[tool.setuptools.dynamic]
version = {attr = "lighthouse.__version__"}

[tool.ruff]
src = ["lighthouse"]
target-version = "py310"
line-length = 88

[tool.ruff.format]
docstring-code-format = true
quote-style = "double"

# List of rules:
# https://docs.astral.sh/ruff/rules/
[tool.ruff.lint]
select = [
"D419", # empty-docstring
"E", # Error
"F", # Pyflakes
"PERF", # Perflint
"RUF022", # __all__ is not sorted
"RUF030", # print() call in assert
"RUF034", # useless if-else
"RUF047", # empty else
"RUF200", # invalid pyproject.toml
"W", # Warning
]
ignore = [
"E501", # line-too-long
"PERF203", # try-except-in-loop
"PERF401", # manual-list-comprehension
]
7 changes: 1 addition & 6 deletions python/examples/ingress/torch/MLPModel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@
import torch
import torch.nn as nn

import os

class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 2)
)
self.net = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))

def forward(self, x):
return self.net(x)
Expand Down
4 changes: 3 additions & 1 deletion python/examples/ingress/torch/mlp_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@
# - Loads the MLPModel class and instantiates it with arguments obtained from 'get_init_inputs()'
# - Calls get_sample_inputs() to get sample input tensors for shape inference
# - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir
# fmt: off
mlir_module_ir: ir.Module = import_from_file(
model_path, # Path to the Python file containing the model
model_class_name="MLPModel", # Name of the PyTorch nn.Module class to convert
init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__()
sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)'
dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types)
ir_context=ir_context # MLIR context for the conversion
ir_context=ir_context, # MLIR context for the conversion
)
# fmt: on

# The PyTorch model is now converted to MLIR at this point. You can now convert
# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file.
Expand Down
4 changes: 1 addition & 3 deletions python/examples/ingress/torch/mlp_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
ir_context = ir.Context()
# Step 2: Convert the PyTorch model to MLIR
mlir_module_ir: ir.Module = import_from_model(
model,
sample_args=(sample_input,),
ir_context=ir_context
model, sample_args=(sample_input,), ir_context=ir_context
)

# The PyTorch model is now converted to MLIR at this point. You can now convert
Expand Down
5 changes: 5 additions & 0 deletions python/lighthouse/ingress/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Provides functions to convert PyTorch models to MLIR."""

from .importer import import_from_file, import_from_model

__all__ = [
"import_from_file",
"import_from_model",
]
46 changes: 29 additions & 17 deletions python/lighthouse/ingress/torch/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from pathlib import Path
from typing import Iterable, Mapping

from lighthouse.ingress.torch.utils import load_and_run_callable, maybe_load_and_run_callable
from lighthouse.ingress.torch.utils import (
load_and_run_callable,
maybe_load_and_run_callable,
)

try:
import torch
Expand All @@ -25,6 +28,7 @@

from mlir import ir


def import_from_model(
model: nn.Module,
sample_args: Iterable,
Expand All @@ -49,10 +53,10 @@ def import_from_model(
ir_context (ir.Context, optional): An optional MLIR context to use for parsing
the module. If not provided, the module is returned as a string.
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.

Returns:
str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.

Examples:
>>> import torch
>>> import torch.nn as nn
Expand All @@ -61,17 +65,22 @@ def import_from_model(
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(10, 5)
...
... def forward(self, x):
... return self.fc(x)
>>> model = SimpleModel()
>>> sample_input = (torch.randn(1, 10),)
>>> #
>>> # option 1: get MLIR module as a string
>>> mlir_module : str = import_from_model(model, sample_input, dialect="linalg-on-tensors")
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
>>> mlir_module: str = import_from_model(
... model, sample_input, dialect="linalg-on-tensors"
... )
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
>>> # option 2: get MLIR module as an ir.Module
>>> ir_context = ir.Context()
>>> mlir_module_ir : ir.Module = import_from_model(model, sample_input, dialect="tosa", ir_context=ir_context)
>>> mlir_module_ir: ir.Module = import_from_model(
... model, sample_input, dialect="tosa", ir_context=ir_context
... )
"""
if dialect == "linalg":
raise ValueError(
Expand Down Expand Up @@ -134,45 +143,48 @@ def import_from_file(
ir_context (ir.Context, optional): An optional MLIR context to use for parsing
the module. If not provided, the module is returned as a string.
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.

Returns:
str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.

Examples:
Given a file `path/to/model_file.py` with the following content:
```python
import torch
import torch.nn as nn


class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)

def forward(self, x):
return self.fc(x)


def get_inputs():
return (torch.randn(1, 10),)
```

The import script would look like:
>>> from lighthouse.ingress.torch_import import import_from_file
>>> # option 1: get MLIR module as a string
>>> mlir_module : str = import_from_file(
>>> mlir_module: str = import_from_file(
... "path/to/model_file.py",
... model_class_name="MyModel",
... init_args_fn_name=None,
... dialect="linalg-on-tensors"
... dialect="linalg-on-tensors",
... )
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
>>> # option 2: get MLIR module as an ir.Module
>>> ir_context = ir.Context()
>>> mlir_module_ir : ir.Module = import_from_file(
>>> mlir_module_ir: ir.Module = import_from_file(
... "path/to/model_file.py",
... model_class_name="MyModel",
... init_args_fn_name=None,
... dialect="linalg-on-tensors",
... ir_context=ir_context
... ir_context=ir_context,
... )
"""
if isinstance(filepath, str):
Expand All @@ -191,24 +203,24 @@ def get_inputs():
module,
init_args_fn_name,
default=tuple(),
error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}"
error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}",
)
model_init_kwargs = maybe_load_and_run_callable(
module,
init_kwargs_fn_name,
default={},
error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}"
error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}",
)
sample_args = load_and_run_callable(
module,
sample_args_fn_name,
f"Sample args function '{sample_args_fn_name}' not found in {filepath}"
f"Sample args function '{sample_args_fn_name}' not found in {filepath}",
)
sample_kwargs = maybe_load_and_run_callable(
module,
sample_kwargs_fn_name,
default={},
error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}"
error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}",
)

nn_model: nn.Module = model(*model_init_args, **model_init_kwargs)
Expand Down
6 changes: 1 addition & 5 deletions python/lighthouse/ingress/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,4 @@ def maybe_load_and_run_callable(
"""
if symbol_name is None:
return default
return load_and_run_callable(
module,
symbol_name,
error_msg=error_msg
)
return load_and_run_callable(module, symbol_name, error_msg=error_msg)
8 changes: 8 additions & 0 deletions python/lighthouse/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@
torch_to_memref,
torch_to_packed_args,
)

__all__ = [
"get_packed_arg",
"memref_to_ctype",
"memrefs_to_packed_args",
"torch_to_memref",
"torch_to_packed_args",
]