Skip to content

Commit

Permalink
[Flang] Change fir.divc to perform library call rather than generate …
Browse files Browse the repository at this point in the history
…inline operations.

Currently `fir.divc` is always lowered to a sequence of llvm operations to perform complex division, however this causes issues for extreme values when the calculations overflow. While this behaviour would be fine at -Ofast, this is currently the default at all levels.

This patch changes `fir.divc` to lower to a library call instead, except for when KIND=3 as there is no appropriate library call for this case.

Reviewed By: vzakhari

Differential Revision: https://reviews.llvm.org/D145808
  • Loading branch information
Sacha Ballantyne authored and Sacha Ballantyne committed Apr 4, 2023
1 parent 0109f8d commit a7bb8e2
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 34 deletions.
86 changes: 66 additions & 20 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Expand Up @@ -41,6 +41,7 @@
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include <mlir/IR/ValueRange.h>

namespace fir {
#define GEN_PASS_DEF_FIRTOLLVMLOWERING
Expand Down Expand Up @@ -3512,42 +3513,87 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
}
};

/// Inlined complex division
static mlir::LogicalResult getDivc3(fir::DivcOp op,
mlir::ConversionPatternRewriter &rewriter,
std::string funcName, mlir::Type returnType,
llvm::SmallVector<mlir::Type> argType,
llvm::SmallVector<mlir::Value> args) {
auto module = op->getParentOfType<mlir::ModuleOp>();
auto loc = op.getLoc();
if (mlir::LLVM::LLVMFuncOp divideFunc =
module.lookupSymbol<mlir::LLVM::LLVMFuncOp>(funcName)) {
auto call = rewriter.create<mlir::LLVM::CallOp>(
loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args);
rewriter.replaceOp(op, call->getResults());
return mlir::success();
}
mlir::OpBuilder moduleBuilder(
op->getParentOfType<mlir::ModuleOp>().getBodyRegion());
auto divideFunc = moduleBuilder.create<mlir::LLVM::LLVMFuncOp>(
rewriter.getUnknownLoc(), funcName,
mlir::LLVM::LLVMFunctionType::get(returnType, argType,
/*isVarArg=*/false));
auto call = rewriter.create<mlir::LLVM::CallOp>(
loc, returnType, mlir::SymbolRefAttr::get(divideFunc), args);
rewriter.replaceOp(op, call->getResults());
return mlir::success();
}

/// complex division
struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
using FIROpConversion::FIROpConversion;

mlir::LogicalResult
matchAndRewrite(fir::DivcOp divc, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
// TODO: Can we use a call to __divdc3 instead?
// Just generate inline code for now.
// given: (x + iy) / (x' + iy')
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
mlir::Value a = adaptor.getOperands()[0];
mlir::Value b = adaptor.getOperands()[1];
auto loc = divc.getLoc();
mlir::Type eleTy = convertType(getComplexEleTy(divc.getType()));
mlir::Type ty = convertType(divc.getType());
llvm::SmallVector<mlir::Type> argTy = {eleTy, eleTy, eleTy, eleTy};
mlir::Type firReturnTy = divc.getType();
mlir::Type ty = convertType(firReturnTy);
auto x0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 0);
auto y0 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, a, 1);
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
rewriter.replaceOp(divc, r0.getResult());
return mlir::success();

fir::KindTy kind = (firReturnTy.dyn_cast<fir::ComplexType>()).getFKind();
mlir::SmallVector<mlir::Value> args = {x0, y0, x1, y1};
switch (kind) {
default:
llvm_unreachable("Unsupported complex type");
case 4:
return getDivc3(divc, rewriter, "__divsc3", ty, argTy, args);
case 8:
return getDivc3(divc, rewriter, "__divdc3", ty, argTy, args);
case 10:
return getDivc3(divc, rewriter, "__divxc3", ty, argTy, args);
case 16:
return getDivc3(divc, rewriter, "__divtc3", ty, argTy, args);
case 3:
case 2:
// No library function for bfloat or half in compiler_rt, generate
// inline instead
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
rewriter.replaceOp(divc, r0.getResult());
return mlir::success();
}
}
};

Expand Down
48 changes: 34 additions & 14 deletions flang/test/Fir/convert-to-llvm.fir
Expand Up @@ -586,22 +586,42 @@ func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128
// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : f128
// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128
// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128
// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128
// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : f128
// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : f128
// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128
// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : f128
// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
// CHECK: %[[CALL:.*]] = llvm.call @__divtc3(%[[X0]], %[[Y0]], %[[X1]], %[[Y1]]) : (f128, f128, f128, f128) -> !llvm.struct<(f128, f128)>
// CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)>

// -----

// Test FIR complex division inlines for KIND=3

func.func @fir_complex_div(%a: !fir.complex<3>, %b: !fir.complex<3>) -> !fir.complex<3> {
%c = fir.divc %a, %b : !fir.complex<3>
return %c : !fir.complex<3>
}

// CHECK-LABEL: llvm.func @fir_complex_div(
// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(bf16, bf16)>,
// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(bf16, bf16)>) -> !llvm.struct<(bf16, bf16)> {
// CHECK: %[[X0:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.struct<(bf16, bf16)>
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(bf16, bf16)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(bf16, bf16)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(bf16, bf16)>
// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : bf16
// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : bf16
// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : bf16
// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : bf16
// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : bf16
// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : bf16
// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : bf16
// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : bf16
// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : bf16
// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : bf16
// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : bf16
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(bf16, bf16)>
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(bf16, bf16)>
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(bf16, bf16)>
// CHECK: llvm.return %{{.*}} : !llvm.struct<(bf16, bf16)>


// -----

// Test FIR complex negation conversion
Expand Down

0 comments on commit a7bb8e2

Please sign in to comment.