Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions ingress/Torch-MLIR/convert-kernel-bench-to-mlir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#!/usr/bin/env python3

import importlib
import importlib.util

from pathlib import Path

from mlir import ir, passmanager
from torch_mlir import fx

kernels_as_pytorch_folder = Path(__file__).parent / "KernelBench" / "KernelBench"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this depends on where the git was cloned in the bash script, perhaps that last step (clone) could be done in this script as well?

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure.

Doing a git clone in either script feels unclean. I also don't like the idea of it being a submodule as that then seems to imply you have to clone KernelBench to do anything useful with lighthouse. It seems to me KernelBench will be just one source of ingress compute graphs of interest, with it potentially making sense to allow users/CI to opt-in to which paths they want to run tests with. What's the right mechanism for that? I am not sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KernelBench is NOT an ingress. Torch-MLIR is.

We now have three PRs that work with FX importer, none using the other. We should have one FX importer script that is used by others.

kernels_as_pytorch_level1 = kernels_as_pytorch_folder / "level1"
kernels_as_pytorch_level2 = kernels_as_pytorch_folder / "level2"

kernels_as_mlir_folder = Path(__file__).parent / "cache"
kernels_as_mlir_level1 = kernels_as_mlir_folder / "level1"
kernels_as_mlir_level1.mkdir(parents=True, exist_ok=True)
kernels_as_mlir_level2 = kernels_as_mlir_folder / "level2"
kernels_as_mlir_level2.mkdir(parents=True, exist_ok=True)

level1, level2 = Path("level1"), Path("level2")
ignore_list = [
level1 / "12_Matmul_with_diagonal_matrices_.py", # torch.operator "torch.aten.diag"
level1
/ "34_InstanceNorm.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (93898875033000)
level1
/ "72_conv_transposed_3D_asymmetric_input_asymmetric_kernel___strided_padded_grouped_.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "89_cumsum.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "90_cumprod.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "91_cumsum_reverse.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "92_cumsum_exclusive.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "93_masked_cumsum.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "95_CrossEntropyLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "96_HuberLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "97_ScaledDotProductAttention.py", # AssertionError: Torch not compiled with CUDA enabled
level1
/ "99_TripletMarginLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level2
/ "17_Conv2d_InstanceNorm_Divide.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94899412484104)
level2
/ "18_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "28_BMM_InstanceNorm_Sum_ResidualAdd_Multiply.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94899412484104)
level2
/ "42_ConvTranspose2d_GlobalAvgPool_BiasAdd_LogSumExp_Sum_Multiply.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "43_Conv3d_Max_LogSumExp_ReLU.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "45_Gemm_Sigmoid_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "52_Conv2d_Activation_BatchNorm.py", # failed to legalize operation 'torch.operator'
level2 / "55_Matmul_MaxPool_Sum_Scale.py", # MLIR file too big: 16G
level2 / "59_Matmul_Swish_Scaling.py", # MLIR file too big: 16G
level2 / "56_Matmul_Sigmoid_Sum.py", # MLIR file too big: 16G
level2 / "66_Matmul_Dropout_Softmax.py", # MLIR file too big: 4G
level2 / "68_Matmul_Min_Subtract.py", # MLIR file too big: 4G
level2 / "94_Gemm_BiasAdd_Hardtanh_Mish_GroupNorm.py", # MLIR file too big: 1G
level2 / "33_Gemm_Scale_BatchNorm.py", # MLIR file too big: 1G
level2 / "88_Gemm_GroupNorm_Swish_Multiply_Swish.py", # MLIR file too big: 1G
level2 / "75_Gemm_GroupNorm_Min_BiasAdd.py", # MLIR file too big: 1G
level2 / "84_Gemm_BatchNorm_Scaling_Softmax.py", # MLIR file too big: 1G
level2 / "97_Matmul_BatchNorm_BiasAdd_Divide_Swish.py", # MLIR file too big: 1G
level2 / "62_Matmul_GroupNorm_LeakyReLU_Sum.py", # MLIR file too big: 1G
level2 / "30_Gemm_GroupNorm_Hardtanh.py", # MLIR file too big: 1G
level2 / "95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py", # MLIR file too big: 1G
level2 / "29_Matmul_Mish_Mish.py", # MLIR file too big: 1G
level2 / "99_Matmul_GELU_Softmax.py", # MLIR file too big: 1G
level2 / "98_Matmul_AvgPool_GELU_Scale_Max.py", # MLIR file too big: 1G
level2 / "80_Gemm_Max_Subtract_GELU.py", # MLIR file too big: 1G
level2 / "81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py", # MLIR file too big: 1G
level2 / "12_Gemm_Multiply_LeakyReLU.py", # MLIR file too big: 1G
level2 / "53_Gemm_Scaling_Hardtanh_GELU.py", # MLIR file too big: 1G
level2 / "9_Matmul_Subtract_Multiply_ReLU.py", # MLIR file too big: 1G
level2 / "70_Gemm_Sigmoid_Scaling_ResidualAdd.py", # MLIR file too big: 1G
level2 / "86_Matmul_Divide_GELU.py", # MLIR file too big: 1G
level2 / "63_Gemm_ReLU_Divide.py", # MLIR file too big: 1G
level2 / "76_Gemm_Add_ReLU.py", # MLIR file too big: 1G
level2 / "14_Gemm_Divide_Sum_Scaling.py", # MLIR file too big: 1G
level2 / "39_Gemm_Scale_BatchNorm.py", # MLIR file too big: 256M
level2 / "41_Gemm_BatchNorm_GELU_ReLU.py", # MLIR file too big: 256M
level2 / "40_Matmul_Scaling_ResidualAdd.py", # MLIR file too big: 256M
level2 / "37_Matmul_Swish_Sum_GroupNorm.py", # MLIR file too big: 64.3M
level2
/ "58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "64_Gemm_LogSumExp_LeakyReLU_LeakyReLU_GELU_GELU.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "79_Conv3d_Multiply_InstanceNorm_Clamp_Multiply_Max.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94312016449768)
level2
/ "92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
]


