Skip to content

Commit

Permalink
[MLIR] : Add integer mul in scf to openmp conversion
Browse files Browse the repository at this point in the history
Add conversion for integer multiplication in scf reductions in the
SCF to OpenMP dialect conversion.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D145948
  • Loading branch information
kiranchandramohan committed Mar 14, 2023
1 parent f51bdae commit c1125ae
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
Expand Up @@ -310,6 +310,10 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
return createDecl(builder, symbolTable, reduce,
builder.getFloatAttr(type, 1.0));
}
if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
return createDecl(builder, symbolTable, reduce,
builder.getIntegerAttr(type, 1));
}

// Match select-based min/max reductions.
bool isMin;
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Conversion/SCFToOpenMP/reductions.mlir
Expand Up @@ -81,6 +81,43 @@ func.func @reduction2(%arg0 : index, %arg1 : index, %arg2 : index,

// -----

// Check the generation of declaration for arith.muli.
// Mostly, the same check as above, except for the types,
// the name of the op and the init value.

// CHECK: omp.reduction.declare @[[$REDI:.*]] : i32

// CHECK: init
// CHECK: %[[INIT:.*]] = llvm.mlir.constant(1 : i32)
// CHECK: omp.yield(%[[INIT]] : i32)

// CHECK: combiner
// CHECK: ^{{.*}}(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
// CHECK: %[[RES:.*]] = arith.muli %[[ARG0]], %[[ARG1]]
// CHECK: omp.yield(%[[RES]] : i32)

// CHECK-NOT: atomic

// CHECK-LABEL: @reduction_muli
func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index,
%arg3 : index, %arg4 : index) {
%step = arith.constant 1 : index
%one = arith.constant 1 : i32
scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
step (%arg4, %step) init (%one) -> (i32) {
// CHECK: omp.reduction
%pow2 = arith.constant 2 : i32
scf.reduce(%pow2) : i32 {
^bb0(%lhs : i32, %rhs: i32):
%res = arith.muli %lhs, %rhs : i32
scf.reduce.return %res : i32
}
}
return
}

// -----

// Only check the declaration here, the rest is same as above.
// CHECK: omp.reduction.declare @{{.*}} : f32

Expand Down

0 comments on commit c1125ae

Please sign in to comment.