From bdf54e0256ba0daeda06ac943f84688c6a65abd3 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Tue, 11 Nov 2025 13:51:06 -0800 Subject: [PATCH 1/4] [mlir][linalg][python] Add Python bindings for inferring contraction dimensions from affine maps Signed-off-by: Bangtian Liu --- mlir/include/mlir-c/Dialect/Linalg.h | 5 +++ mlir/lib/Bindings/Python/DialectLinalg.cpp | 23 ++++++++++++ mlir/lib/CAPI/Dialect/Linalg.cpp | 35 ++++++++++++++++++ mlir/test/python/dialects/linalg/utils.py | 41 ++++++++++++++++++++++ 4 files changed, 104 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 339e63d667c5e..42d3e494fe9b4 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -10,6 +10,7 @@ #ifndef MLIR_C_DIALECT_LINALG_H #define MLIR_C_DIALECT_LINALG_H +#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions { MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensions(MlirOperation op); +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps, + intptr_t numMaps); + MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op); typedef struct MlirLinalgConvolutionDimensions { diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 015502371c65b..56fa1ace995c8 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -80,6 +80,29 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { "op.", nb::arg("op")); + m.def( + "infer_contraction_dimensions_from_maps", + [](std::vector indexingMaps) + -> std::optional { + if (indexingMaps.empty()) + return std::nullopt; + + MlirLinalgContractionDimensions dims = + mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(), + indexingMaps.size()); + + // Detect "empty" result. This occurs when the input is invalid + // or when `linalg::inferContractionDims` fails. + if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) && + mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) { + return std::nullopt; + } + return dims; + }, + "Infers contraction dimensions (batch/m/n/k) from a list of affine " + "maps.", + nb::arg("indexing_maps")); + m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp, "Checks if the given operation is a Linalg convolution operation.", nb::arg("op")); diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 5c2a65d2c4c8a..e3a0b0977048b 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Linalg.h" +#include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -75,6 +76,40 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { return result; } +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps, + intptr_t numMaps) { + MlirLinalgContractionDimensions result{}; + if (!indexingMaps || numMaps <= 0) + return result; + + SmallVector maps; + maps.reserve(numMaps); + for (intptr_t i = 0; i < numMaps; ++i) { + maps.push_back(unwrap(indexingMaps[i])); + } + + FailureOr maybeDims = + linalg::inferContractionDims(maps); + if (failed(maybeDims)) + return result; + + const linalg::ContractionDimensions &contractionDims = *maybeDims; + MLIRContext *ctx = maps[0].getContext(); + + auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { + return wrap( + DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + }; + + result.batch = toAttr(contractionDims.batch); + result.m = toAttr(contractionDims.m); + result.n = toAttr(contractionDims.n); + result.k = toAttr(contractionDims.k); + + return result; +} + MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) { auto linalgOp = llvm::dyn_cast(unwrap(op)); if (!linalgOp) diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 5f7cb6a6c83cb..2e2017925826c 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -208,3 +208,44 @@ def matmul_func(a, b, c): assert maps[0] == a_map assert maps[1] == b_map assert maps[2] == c_map + + +@run +def test_infer_contraction_dimensions_from_maps(): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + # === Test valid contraction (matmul) === + dim_m = AffineDimExpr.get(0) + dim_n = AffineDimExpr.get(1) + dim_k = AffineDimExpr.get(2) + a_map = AffineMap.get(3, 0, [dim_m, dim_k]) + b_map = AffineMap.get(3, 0, [dim_k, dim_n]) + c_map = AffineMap.get(3, 0, [dim_m, dim_n]) + + dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map]) + assert dims is not None + + # Expect m=[0], n=[1], k=[2] as per standard matmul. + assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}" + assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}" + assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}" + assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}" + + # === Test invalid input (wrong number of maps) === + invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map]) + assert invalid_dims is None + + # === Test element-wise operation === + # All dimensions appear in all operands, so they're batch dimensions. + dim_i = AffineDimExpr.get(0) + dim_j = AffineDimExpr.get(1) + elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j]) + elementwise_dims = linalg.infer_contraction_dimensions_from_maps( + [elementwise_map, elementwise_map, elementwise_map] + ) + assert elementwise_dims is not None + assert list(elementwise_dims.m) == [] + assert list(elementwise_dims.n) == [] + assert list(elementwise_dims.k) == [] + assert list(elementwise_dims.batch) == [0, 1] From 0fb90b12faed294985bbbbcd458016108a21e660 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Tue, 11 Nov 2025 15:55:18 -0800 Subject: [PATCH 2/4] address reviewer comments Signed-off-by: Bangtian Liu --- mlir/include/mlir-c/Dialect/Linalg.h | 2 +- mlir/lib/CAPI/Dialect/Linalg.cpp | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 42d3e494fe9b4..e2cc4a7da7201 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -36,7 +36,7 @@ MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensions(MlirOperation op); MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions -mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps, +mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, intptr_t numMaps); MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op); diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index e3a0b0977048b..9d4307eb551dd 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -77,7 +77,7 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { } MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions -mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps, +mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, intptr_t numMaps) { MlirLinalgContractionDimensions result{}; if (!indexingMaps || numMaps <= 0) @@ -94,7 +94,6 @@ mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps, if (failed(maybeDims)) return result; - const linalg::ContractionDimensions &contractionDims = *maybeDims; MLIRContext *ctx = maps[0].getContext(); auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { @@ -102,10 +101,10 @@ mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps, DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); }; - result.batch = toAttr(contractionDims.batch); - result.m = toAttr(contractionDims.m); - result.n = toAttr(contractionDims.n); - result.k = toAttr(contractionDims.k); + result.batch = toAttr(maybeDims->batch); + result.m = toAttr(maybeDims->m); + result.n = toAttr(maybeDims->n); + result.k = toAttr(maybeDims->k); return result; } From dd523360274983cdc5c22d469249d3d1e2660755 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 12 Nov 2025 07:23:36 -0800 Subject: [PATCH 3/4] format the code Signed-off-by: Bangtian Liu --- mlir/lib/Bindings/Python/DialectLinalg.cpp | 3 +-- mlir/lib/CAPI/Dialect/Linalg.cpp | 10 ++++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 56fa1ace995c8..0b079b404d42d 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -91,8 +91,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(), indexingMaps.size()); - // Detect "empty" result. This occurs when the input is invalid - // or when `linalg::inferContractionDims` fails. + // Detect "empty" result from invalid input or failed inference. if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) && mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) { return std::nullopt; diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 9d4307eb551dd..5d02f5dc6b7c5 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -63,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { const linalg::ContractionDimensions &contractionDims = *maybeDims; MLIRContext *ctx = linalgOp.getContext(); - auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { - return wrap( - DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + auto toAttr = [ctx](ArrayRef vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); }; result.batch = toAttr(contractionDims.batch); @@ -96,9 +95,8 @@ mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, MLIRContext *ctx = maps[0].getContext(); - auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { - return wrap( - DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + auto toAttr = [ctx](ArrayRef vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); }; result.batch = toAttr(maybeDims->batch); From 21cc334111968ceb52cbb9dae225cc0b9d605fab Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 12 Nov 2025 09:12:30 -0800 Subject: [PATCH 4/4] use size_t Signed-off-by: Bangtian Liu --- mlir/include/mlir-c/Dialect/Linalg.h | 2 +- mlir/lib/CAPI/Dialect/Linalg.cpp | 8 ++++---- mlir/test/python/dialects/linalg/utils.py | 7 +++---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index e2cc4a7da7201..003b0cde39652 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -37,7 +37,7 @@ mlirLinalgInferContractionDimensions(MlirOperation op); MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, - intptr_t numMaps); + size_t numMaps); MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op); diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 5d02f5dc6b7c5..75c811aed6cc5 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -77,14 +77,14 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps, - intptr_t numMaps) { + size_t numMaps) { MlirLinalgContractionDimensions result{}; - if (!indexingMaps || numMaps <= 0) + if (!indexingMaps || numMaps == 0) return result; - SmallVector maps; + SmallVector maps; maps.reserve(numMaps); - for (intptr_t i = 0; i < numMaps; ++i) { + for (size_t i = 0; i < numMaps; ++i) { maps.push_back(unwrap(indexingMaps[i])); } diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 2e2017925826c..8ab53b4e28743 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -237,7 +237,6 @@ def test_infer_contraction_dimensions_from_maps(): assert invalid_dims is None # === Test element-wise operation === - # All dimensions appear in all operands, so they're batch dimensions. dim_i = AffineDimExpr.get(0) dim_j = AffineDimExpr.get(1) elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j]) @@ -245,7 +244,7 @@ def test_infer_contraction_dimensions_from_maps(): [elementwise_map, elementwise_map, elementwise_map] ) assert elementwise_dims is not None - assert list(elementwise_dims.m) == [] - assert list(elementwise_dims.n) == [] - assert list(elementwise_dims.k) == [] + assert len(elementwise_dims.m) == 0 + assert len(elementwise_dims.n) == 0 + assert len(elementwise_dims.k) == 0 assert list(elementwise_dims.batch) == [0, 1]