ctx = ir.Context()
pm = passmanager.PassManager(context=ctx)
pm.add("linalg-specialize-generic-ops")

for pytorch_level, mlir_level in (
(kernels_as_pytorch_level1, kernels_as_mlir_level1),
(kernels_as_pytorch_level2, kernels_as_mlir_level2),
):
for kernel_pytorch_file in pytorch_level.iterdir():
level_and_kernel = (
Path(kernel_pytorch_file.parent.name) / kernel_pytorch_file.name
)
if level_and_kernel in ignore_list or not kernel_pytorch_file.is_file():
print(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print to stderr?

f"Skipping: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}"
)
continue

module_name = kernel_pytorch_file.stem

kernel_as_mlir_path = mlir_level / (module_name + ".mlir")
if kernel_as_mlir_path.exists():
print(
f"Already in cache: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}"
)
continue
print(
f"Processing: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}"
)

module_spec = importlib.util.spec_from_file_location(
module_name, kernel_pytorch_file
)

if module_spec is None or module_spec.loader is None:
print(f"Error: Could not create module spec for {kernel_pytorch_file}")
continue

module = importlib.util.module_from_spec(module_spec)
# Execute the module to load its contents
module_spec.loader.exec_module(module)

if not all(
hasattr(module, a) for a in ("Model", "get_inputs", "get_init_inputs")
):
print(f"Error: module in file {kernel_pytorch_file} not a proper benchmark")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to mark error so to return non-zero at the end upon any such continue?

continue

# TODO: check hasattr(module, "in_features") etc and adjust to sizes that are more tractable for torch-mlir

try:
m = fx.export_and_import(
module.Model(*module.get_init_inputs()),
*module.get_inputs(),
output_type=fx.OutputType.LINALG_ON_TENSORS,
)
except Exception as e:
print(f"Error: got the following error converting {kernel_pytorch_file}")
raise e

before_clean_up = "//" + str(m)[:-1].replace("\n", "\n//") + "\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats this for?

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to include the pre-cleanup output of torch-mlir in the final file, though as commented-out code.

It helped me see what the specialization pass (& potentially other clean-up) is accomplishing.

I don't have strong opinion on whether this should stay.

# Cross boundary from torch-mlir's mlir to environment's mlir
m = ir.Module.parse(str(m), context=ctx)
# Run clean-up, e.g. linalg-"specialization" passes to raise within Linalg.
try:
pm.run(m.operation) # cleanup
except Exception as e:
print(f"Error: got the following error cleaning up {module_name}")
raise e

with kernel_as_mlir_path.open("w") as f:
print("// Torch-MLIR output:", file=f)
print(before_clean_up, file=f)
print("// MLIR output after clean-up:", file=f)
print(m, file=f)
14 changes: 6 additions & 8 deletions ingress/Torch-MLIR/install-virtualenv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,16 @@ else
DEVICE_TYPE=$(lspci | grep VGA)
fi


# Install torch-mlir inside a virtual environment
echo "First ensure uv is installed"

python -m pip install uv --upgrade

echo "Preparing the virtual environment"
python -m uv venv torch-mlir-venv --python 3.12

#echo "Preparing the virtual environment"
#python3 -m venv torch-mlir-venv
source torch-mlir-venv/bin/activate
uv pip install --upgrade pip wheel

echo "Installing mlir-python-bindings and numpy"
uv pip install numpy mlir-python-bindings -f https://makslevental.github.io/wheels

# GPU support ("AMD", "NVIDIA", "Intel")
EXTRA_INDEX_URL=""
Expand All @@ -46,13 +43,14 @@ if [ $? != 0 ]; then
exit 1
fi


echo "Installing torch-mlir"
# This only seems to work on Ubuntu
uv pip install --pre torch-mlir \
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
-f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels
if [ $? != 0 ]; then
exit 1
fi

# TODO(rolf): move to a more appropriate place
echo "Obtaining KernelBench repo"
git clone https://github.com/ScalingIntelligence/KernelBench.git