Skip to content

Commit

Permalink
[flang][hlfir] allow recursive intrinsic lowering
Browse files Browse the repository at this point in the history
We need to allow recursive application of intrinsic lowering patterns,
otherwise we cannot lower nested calls of the same intrinsic e.g.
matmul(matmul(a, b), c).

matmul(matmul(a, b), matmul(c, d)) requires hlfir.associate of hlfir
expr with more than one use (TODO).

Differential Revision: https://reviews.llvm.org/D152284
  • Loading branch information
tblah committed Jun 7, 2023
1 parent 89227b6 commit 6bcfab3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
19 changes: 17 additions & 2 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -1,4 +1,4 @@
//===- LowerHLFIRIntrinsics.cpp - Bufferize HLFIR ------------------------===//
//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -37,7 +37,22 @@ namespace {
/// runtime calls
template <class OP>
class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
using mlir::OpRewritePattern<OP>::OpRewritePattern;
public:
explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
: mlir::OpRewritePattern<OP>{ctx} {
// required for cases where intrinsics are chained together e.g.
// matmul(matmul(a, b), c)
// because converting the inner operation then invalidates the
// outer operation: causing the pattern to apply recursively.
//
// This is safe because we always progress with each iteration. Circular
// applications of operations are not expressible in MLIR because we use
// an SSA form and one must become first. E.g.
// %a = hlfir.matmul %b %d
// %b = hlfir.matmul %a %d
// cannot be written.
mlir::OpConversionPattern<OP>::setHasBoundedRewriteRecursion(true);
}

protected:
struct IntrinsicArgument {
Expand Down
36 changes: 36 additions & 0 deletions flang/test/HLFIR/matmul-lowering.fir
Expand Up @@ -43,3 +43,39 @@ func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lh
// CHECK: hlfir.destroy %[[ASEXPR]]
// CHECK-NEXT: return
// CHECK-NEXT: }

// nested matmuls leading to recursive pattern application
func.func @_QPtest(%arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"}, %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"}, %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
%c3 = arith.constant 3 : index
%c3_0 = arith.constant 3 : index
%0 = fir.shape %c3, %c3_0 : (index, index) -> !fir.shape<2>
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
%c3_1 = arith.constant 3 : index
%c3_2 = arith.constant 3 : index
%2 = fir.shape %c3_1, %c3_2 : (index, index) -> !fir.shape<2>
%3:2 = hlfir.declare %arg1(%2) {uniq_name = "_QFtestEb"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
%c3_3 = arith.constant 3 : index
%c3_4 = arith.constant 3 : index
%4 = fir.shape %c3_3, %c3_4 : (index, index) -> !fir.shape<2>
%5:2 = hlfir.declare %arg2(%4) {uniq_name = "_QFtestEc"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
%c3_5 = arith.constant 3 : index
%c3_6 = arith.constant 3 : index
%6 = fir.shape %c3_5, %c3_6 : (index, index) -> !fir.shape<2>
%7:2 = hlfir.declare %arg3(%6) {uniq_name = "_QFtestEout"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
%8 = hlfir.matmul %1#0 %3#0 {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
%9 = hlfir.matmul %8 %5#0 {fastmath = #arith.fastmath<contract>} : (!hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
hlfir.assign %9 to %7#0 : !hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>
hlfir.destroy %9 : !hlfir.expr<3x3xf32>
hlfir.destroy %8 : !hlfir.expr<3x3xf32>
return
}
// just check that we apply the patterns successfully. The details are checked above
// CHECK-LABEL: func.func @_QPtest(
// CHECK: %arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"},
// CHECK-SAME: %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"},
// CHECK-SAME: %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"},
// CHECK-SAME: %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
// CHECK: fir.call @_FortranAMatmul(
// CHECK; fir.call @_FortranAMatmul(%40, %41, %42, %43, %c20_i32) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
// CHECK: return
// CHECK-NEXT: }

0 comments on commit 6bcfab3

Please sign in to comment.