[mlir][x86vector] Python bindings for x86vector dialect#179958
Merged
Conversation
Registers python bindings for x86vector dialect and transform ops.
Member
|
@llvm/pr-subscribers-mlir-vector Author: Adam Siemieniuk (adam-smnk) ChangesRegisters python bindings for x86vector dialect and transform ops. Full diff: https://github.com/llvm/llvm-project/pull/179958.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 891829fca017f..d57ed1f1cd171 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -73,4 +73,3 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
#endif // X86VECTOR_TRANSFORM_OPS
-
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8ab145ada85dd..50143f700f5a1 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -325,6 +325,15 @@ declare_mlir_dialect_extension_python_bindings(
"../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
)
+declare_mlir_dialect_extension_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/X86VectorTransformOps.td
+ SOURCES
+ dialects/transform/x86vector.py
+ DIALECT_NAME transform
+ EXTENSION_NAME x86vector_transform)
+
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -510,6 +519,13 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS_TD_FILE
"dialects/VectorAttributes.td")
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/X86Vector.td
+ SOURCES dialects/x86vector.py
+ DIALECT_NAME x86vector)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/X86Vector.td b/mlir/python/mlir/dialects/X86Vector.td
new file mode 100644
index 0000000000000..d8a846bf9e905
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86Vector.td
@@ -0,0 +1,14 @@
+//===-- X86Vector.td - Entry point for x86vector bindings --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTOR
+#define PYTHON_BINDINGS_X86VECTOR
+
+include "mlir/Dialect/X86Vector/X86Vector.td"
+
+#endif // PYTHON_BINDINGS_X86VECTOR
diff --git a/mlir/python/mlir/dialects/X86VectorTransformOps.td b/mlir/python/mlir/dialects/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..ad6a693923703
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86VectorTransformOps.td
@@ -0,0 +1,14 @@
+//===-- X86VectorTransformOps.td ---------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+#define PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+
+include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td"
+
+#endif // PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
diff --git a/mlir/python/mlir/dialects/transform/x86vector.py b/mlir/python/mlir/dialects/transform/x86vector.py
new file mode 100644
index 0000000000000..cccd300522797
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/x86vector.py
@@ -0,0 +1,5 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._x86vector_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/x86vector.py b/mlir/python/mlir/dialects/x86vector.py
new file mode 100644
index 0000000000000..eddc93dbe6460
--- /dev/null
+++ b/mlir/python/mlir/dialects/x86vector.py
@@ -0,0 +1,6 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._x86vector_ops_gen import *
+from ._x86vector_ops_gen import _Dialect
diff --git a/mlir/test/python/dialects/transform_x86vector_ext.py b/mlir/test/python/dialects/transform_x86vector_ext.py
new file mode 100644
index 0000000000000..ad8dab8175ef2
--- /dev/null
+++ b/mlir/test/python/dialects/transform_x86vector_ext.py
@@ -0,0 +1,40 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import x86vector
+
+
+def run_apply_patterns(f):
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ apply = transform.ApplyPatternsOp(sequence.bodyTarget)
+ with InsertionPoint(apply.patterns):
+ f()
+ transform.YieldOp()
+ print("\nTEST:", f.__name__)
+ print(module)
+ return f
+
+
+@run_apply_patterns
+def non_configurable_patterns():
+ # CHECK-LABEL: TEST: non_configurable_patterns
+ # CHECK: apply_patterns
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_to_fma
+ x86vector.ApplyVectorContractToFMAPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ x86vector.ApplyVectorContractToPackedTypeDotProductPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ x86vector.ApplyVectorContractBF16ToFMAPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.sink_vector_producer_ops
+ x86vector.ApplySinkVectorProducerOpsPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ x86vector.ApplyShuffleVectorFMAOpsPatternsOp()
diff --git a/mlir/test/python/dialects/x86vector.py b/mlir/test/python/dialects/x86vector.py
new file mode 100644
index 0000000000000..c7d680792fb66
--- /dev/null
+++ b/mlir/test/python/dialects/x86vector.py
@@ -0,0 +1,72 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.builtin as builtin
+import mlir.dialects.func as func
+import mlir.dialects.x86vector as x86vector
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ f()
+ return f
+
+
+# CHECK-LABEL: TEST: testAvxOp
+@run
+def testAvxOp():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
+ def avx_op(arg):
+ return x86vector.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
+
+ # CHECK-LABEL: func @avx_op(
+ # CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]]
+ # CHECK: return %[[VAL]] : vector<8xf32>
+ # CHECK: }
+ print(module)
+
+# CHECK-LABEL: TEST: testAvx512Op
+@run
+def testAvx512Op():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
+ def avx512_op(arg):
+ return x86vector.CvtPackedF32ToBF16Op(a=arg, dst=VectorType.get((8,), BF16Type.get()))
+
+ # CHECK-LABEL: func @avx512_op(
+ # CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
+ # CHECK: return %[[VAL]] : vector<8xbf16>
+ # CHECK: }
+ print(module)
+
+# CHECK-LABEL: TEST: testAvx10Op
+@run
+def testAvx10Op():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ VectorType.get((16,), IntegerType.get(32)),
+ VectorType.get((64,), IntegerType.get(8)),
+ VectorType.get((64,), IntegerType.get(8)),
+ )
+ def avx10_op(*args):
+ return x86vector.AVX10DotInt8Op(
+ w=args[0], a=args[1], b=args[2]
+ )
+
+ # CHECK-LABEL: func @avx10_op(
+ # CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
+ # CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
+ # CHECK: return %[[VAL]] : vector<16xi32>
+ # CHECK: }
+ print(module)
|
Member
|
@llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesRegisters python bindings for x86vector dialect and transform ops. Full diff: https://github.com/llvm/llvm-project/pull/179958.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
index 891829fca017f..d57ed1f1cd171 100644
--- a/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td
@@ -73,4 +73,3 @@ def ApplyShuffleVectorFMAOpsPatternsOp : Op<Transform_Dialect,
#endif // X86VECTOR_TRANSFORM_OPS
-
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8ab145ada85dd..50143f700f5a1 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -325,6 +325,15 @@ declare_mlir_dialect_extension_python_bindings(
"../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
)
+declare_mlir_dialect_extension_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/X86VectorTransformOps.td
+ SOURCES
+ dialects/transform/x86vector.py
+ DIALECT_NAME transform
+ EXTENSION_NAME x86vector_transform)
+
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -510,6 +519,13 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS_TD_FILE
"dialects/VectorAttributes.td")
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/X86Vector.td
+ SOURCES dialects/x86vector.py
+ DIALECT_NAME x86vector)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/X86Vector.td b/mlir/python/mlir/dialects/X86Vector.td
new file mode 100644
index 0000000000000..d8a846bf9e905
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86Vector.td
@@ -0,0 +1,14 @@
+//===-- X86Vector.td - Entry point for x86vector bindings --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTOR
+#define PYTHON_BINDINGS_X86VECTOR
+
+include "mlir/Dialect/X86Vector/X86Vector.td"
+
+#endif // PYTHON_BINDINGS_X86VECTOR
diff --git a/mlir/python/mlir/dialects/X86VectorTransformOps.td b/mlir/python/mlir/dialects/X86VectorTransformOps.td
new file mode 100644
index 0000000000000..ad6a693923703
--- /dev/null
+++ b/mlir/python/mlir/dialects/X86VectorTransformOps.td
@@ -0,0 +1,14 @@
+//===-- X86VectorTransformOps.td ---------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+#define PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
+
+include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.td"
+
+#endif // PYTHON_BINDINGS_X86VECTORTRANSFORMOPS
diff --git a/mlir/python/mlir/dialects/transform/x86vector.py b/mlir/python/mlir/dialects/transform/x86vector.py
new file mode 100644
index 0000000000000..cccd300522797
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/x86vector.py
@@ -0,0 +1,5 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._x86vector_transform_ops_gen import *
diff --git a/mlir/python/mlir/dialects/x86vector.py b/mlir/python/mlir/dialects/x86vector.py
new file mode 100644
index 0000000000000..eddc93dbe6460
--- /dev/null
+++ b/mlir/python/mlir/dialects/x86vector.py
@@ -0,0 +1,6 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._x86vector_ops_gen import *
+from ._x86vector_ops_gen import _Dialect
diff --git a/mlir/test/python/dialects/transform_x86vector_ext.py b/mlir/test/python/dialects/transform_x86vector_ext.py
new file mode 100644
index 0000000000000..ad8dab8175ef2
--- /dev/null
+++ b/mlir/test/python/dialects/transform_x86vector_ext.py
@@ -0,0 +1,40 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import x86vector
+
+
+def run_apply_patterns(f):
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ apply = transform.ApplyPatternsOp(sequence.bodyTarget)
+ with InsertionPoint(apply.patterns):
+ f()
+ transform.YieldOp()
+ print("\nTEST:", f.__name__)
+ print(module)
+ return f
+
+
+@run_apply_patterns
+def non_configurable_patterns():
+ # CHECK-LABEL: TEST: non_configurable_patterns
+ # CHECK: apply_patterns
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_to_fma
+ x86vector.ApplyVectorContractToFMAPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_to_packed_type_dot_product
+ x86vector.ApplyVectorContractToPackedTypeDotProductPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.vector_contract_bf16_to_fma
+ x86vector.ApplyVectorContractBF16ToFMAPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.sink_vector_producer_ops
+ x86vector.ApplySinkVectorProducerOpsPatternsOp()
+ # CHECK: transform.apply_patterns.x86vector.shuffle_vector_fma_ops
+ x86vector.ApplyShuffleVectorFMAOpsPatternsOp()
diff --git a/mlir/test/python/dialects/x86vector.py b/mlir/test/python/dialects/x86vector.py
new file mode 100644
index 0000000000000..c7d680792fb66
--- /dev/null
+++ b/mlir/test/python/dialects/x86vector.py
@@ -0,0 +1,72 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.builtin as builtin
+import mlir.dialects.func as func
+import mlir.dialects.x86vector as x86vector
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ f()
+ return f
+
+
+# CHECK-LABEL: TEST: testAvxOp
+@run
+def testAvxOp():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
+ def avx_op(arg):
+ return x86vector.BcstToPackedF32Op(a=arg, dst=VectorType.get((8,), F32Type.get()))
+
+ # CHECK-LABEL: func @avx_op(
+ # CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]]
+ # CHECK: return %[[VAL]] : vector<8xf32>
+ # CHECK: }
+ print(module)
+
+# CHECK-LABEL: TEST: testAvx512Op
+@run
+def testAvx512Op():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
+ def avx512_op(arg):
+ return x86vector.CvtPackedF32ToBF16Op(a=arg, dst=VectorType.get((8,), BF16Type.get()))
+
+ # CHECK-LABEL: func @avx512_op(
+ # CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
+ # CHECK: return %[[VAL]] : vector<8xbf16>
+ # CHECK: }
+ print(module)
+
+# CHECK-LABEL: TEST: testAvx10Op
+@run
+def testAvx10Op():
+ module = Module.create()
+ with InsertionPoint(module.body):
+
+ @func.FuncOp.from_py_func(
+ VectorType.get((16,), IntegerType.get(32)),
+ VectorType.get((64,), IntegerType.get(8)),
+ VectorType.get((64,), IntegerType.get(8)),
+ )
+ def avx10_op(*args):
+ return x86vector.AVX10DotInt8Op(
+ w=args[0], a=args[1], b=args[2]
+ )
+
+ # CHECK-LABEL: func @avx10_op(
+ # CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
+ # CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
+ # CHECK: %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
+ # CHECK: return %[[VAL]] : vector<16xi32>
+ # CHECK: }
+ print(module)
|
|
✅ With the latest revision this PR passed the Python code formatter. |
rengolin
reviewed
Feb 5, 2026
AlexisPerry
pushed a commit
to llvm-project-tlp/llvm-project
that referenced
this pull request
Feb 6, 2026
Registers python bindings for x86vector dialect and transform ops.
rishabhmadan19
pushed a commit
to rishabhmadan19/llvm-project
that referenced
this pull request
Feb 9, 2026
Registers python bindings for x86vector dialect and transform ops.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Registers python bindings for x86vector dialect and transform ops.