Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir-c/Dialect/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#ifndef MLIR_C_DIALECT_LINALG_H
#define MLIR_C_DIALECT_LINALG_H

#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"

Expand All @@ -34,6 +35,10 @@ typedef struct MlirLinalgContractionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensions(MlirOperation op);

MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
size_t numMaps);

MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);

typedef struct MlirLinalgConvolutionDimensions {
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Bindings/Python/DialectLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,28 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"op.",
nb::arg("op"));

m.def(
"infer_contraction_dimensions_from_maps",
[](std::vector<MlirAffineMap> indexingMaps)
-> std::optional<MlirLinalgContractionDimensions> {
if (indexingMaps.empty())
return std::nullopt;

MlirLinalgContractionDimensions dims =
mlirLinalgInferContractionDimensionsFromMaps(indexingMaps.data(),
indexingMaps.size());

// Detect "empty" result from invalid input or failed inference.
if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
return dims;
},
"Infers contraction dimensions (batch/m/n/k) from a list of affine "
"maps.",
nb::arg("indexing_maps"));

m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
"Checks if the given operation is a Linalg convolution operation.",
nb::arg("op"));
Expand Down
38 changes: 35 additions & 3 deletions mlir/lib/CAPI/Dialect/Linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/Linalg.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"

Expand Down Expand Up @@ -62,9 +63,8 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
const linalg::ContractionDimensions &contractionDims = *maybeDims;
MLIRContext *ctx = linalgOp.getContext();

auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
return wrap(
DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
};

result.batch = toAttr(contractionDims.batch);
Expand All @@ -75,6 +75,38 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
return result;
}

MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
mlirLinalgInferContractionDimensionsFromMaps(const MlirAffineMap *indexingMaps,
size_t numMaps) {
MlirLinalgContractionDimensions result{};
if (!indexingMaps || numMaps == 0)
return result;

SmallVector<AffineMap, 3> maps;
maps.reserve(numMaps);
for (size_t i = 0; i < numMaps; ++i) {
maps.push_back(unwrap(indexingMaps[i]));
}

FailureOr<linalg::ContractionDimensions> maybeDims =
linalg::inferContractionDims(maps);
if (failed(maybeDims))
return result;

MLIRContext *ctx = maps[0].getContext();

auto toAttr = [ctx](ArrayRef<unsigned> vals) -> MlirAttribute {
return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
};

result.batch = toAttr(maybeDims->batch);
result.m = toAttr(maybeDims->m);
result.n = toAttr(maybeDims->n);
result.k = toAttr(maybeDims->k);

return result;
}

MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
if (!linalgOp)
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/python/dialects/linalg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,43 @@ def matmul_func(a, b, c):
assert maps[0] == a_map
assert maps[1] == b_map
assert maps[2] == c_map


@run
def test_infer_contraction_dimensions_from_maps():
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
# === Test valid contraction (matmul) ===
dim_m = AffineDimExpr.get(0)
dim_n = AffineDimExpr.get(1)
dim_k = AffineDimExpr.get(2)
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
c_map = AffineMap.get(3, 0, [dim_m, dim_n])

dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map, c_map])
assert dims is not None

# Expect m=[0], n=[1], k=[2] as per standard matmul.
assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
assert list(dims.batch) == [], f"Expected batch=[], got {list(dims.batch)}"

# === Test invalid input (wrong number of maps) ===
invalid_dims = linalg.infer_contraction_dimensions_from_maps([a_map, b_map])
assert invalid_dims is None

# === Test element-wise operation ===
dim_i = AffineDimExpr.get(0)
dim_j = AffineDimExpr.get(1)
elementwise_map = AffineMap.get(2, 0, [dim_i, dim_j])
elementwise_dims = linalg.infer_contraction_dimensions_from_maps(
[elementwise_map, elementwise_map, elementwise_map]
)
assert elementwise_dims is not None
assert len(elementwise_dims.m) == 0
assert len(elementwise_dims.n) == 0
assert len(elementwise_dims.k) == 0
assert list(elementwise_dims.batch) == [0, 1]
Loading