Skip to content

Conversation

@bangtianliu
Copy link
Contributor

@bangtianliu bangtianliu commented Nov 11, 2025

This PR exposes linalg::inferContractionDims(ArrayRef<AffineMap>) to Python, allowing users to infer contraction dimensions (batch/m/n/k) directly from a list of affine maps without needing an operation.

@llvmbot
Copy link
Member

llvmbot commented Nov 11, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Bangtian Liu (bangtianliu)

Changes

This patch exposes linalg::inferContractionDims(ArrayRef&lt;AffineMap&gt;) to Python, allowing users to infer contraction dimensions (batch/m/n/k) directly from a list of affine maps without needing an operation.


Full diff: https://github.com/llvm/llvm-project/pull/167587.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/Linalg.h (+5)
  • (modified) mlir/lib/Bindings/Python/DialectLinalg.cpp (+23)
  • (modified) mlir/lib/CAPI/Dialect/Linalg.cpp (+35)
  • (modified) mlir/test/python/dialects/linalg/utils.py (+42)
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 339e63d667c5e..989835485bdd9 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -10,6 +10,7 @@
 #ifndef MLIR_C_DIALECT_LINALG_H
 #define MLIR_C_DIALECT_LINALG_H
 
+#include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/Support.h"
 
@@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions {
 MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
 mlirLinalgInferContractionDimensions(MlirOperation op);
 
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps,
+                                              intptr_t numMaps);
+
 MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
 
 typedef struct MlirLinalgConvolutionDimensions {
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 015502371c65b..c179dac3c4df1 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -80,6 +80,29 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
         "op.",
         nb::arg("op"));
 
+  m.def(
+      "infer_contraction_dimensions_from_maps",
+      [](std::vector<MlirAffineMap> indexingMaps)
+          -> std::optional<MlirLinalgContractionDimensions> {
+        if (indexingMaps.empty())
+          return std::nullopt;
+
+        MlirLinalgContractionDimensions dims =
+            mlirLinalgInferContractionDimensionsFromMaps(
+                indexingMaps.data(), indexingMaps.size());
+
+        // Detect "empty" result. This occurs when the input is invalid
+        // or when `linalg::inferContractionDims` fails.
+        if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+            mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+          return std::nullopt;
+        }
+        return dims;
+      },
+      "Infers contraction dimensions (batch/m/n/k) from a list of affine "
+      "maps.",
+      nb::arg("indexing_maps"));
+
   m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
         "Checks if the given operation is a Linalg convolution operation.",
         nb::arg("op"));
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 5c2a65d2c4c8a..ef2658167bb46 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/Dialect/Linalg.h"
+#include "mlir/CAPI/AffineMap.h"
 #include "mlir/CAPI/Registration.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 
@@ -75,6 +76,40 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
   return result;
 }
 
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensionsFromMaps(MlirAffineMap const *indexingMaps,
+                                              intptr_t numMaps) {
+  MlirLinalgContractionDimensions result{};
+  if (!indexingMaps || numMaps <= 0)
+    return result;
+
+  SmallVector<AffineMap> maps;
+  maps.reserve(numMaps);
+  for (intptr_t i = 0; i < numMaps; ++i) {
+    maps.push_back(unwrap(indexingMaps[i]));
+  }
+
+  FailureOr<linalg::ContractionDimensions> maybeDims =
+      linalg::inferContractionDims(maps);
+  if (failed(maybeDims))
+    return result;
+
+  const linalg::ContractionDimensions &contractionDims = *maybeDims;
+  MLIRContext *ctx = maps[0].getContext();
+
+  auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+    return wrap(
+        DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
+  };
+
+  result.batch = toAttr(contractionDims.batch);
+  result.m = toAttr(contractionDims.m);
+  result.n = toAttr(contractionDims.n);
+  result.k = toAttr(contractionDims.k);
+
+  return result;
+}
+
 MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
   auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
   if (!linalgOp)
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
index 5f7cb6a6c83cb..c0fd0efed7c65 100644
--- a/mlir/test/python/dialects/linalg/utils.py
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -208,3 +208,45 @@ def matmul_func(a, b, c):
                 assert maps[0] == a_map
                 assert maps[1] == b_map
                 assert maps[2] == c_map
+
+
+@run
+def test_infer_contraction_dimensions_from_maps():
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            # === Test valid contraction (matmul) ===
+            dim_m = AffineDimExpr.get(0)
+            dim_n = AffineDimExpr.get(1)
+            dim_k = AffineDimExpr.get(2)
+            a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+            b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+            c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+            dims = linalg.infer_contraction_dimensions_from_maps(
+                [a_map, b_map, c_map]
+            )
+            assert dims is not None
+
+            # Expect m=[0], n=[1], k=[2] as per standard matmul.
+            assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+            assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+            assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+            assert (
+                list(dims.batch) == []
+            ), f"Expected batch=[], got {list(dims.batch)}"
+
+            # === Test invalid input (wrong number of maps) ===
+            invalid_dims = linalg.infer_contraction_dimensions_from_maps(
+                [a_map, b_map]
+            )
+            assert invalid_dims is None
+
+            # === Test non-contraction (element-wise operation) ===
+            dim_i = AffineDimExpr.get(0)
+            dim_j = AffineDimExpr.get(1)
+            elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
+            non_contraction_dims = linalg.infer_contraction_dimensions_from_maps(
+                [elementwise_map, elementwise_map, elementwise_map]
+            )
+            assert non_contraction_dims is None

@bangtianliu bangtianliu force-pushed the infer_contraction_dims branch from 5ba54b2 to 1a520ea Compare November 11, 2025 21:49
@github-actions
Copy link

github-actions bot commented Nov 11, 2025

✅ With the latest revision this PR passed the Python code formatter.

@github-actions
Copy link

github-actions bot commented Nov 11, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@bangtianliu bangtianliu force-pushed the infer_contraction_dims branch 4 times, most recently from 9984adf to f875428 Compare November 11, 2025 22:39
@bangtianliu bangtianliu force-pushed the infer_contraction_dims branch from 6bc3096 to a8d064e Compare November 11, 2025 23:52
@bangtianliu bangtianliu force-pushed the infer_contraction_dims branch 2 times, most recently from f1d9e3a to b76fce9 Compare November 12, 2025 00:27
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@bangtianliu bangtianliu requested a review from kuhar November 12, 2025 15:20
@bangtianliu bangtianliu force-pushed the infer_contraction_dims branch from 0d047ad to e9f7979 Compare November 12, 2025 17:09
@bangtianliu bangtianliu requested a review from kuhar November 12, 2025 17:09
@bangtianliu bangtianliu force-pushed the infer_contraction_dims branch 3 times, most recently from 39a3d04 to 21cc334 Compare November 12, 2025 18:19
…dimensions from affine maps

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
@bangtianliu bangtianliu merged commit a5a78d0 into llvm:main Nov 12, 2025
10 checks passed
git-crd pushed a commit to git-crd/crd-llvm-project that referenced this pull request Nov 13, 2025
…Dimensions from Affine Maps (llvm#167587)

This PR exposes `linalg::inferContractionDims(ArrayRef<AffineMap>)` to
Python, allowing users to infer contraction dimensions (batch/m/n/k)
directly from a list of affine maps without needing an operation.

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants