Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mlir/include/mlir/Transforms/FoldUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ class OperationFolder {
/// deduplicated constants. If successful, replaces `op`'s uses with
/// folded results, and returns success. If the op was completely folded it is
/// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr);
/// On success() and when in-place, the folder is invoked until
/// `maxIterations` is reached (default INT_MAX).
LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr,
int maxIterations = INT_MAX);

/// Tries to fold a pre-existing constant operation. `constValue` represents
/// the value of the constant, and can be optionally passed if the value is
Expand Down Expand Up @@ -82,7 +85,10 @@ class OperationFolder {

/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results);
/// On success() and when in-place, the folder is invoked until
/// `maxIterations` is reached (default INT_MAX).
LogicalResult tryToFold(Operation *op, SmallVectorImpl<Value> &results,
int maxIterations = INT_MAX);

/// Try to process a set of fold results. Populates `results` on success,
/// otherwise leaves it unchanged.
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/DebugLog.h"

using namespace mlir;

Expand Down Expand Up @@ -486,9 +487,25 @@ OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,

// Try to fold the operation.
SmallVector<OpFoldResult, 4> foldResults;
LDBG() << "Trying to fold: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
if (op->getName().getStringRef() == "vector.extract") {
Operation *parent = op->getParentOp();
while (parent && parent->getName().getStringRef() != "spirv.func")
parent = parent->getParentOp();
if (parent)
parent->dump();
}
Comment on lines +493 to +498
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be removed?

Copy link
Collaborator Author

@joker-eph joker-eph Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh... clearly!! (give me a min)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@joker-eph joker-eph Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e84dcba ; thanks for the report!!

if (failed(op->fold(foldResults)))
return cleanupFailure();

int count = 0;
do {
LDBG() << "Folded in place #" << count
<< " times: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
count++;
} while (foldResults.empty() && succeeded(op->fold(foldResults)));

// An in-place fold does not require generation of any constants.
if (foldResults.empty())
return success();
Expand Down
21 changes: 16 additions & 5 deletions mlir/lib/Transforms/Utils/FoldUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/DebugLog.h"

using namespace mlir;

Expand Down Expand Up @@ -67,7 +68,8 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
// OperationFolder
//===----------------------------------------------------------------------===//

LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate,
int maxIterations) {
if (inPlaceUpdate)
*inPlaceUpdate = false;

Expand All @@ -86,7 +88,7 @@ LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {

// Try to fold the operation.
SmallVector<Value, 8> results;
if (failed(tryToFold(op, results)))
if (failed(tryToFold(op, results, maxIterations)))
return failure();

// Check to see if the operation was just updated in place.
Expand Down Expand Up @@ -224,10 +226,19 @@ bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult OperationFolder::tryToFold(Operation *op,
SmallVectorImpl<Value> &results) {
SmallVectorImpl<Value> &results,
int maxIterations) {
SmallVector<OpFoldResult, 8> foldResults;
if (failed(op->fold(foldResults)) ||
failed(processFoldResults(op, results, foldResults)))
if (failed(op->fold(foldResults)))
return failure();
int count = 1;
do {
LDBG() << "Folded in place #" << count
<< " times: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
} while (count++ < maxIterations && foldResults.empty() &&
succeeded(op->fold(foldResults)));

if (failed(processFoldResults(op, results, foldResults)))
return failure();
return success();
}
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Arith/constant-fold.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Test with the default (one application of the folder) and then with 2 iterations.
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(test-single-fold))" | FileCheck %s --check-prefixes=CHECK,CHECK-ONE
// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(test-single-fold{max-iterations=2}))" | FileCheck %s --check-prefixes=CHECK,CHECK-TWO


// Folding entirely this requires to move the constant to the right
// before invoking the op-specific folder.
// With one iteration, we just push the constant to the right.
// With a second iteration, we actually fold the "add" (x+0->x)
// CHECK: func @recurse_fold_traits(%[[ARG0:.*]]: i32)
func.func @recurse_fold_traits(%arg0 : i32) -> i32 {
%cst0 = arith.constant 0 : i32
// CHECK-ONE: %[[ADD:.*]] = arith.addi %[[ARG0]],
%res = arith.addi %cst0, %arg0 : i32
// CHECK-ONE: return %[[ADD]] : i32
// CHECK-TWO: return %[[ARG0]] : i32
return %res : i32
}
8 changes: 2 additions & 6 deletions mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ gpu.module @test {
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
//CHECK: [[c32:%.+]] = arith.constant 32 : index
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
//CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
Expand All @@ -23,10 +21,8 @@ gpu.module @test {
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
//CHECK: [[c32:%.+]] = arith.constant 32 : index
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
//CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
Expand Down
6 changes: 2 additions & 4 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ gpu.module @test_round_robin_assignment {
//CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
//CHECK: [[C0:%.+]] = arith.constant 0 : index
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
//CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
//CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
//CHECK: [[C128:%.+]] = arith.constant 128 : index
//CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
//CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]]
//CHECK: [[C64_2:%.+]] = arith.constant 64 : index
//CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]]
//CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]]
//CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
Expand Down
20 changes: 7 additions & 13 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,10 @@ gpu.module @test_distribution {
//CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c0_1:%.+]] = arith.constant 0 : index
//CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index
//CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index
//CHECK: [[c64:%.+]] = arith.constant 64 : index
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y_0]], [[c64]]
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
//CHECK: [[c128:%.+]] = arith.constant 128 : index
//CHECK: [[off_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
//CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
//CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
%0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
%1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
Expand All @@ -349,13 +347,11 @@ gpu.module @test_distribution {
//CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
//CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
//CHECK: [[c32:%.+]] = arith.constant 32 : index
//CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]]
//CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
//CHECK: [[c32_1:%.+]] = arith.constant 32 : index
//CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]]
//CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c0_2:%.+]] = arith.constant 0 : index
//CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index
//CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index
//CHECK: [[c64:%.+]] = arith.constant 64 : index
//CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
//CHECK: [[c128:%.+]] = arith.constant 128 : index
Expand Down Expand Up @@ -412,11 +408,10 @@ gpu.module @test_distribution {
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
//CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
//CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
//CHECK-DAG: [[LY:%.+]] = index.mul [[IDY]], [[c32]]
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
//CHECK-DAG: [[MODY:%.+]] = index.remu [[LY]], [[c128]]
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
Expand All @@ -430,9 +425,8 @@ gpu.module @test_distribution {
//CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
//CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
//CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
//CHECK-DAG: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
//CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
//CHECK-DAG: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
//CHECK-DAG: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
//CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
//CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
Expand Down
16 changes: 5 additions & 11 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ gpu.module @test_1_1_assignment {
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
//CHECK: [[C0:%.+]] = arith.constant 0 : index
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
//CHECK: [[C256:%.+]] = arith.constant 256 : index
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
//CHECK: [[Y:%.+]] = index.remu [[LY]], [[C256]]
//CHECK: [[C128:%.+]] = arith.constant 128 : index
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
//CHECK: [[X:%.+]] = index.remu [[LX]], [[C128]]
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
Expand All @@ -37,17 +35,13 @@ gpu.module @test_1_1_assignment {
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
//CHECK: [[C0:%.+]] = arith.constant 0 : index
//CHECK: [[C0_2:%.+]] = arith.constant 0 : index
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_2]] : index
//CHECK: [[C256:%.+]] = arith.constant 256 : index
//CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]]
//CHECK: [[MODY:%.+]] = index.remu [[LY]], [[C256]]
//CHECK: [[C128:%.+]] = arith.constant 128 : index
//CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]]
//CHECK: [[MODX:%.+]] = index.remu [[LX]], [[C128]]
//CHECK: [[C0_3:%.+]] = arith.constant 0 : index
//CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_3]]
//CHECK: [[C0_4:%.+]] = arith.constant 0 : index
//CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_4]]
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[Y]], [[X]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
//CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[MODY]], [[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
Expand Down
10 changes: 9 additions & 1 deletion mlir/test/lib/Transforms/TestSingleFold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
public RewriterBase::Listener {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSingleFold)

TestSingleFold() = default;
TestSingleFold(const TestSingleFold &pass) : PassWrapper(pass) {}

StringRef getArgument() const final { return "test-single-fold"; }
StringRef getDescription() const final {
return "Test single-pass operation folding and dead constant elimination";
Expand All @@ -45,13 +48,18 @@ struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
if (it != existingConstants.end())
existingConstants.erase(it);
}

Option<int> maxIterations{*this, "max-iterations",
llvm::cl::desc("Max iterations in the tryToFold"),
llvm::cl::init(1)};
};
} // namespace

void TestSingleFold::foldOperation(Operation *op, OperationFolder &helper) {
// Attempt to fold the specified operation, including handling unused or
// duplicated constants.
(void)helper.tryToFold(op);
bool inPlaceUpdate = false;
(void)helper.tryToFold(op, &inPlaceUpdate, maxIterations);
}

void TestSingleFold::runOnOperation() {
Expand Down