-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][linalg][python] Add Python Bindings for Inferring Contraction Dimensions from Affine Maps #167587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Bangtian Liu (bangtianliu) ChangesThis patch exposes Full diff: https://github.com/llvm/llvm-project/pull/167587.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 339e63d667c5e..989835485bdd9 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -10,6 +10,7 @@
#ifndef MLIR_C_DIALECT_LINALG_H
#define MLIR_C_DIALECT_LINALG_H
+#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
@@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op);
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps,
+ intptr_t numMaps);
+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
typedef struct MlirLinalgConvolutionDimensions {
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 015502371c65b..c179dac3c4df1 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -80,6 +80,29 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"op.",
nb::arg("op"));
+ m.def(
+ "infer_contraction_dimensions_from_maps",
+ [](std::vector<MlirAffineMap> indexingMaps)
+ -> std::optional<MlirLinalgContractionDimensions> {
+ if (indexingMaps.empty())
+ return std::nullopt;
+
+ MlirLinalgContractionDimensions dims =
+ mlirLinalgInferContractionDimensionsFromMaps(
+ indexingMaps.data(), indexingMaps.size());
+
+ // Detect "empty" result. This occurs when the input is invalid
+ // or when `linalg::inferContractionDims` fails.
+ if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+ mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+ return std::nullopt;
+ }
+ return dims;
+ },
+ "Infers contraction dimensions (batch/m/n/k) from a list of affine "
+ "maps.",
+ nb::arg("indexing_maps"));
+
m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
"Checks if the given operation is a Linalg convolution operation.",
nb::arg("op"));
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 5c2a65d2c4c8a..ef2658167bb46 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Linalg.h"
+#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -75,6 +76,40 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
return result;
}
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps,
+ intptr_t numMaps) {
+ MlirLinalgContractionDimensions result{};
+ if (!indexingMaps || numMaps <= 0)
+ return result;
+
+ SmallVector<AffineMap> maps;
+ maps.reserve(numMaps);
+ for (intptr_t i = 0; i < numMaps; ++i) {
+ maps.push_back(unwrap(indexingMaps[i]));
+ }
+
+ FailureOr<linalg::ContractionDimensions> maybeDims =
+ linalg::inferContractionDims(maps);
+ if (failed(maybeDims))
+ return result;
+
+ const linalg::ContractionDimensions &contractionDims = *maybeDims;
+ MLIRContext *ctx = maps[0].getContext();
+
+ auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+ return wrap(
+ DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
+ };
+
+ result.batch = toAttr(contractionDims.batch);
+ result.m = toAttr(contractionDims.m);
+ result.n = toAttr(contractionDims.n);
+ result.k = toAttr(contractionDims.k);
+
+ return result;
+}
+
MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
index 5f7cb6a6c83cb..c0fd0efed7c65 100644
--- a/mlir/test/python/dialects/linalg/utils.py
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -208,3 +208,45 @@ def matmul_func(a, b, c):
assert maps[0] == a_map
assert maps[1] == b_map
assert maps[2] == c_map
+
+
+@run
+def test_infer_contraction_dimensions_from_maps():
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ # === Test valid contraction (matmul) ===
+ dim_m = AffineDimExpr.get(0)
+ dim_n = AffineDimExpr.get(1)
+ dim_k = AffineDimExpr.get(2)
+ a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+ b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+ c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+ dims = linalg.infer_contraction_dimensions_from_maps(
+ [a_map, b_map, c_map]
+ )
+ assert dims is not None
+
+ # Expect m=[0], n=[1], k=[2] as per standard matmul.
+ assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+ assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+ assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+ assert (
+ list(dims.batch) == []
+ ), f"Expected batch=[], got {list(dims.batch)}"
+
+ # === Test invalid input (wrong number of maps) ===
+ invalid_dims = linalg.infer_contraction_dimensions_from_maps(
+ [a_map, b_map]
+ )
+ assert invalid_dims is None
+
+ # === Test non-contraction (element-wise operation) ===
+ dim_i = AffineDimExpr.get(0)
+ dim_j = AffineDimExpr.get(1)
+ elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
+ non_contraction_dims = linalg.infer_contraction_dimensions_from_maps(
+ [elementwise_map, elementwise_map, elementwise_map]
+ )
+ assert non_contraction_dims is None
|
5ba54b2 to
1a520ea
Compare
|
✅ With the latest revision this PR passed the Python code formatter. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
9984adf to
f875428
Compare
6bc3096 to
a8d064e
Compare
f1d9e3a to
b76fce9
Compare
makslevental
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
rolfmorel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
0d047ad to
e9f7979
Compare
39a3d04 to
21cc334
Compare
…dimensions from affine maps Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
…Dimensions from Affine Maps (llvm#167587) This PR exposes `linalg::inferContractionDims(ArrayRef<AffineMap>)` to Python, allowing users to infer contraction dimensions (batch/m/n/k) directly from a list of affine maps without needing an operation. --------- Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
This PR exposes
linalg::inferContractionDims(ArrayRef<AffineMap>)to Python, allowing users to infer contraction dimensions (batch/m/n/k) directly from a list of affine maps without needing an operation.