diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp index 8039cef87a625..b7c02e2c2cacf 100644 --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -477,6 +477,7 @@ struct IntrinsicLibrary { fir::ExtendedValue genLbound(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genLen(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genLenTrim(mlir::Type, llvm::ArrayRef); + fir::ExtendedValue genMatmul(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genMaxloc(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genMaxval(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genMinloc(mlir::Type, llvm::ArrayRef); @@ -692,6 +693,10 @@ static constexpr IntrinsicHandler handlers[]{ {"lgt", &I::genCharacterCompare}, {"lle", &I::genCharacterCompare}, {"llt", &I::genCharacterCompare}, + {"matmul", + &I::genMatmul, + {{{"matrix_a", asAddr}, {"matrix_b", asAddr}}}, + /*isElemental=*/false}, {"max", &I::genExtremum}, {"maxloc", &I::genMaxloc, @@ -2374,6 +2379,34 @@ IntrinsicLibrary::genCharacterCompare(mlir::Type type, fir::getBase(args[1]), fir::getLen(args[1])); } +// MATMUL +fir::ExtendedValue +IntrinsicLibrary::genMatmul(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 2); + + // Handle required matmul arguments + fir::BoxValue matrixTmpA = builder.createBox(loc, args[0]); + mlir::Value matrixA = fir::getBase(matrixTmpA); + fir::BoxValue matrixTmpB = builder.createBox(loc, args[1]); + mlir::Value matrixB = fir::getBase(matrixTmpB); + unsigned resultRank = + (matrixTmpA.rank() == 1 || matrixTmpB.rank() == 1) ? 1 : 2; + + // Create mutable fir.box to be passed to the runtime for the result. + mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, resultRank); + fir::MutableBoxValue resultMutableBox = + fir::factory::createTempMutableBox(builder, loc, resultArrayType); + mlir::Value resultIrBox = + fir::factory::getMutableIRBox(builder, loc, resultMutableBox); + // Call runtime. The runtime is allocating the result. + fir::runtime::genMatmul(builder, loc, resultIrBox, matrixA, matrixB); + // Read result from mutable fir.box and add it to the list of temps to be + // finalized by the StatementContext. + return readAndAddCleanUp(resultMutableBox, resultType, + "unexpected result for MATMUL"); +} + // Compare two FIR values and return boolean result as i1. template static mlir::Value createExtremumCompare(mlir::Location loc, diff --git a/flang/test/Lower/Intrinsics/matmul.f90 b/flang/test/Lower/Intrinsics/matmul.f90 new file mode 100644 index 0000000000000..6c3c721063c9f --- /dev/null +++ b/flang/test/Lower/Intrinsics/matmul.f90 @@ -0,0 +1,68 @@ +! RUN: bbc -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s + +! Test matmul intrinsic + +! CHECK-LABEL: matmul_test +! CHECK-SAME: (%[[X:.*]]: !fir.ref>{{.*}}, %[[Y:.*]]: !fir.ref>{{.*}}, %[[Z:.*]]: !fir.ref>{{.*}}) +! CHECK: %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box>> +! CHECK: %[[C3:.*]] = arith.constant 3 : index +! CHECK: %[[C1:.*]] = arith.constant 1 : index +! CHECK: %[[C1_0:.*]] = arith.constant 1 : index +! CHECK: %[[C3_1:.*]] = arith.constant 3 : index +! CHECK: %[[Z_BOX:.*]] = fir.array_load %[[Z]]({{.*}}) : (!fir.ref>, !fir.shape<2>) -> !fir.array<2x2xf32> +! CHECK: %[[X_SHAPE:.*]] = fir.shape %[[C3]], %[[C1]] : (index, index) -> !fir.shape<2> +! CHECK: %[[X_BOX:.*]] = fir.embox %[[X]](%[[X_SHAPE]]) : (!fir.ref>, !fir.shape<2>) -> !fir.box> +! CHECK: %[[Y_SHAPE:.*]] = fir.shape %[[C1_0]], %[[C3_1]] : (index, index) -> !fir.shape<2> +! CHECK: %[[Y_BOX:.*]] = fir.embox %[[Y]](%[[Y_SHAPE]]) : (!fir.ref>, !fir.shape<2>) -> !fir.box> +! CHECK: %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap> +! CHECK: %[[C0:.*]] = arith.constant 0 : index +! CHECK: %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]], %[[C0]] : (index, index) -> !fir.shape<2> +! CHECK: %[[RESULT_BOX_VAL:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap>, !fir.shape<2>) -> !fir.box>> +! CHECK: fir.store %[[RESULT_BOX_VAL]] to %[[RESULT_BOX_ADDR]] : !fir.ref>>> +! CHECK: %[[RESULT_BOX_ADDR_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref>>>) -> !fir.ref> +! CHECK: %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box>) -> !fir.box +! CHECK: %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box>) -> !fir.box +! CHECK: {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_ADDR_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}} : (!fir.ref>, !fir.box, !fir.box, !fir.ref, i32) -> none +! CHECK: %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref>>> +! CHECK: %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box>>) -> !fir.heap> +! CHECK: %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop +! CHECK: {{.*}}fir.array_fetch +! CHECK: {{.*}}fir.array_update +! CHECK: fir.result +! CHECK: } +! CHECK: fir.array_merge_store %[[Z_BOX]], %[[Z_COPY_FROM_RESULT]] to %[[Z]] : !fir.array<2x2xf32>, !fir.array<2x2xf32>, !fir.ref> +! CHECK: fir.freemem %[[RESULT_TMP]] : > +subroutine matmul_test(x,y,z) + real :: x(3,1), y(1,3), z(2,2) + z = matmul(x,y) +end subroutine + +! CHECK-LABEL: matmul_test2 +! CHECK-SAME: (%[[X_BOX:.*]]: !fir.box>>{{.*}}, %[[Y_BOX:.*]]: !fir.box>>{{.*}}, %[[Z_BOX:.*]]: !fir.box>>{{.*}}) +!CHECK: %[[RESULT_BOX_ADDR:.*]] = fir.alloca !fir.box>>> +!CHECK: %[[Z:.*]] = fir.array_load %[[Z_BOX]] : (!fir.box>>) -> !fir.array> +!CHECK: %[[ZERO_INIT:.*]] = fir.zero_bits !fir.heap>> +!CHECK: %[[C0:.*]] = arith.constant 0 : index +!CHECK: %[[RESULT_SHAPE:.*]] = fir.shape %[[C0]] : (index) -> !fir.shape<1> +!CHECK: %[[RESULT_BOX:.*]] = fir.embox %[[ZERO_INIT]](%[[RESULT_SHAPE]]) : (!fir.heap>>, !fir.shape<1>) -> !fir.box>>> +!CHECK: fir.store %[[RESULT_BOX]] to %[[RESULT_BOX_ADDR]] : !fir.ref>>>> +!CHECK: %[[RESULT_BOX_RUNTIME:.*]] = fir.convert %[[RESULT_BOX_ADDR]] : (!fir.ref>>>>) -> !fir.ref> +!CHECK: %[[X_BOX_RUNTIME:.*]] = fir.convert %[[X_BOX]] : (!fir.box>>) -> !fir.box +!CHECK: %[[Y_BOX_RUNTIME:.*]] = fir.convert %[[Y_BOX]] : (!fir.box>>) -> !fir.box +!CHECK: {{.*}}fir.call @_FortranAMatmul(%[[RESULT_BOX_RUNTIME]], %[[X_BOX_RUNTIME]], %[[Y_BOX_RUNTIME]], {{.*}}, {{.*}}) : (!fir.ref>, !fir.box, !fir.box, !fir.ref, i32) -> none +!CHECK: %[[RESULT_BOX:.*]] = fir.load %[[RESULT_BOX_ADDR]] : !fir.ref>>>> +!CHECK: %[[RESULT_TMP:.*]] = fir.box_addr %[[RESULT_BOX]] : (!fir.box>>>) -> !fir.heap>> +!CHECK: %[[Z_COPY_FROM_RESULT:.*]] = fir.do_loop +!CHECK: {{.*}}fir.array_fetch +!CHECK: {{.*}}fir.array_update +!CHECK: fir.result +!CHECK: } +!CHECK: fir.array_merge_store %[[Z]], %[[Z_COPY_FROM_RESULT]] to %[[Z_BOX]] : !fir.array>, !fir.array>, !fir.box>> +!CHECK: fir.freemem %[[RESULT_TMP]] : >> +subroutine matmul_test2(X, Y, Z) + logical :: X(:,:) + logical :: Y(:) + logical :: Z(:) + Z = matmul(X, Y) +end subroutine