Skip to content

Commit

Permalink
[mlir][Transforms] GreedyPatternRewriteDriver: Do not CSE constants…
Browse files Browse the repository at this point in the history
… during iterations (#75897)

The `GreedyPatternRewriteDriver` tries to iteratively fold ops and apply
rewrite patterns to ops. It has special handling for constants: they are
CSE'd and sometimes moved to parent regions to allow for additional
CSE'ing. This happens in `OperationFolder`.

To allow for efficient CSE'ing, `OperationFolder` maintains an internal
lookup data structure to find the existing constant ops with the same
value for each `IsolatedFromAbove` region:
```c++
/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes;
```

Rewrite patterns are allowed to modify operations. In particular, they
may move operations (including constants) from one region to another
one. Such an IR rewrite can make the above lookup data structure
inconsistent.

We encountered such a bug in a downstream project. This bug materialized
in the form of an op that uses the result of a constant op from a
different `IsolatedFromAbove` region (that is not accessible).

This commit changes the behavior of the `GreedyPatternRewriteDriver`
such that `OperationFolder` is used to CSE constants at the beginning of
each iteration (as the worklist is populated), but no longer during an
iteration. `OperationFolder` is no longer used after populating the
worklist, so we do not have to care about inconsistent state in the
`OperationFolder` due to IR rewrites. The `GreedyPatternRewriteDriver`
now performs the op folding by itself instead of calling
`OperationFolder::tryToFold`.

This change changes the order of constant ops in test cases, but not the
region in which they appear. All broken test cases were fixed by turning
`CHECK` into `CHECK-DAG`.

Alternatives considered: The state of `OperationFolder` could be
partially invalidated with every `notifyOperationModified` notification.
That is more fragile than the solution in this commit because incorrect
rewriter API usage can lead to missing notifications and hard-to-debug
`IsolatedFromAbove` violations. (It did not fix the above mention bug in
a downstream project, which could be due to incorrect rewriter API usage
or due to another conceptual problem that I missed.) Moreover, ops are
frequently getting modified during a greedy pattern rewrite, so we would
likely keep invalidating large parts of the state of `OperationFolder`
over and over.

Migration guide: Turn `CHECK` into `CHECK-DAG` in test cases. Constant
ops are no longer folded during a greedy pattern rewrite. If you rely on
folding (and rematerialization) of constant ops during a greedy pattern
rewrite, turn the folder into a pattern.
  • Loading branch information
matthias-springer committed Jan 5, 2024
1 parent e96e7a9 commit bb6d5c2
Show file tree
Hide file tree
Showing 30 changed files with 357 additions and 315 deletions.
54 changes: 27 additions & 27 deletions flang/test/Lower/array-temp.f90
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ subroutine ss4(N)

! CHECK-LABEL: func @_QPss2(
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK: %[[C_27_i32:[-0-9a-z_]+]] = arith.constant 27 : i32
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK-DAG: %[[C_27_i32:[-0-9a-z_]+]] = arith.constant 27 : i32
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index
Expand Down Expand Up @@ -137,15 +137,15 @@ subroutine ss4(N)

! CHECK-LABEL: func @_QPss3(
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK: %[[C_34_i32:[-0-9a-z_]+]] = arith.constant 34 : i32
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK-DAG: %[[C_34_i32:[-0-9a-z_]+]] = arith.constant 34 : i32
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index
Expand Down Expand Up @@ -263,15 +263,15 @@ subroutine ss4(N)

! CHECK-LABEL: func @_QPss4(
! CHECK-SAME: %arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
! CHECK: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK: %[[C_41_i32:[-0-9a-z_]+]] = arith.constant 41 : i32
! CHECK: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK-DAG: %[[C_2:[-0-9a-z_]+]] = arith.constant 2 : index
! CHECK-DAG: %[[C_m1:[-0-9a-z_]+]] = arith.constant -1 : index
! CHECK-DAG: %[[C_1:[-0-9a-z_]+]] = arith.constant 1 : index
! CHECK-DAG: %[[C_41_i32:[-0-9a-z_]+]] = arith.constant 41 : i32
! CHECK-DAG: %[[C_6_i32:[-0-9a-z_]+]] = arith.constant 6 : i32
! CHECK-DAG: %[[C_st:[-0-9a-z_]+]] = arith.constant 7.000000e+00 : f32
! CHECK-DAG: %[[C_1_i32:[-0-9a-z_]+]] = arith.constant 1 : i32
! CHECK-DAG: %[[C_st_0:[-0-9a-z_]+]] = arith.constant -2.000000e+00 : f32
! CHECK-DAG: %[[C_0:[-0-9a-z_]+]] = arith.constant 0 : index
! CHECK: %[[V_0:[0-9]+]] = fir.load %arg0 : !fir.ref<i32>
! CHECK: %[[V_1:[0-9]+]] = fir.convert %[[V_0:[0-9]+]] : (i32) -> index
! CHECK: %[[V_2:[0-9]+]] = arith.cmpi sgt, %[[V_1]], %[[C_0]] : index
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
if (origYield.getDefiningOp() == peeledScalarOperation) {
yieldedVals.push_back(origYield);
} else {
// Do not materialize any new ops inside of the decomposed LinalgOp,
// as that would trigger another application of the rewrite pattern
// (infinite loop).
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(peeledGenericOp);
yieldedVals.push_back(
getZero(rewriter, genericOp.getLoc(), origYield.getType()));
}
Expand Down
84 changes: 62 additions & 22 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
Worklist worklist;
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED

/// Non-pattern based folder for operations.
OperationFolder folder;

/// Configuration information for how to simplify.
const GreedyRewriteConfig config;

Expand Down Expand Up @@ -358,7 +355,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
: PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
: PatternRewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, debugFingerPrints(this)
Expand Down Expand Up @@ -429,15 +426,55 @@ bool GreedyPatternRewriteDriver::processWorklist() {
continue;
}

// Try to fold this op.
if (succeeded(folder.tryToFold(op))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
changed = true;
// Try to fold this op. Do not fold constant ops. That would lead to an
// infinite folding loop, as every constant op would be folded to an
// Attribute and then immediately be rematerialized as a constant op, which
// is then put on the worklist.
if (!op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
changed = true;
if (foldResults.empty()) {
// Op was modified in-place.
notifyOperationModified(op);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
continue;
}

// Op results can be replaced with `foldResults`.
assert(foldResults.size() == op->getNumResults() &&
"folder produced incorrect number of results");
OpBuilder::InsertionGuard g(*this);
setInsertionPoint(op);
SmallVector<Value> replacements;
for (auto [ofr, resultType] :
llvm::zip_equal(foldResults, op->getResultTypes())) {
if (auto value = ofr.dyn_cast<Value>()) {
assert(value.getType() == resultType &&
"folder produced value of incorrect type");
replacements.push_back(value);
continue;
}
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
"materializeConstant produced op that is not a ConstantLike");
assert(constOp->getResultTypes()[0] == resultType &&
"materializeConstant produced incorrect result type");
replacements.push_back(constOp->getResult(0));
}
replaceOp(op, replacements);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
}

// Try to match one of the patterns. The rewriter is automatically
Expand Down Expand Up @@ -592,7 +629,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {

addOperandsToWorklist(op->getOperands());
worklist.remove(op);
folder.notifyRemoval(op);

if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
Expand Down Expand Up @@ -672,16 +708,6 @@ class GreedyPatternRewriteIteration
} // namespace

LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
Attribute constValue;
if (matchPattern(op, m_Constant(&constValue)))
if (!folder.insertKnownConstant(op, constValue))
return true;
return false;
};

bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
Expand All @@ -691,8 +717,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;

// New iteration: start with an empty worklist.
worklist.clear();

// `OperationFolder` CSE's constant ops (and may move them into parents
// regions to enable more aggressive CSE'ing).
OperationFolder folder(getContext(), this);
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
Attribute constValue;
if (matchPattern(op, m_Constant(&constValue)))
if (!folder.insertKnownConstant(op, constValue))
return true;
return false;
};

if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb

// CHECK-LABEL: func.func @broadcast_vec2d_from_i32(
// CHECK-SAME: %[[SRC:.*]]: i32) {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
Expand Down Expand Up @@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {

// CHECK-LABEL: func.func @transpose_i8(
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>)
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
}
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>

// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
Expand Down
30 changes: 15 additions & 15 deletions mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
}

// CHECK-SAME: %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
// CHECK: %[[C_0:.*]] = arith.constant 0 : index
// CHECK: %[[C_16:.*]] = arith.constant 16 : index
// CHECK: %[[STEP:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C_16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
Expand All @@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
}
// CHECK-LABEL: func.func @vector_print_vector_0d(
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
Expand All @@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
}
// CHECK-LABEL: func.func @vector_print_vector(
// CHECK-SAME: %[[VEC:.*]]: vector<2x2xf32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
Expand Down Expand Up @@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
}
// CHECK-LABEL: func.func @transfer_read_array_of_scalable(
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
Expand Down Expand Up @@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
// CHECK-LABEL: func.func @transfer_write_array_of_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
// -----

// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
Expand All @@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
// -----

// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
Expand Down

0 comments on commit bb6d5c2

Please sign in to comment.