Skip to content

Commit

Permalink
[mlir][llvm] Fastmath flags import from LLVM IR.
Browse files Browse the repository at this point in the history
This revision adds support to import fastmath flags from LLVMIR. It
implement the import using a listener attached to the builder. The
listener gets notified if an operation is created and then checks if
there are fastmath flags to import from LLVM IR to the MLIR. The
listener based approach allows us to perform the import without changing
the mlirBuilders used to create the imported operations.

An alternative solution, could be to update the builders so that they
return the created operation using FailureOr<Operation*> instead of
LogicalResult. However, this solution implies an LLVM IR instruction
always maps to exatly one MLIR operation. While mostly true, there are
already exceptions to this such as the PHI instruciton. Additionally, an
mlirBuilder based solution also further complicates the builder
implementations, which led to the listener based solution.

Depends on D139405

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D139620
  • Loading branch information
gysit committed Dec 16, 2022
1 parent 66bf54a commit b6ebecc
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 43 deletions.
56 changes: 30 additions & 26 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Expand Up @@ -13,52 +13,55 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
// "intr." to avoid potential name clashes.

class LLVM_UnaryIntrOpBase<string func, Type element,
list<Trait> traits = [],
dag addAttrs = (ins)> :
list<Trait> traits = [], bit requiresFastmath = 0> :
LLVM_OneResultIntrOp<func, [], [0],
!listconcat([Pure, SameOperandsAndResultType], traits)> {
dag args = (ins LLVM_ScalarOrVectorOf<element>:$in);
let arguments = !con(args, addAttrs);
!listconcat([Pure, SameOperandsAndResultType], traits),
requiresFastmath> {
dag commonArgs = (ins LLVM_ScalarOrVectorOf<element>:$in);
let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
"functional-type(operands, results)";
}

class LLVM_UnaryIntrOpI<string func, list<Trait> traits = []> :
LLVM_UnaryIntrOpBase<func, AnySignlessInteger, traits>;
LLVM_UnaryIntrOpBase<func, AnySignlessInteger, traits> {
let arguments = commonArgs;
}

class LLVM_UnaryIntrOpF<string func, list<Trait> traits = []> :
LLVM_UnaryIntrOpBase<func, LLVM_AnyFloat,
!listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>],
traits),
(ins DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags)>;
LLVM_UnaryIntrOpBase<func, LLVM_AnyFloat, traits, /*requiresFastmath=*/1> {
dag fmfArg = (
ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
let arguments = !con(commonArgs, fmfArg);
}

class LLVM_BinarySameArgsIntrOpBase<string func, Type element,
list<Trait> traits = [],
dag addAttrs = (ins)> :
list<Trait> traits = [], bit requiresFastmath = 0> :
LLVM_OneResultIntrOp<func, [], [0],
!listconcat([Pure, SameOperandsAndResultType], traits)> {
dag args = (ins LLVM_ScalarOrVectorOf<element>:$a,
LLVM_ScalarOrVectorOf<element>:$b);
let arguments = !con(args, addAttrs);
!listconcat([Pure, SameOperandsAndResultType], traits),
requiresFastmath> {
dag commonArgs = (ins LLVM_ScalarOrVectorOf<element>:$a,
LLVM_ScalarOrVectorOf<element>:$b);
let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
"functional-type(operands, results)";
}

class LLVM_BinarySameArgsIntrOpI<string func, list<Trait> traits = []> :
LLVM_BinarySameArgsIntrOpBase<func, AnySignlessInteger, traits>;
LLVM_BinarySameArgsIntrOpBase<func, AnySignlessInteger, traits> {
let arguments = commonArgs;
}

class LLVM_BinarySameArgsIntrOpF<string func, list<Trait> traits = []> :
LLVM_BinarySameArgsIntrOpBase<func, LLVM_AnyFloat,
!listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>],
traits),
(ins DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags)>;
LLVM_BinarySameArgsIntrOpBase<func, LLVM_AnyFloat, traits,
/*requiresFastmath=*/1> {
dag fmfArg = (
ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
let arguments = !con(commonArgs, fmfArg);
}

class LLVM_TernarySameArgsIntrOpF<string func, list<Trait> traits = []> :
LLVM_OneResultIntrOp<func, [], [0],
!listconcat([DeclareOpInterfaceMethods<FastmathFlagsInterface>,
Pure, SameOperandsAndResultType], traits)> {
!listconcat([Pure, SameOperandsAndResultType], traits),
/*requiresFastmath=*/1> {
let arguments = (ins LLVM_ScalarOrVectorOf<AnyFloat>:$a,
LLVM_ScalarOrVectorOf<AnyFloat>:$b,
LLVM_ScalarOrVectorOf<AnyFloat>:$c,
Expand Down Expand Up @@ -106,7 +109,8 @@ def LLVM_FTruncOp : LLVM_UnaryIntrOpF<"trunc">;
def LLVM_SqrtOp : LLVM_UnaryIntrOpF<"sqrt">;
def LLVM_PowOp : LLVM_BinarySameArgsIntrOpF<"pow">;
def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi", [], [0,1],
[DeclareOpInterfaceMethods<FastmathFlagsInterface>, Pure]> {
[DeclareOpInterfaceMethods<FastmathFlagsInterface>, Pure],
/*requiresFastmath=*/1> {
let arguments =
(ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$val,
AnySignlessInteger:$power,
Expand Down
27 changes: 19 additions & 8 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Expand Up @@ -345,8 +345,13 @@ def LLVM_IntrPatterns {
class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
list<int> overloadedResults, list<int> overloadedOperands,
list<Trait> traits, int numResults,
bit requiresAccessGroup = 0, bit requiresAliasScope = 0>
: LLVM_OpBase<dialect, opName, traits>,
bit requiresAccessGroup = 0, bit requiresAliasScope = 0,
bit requiresFastmath = 0>
: LLVM_OpBase<dialect, opName, !listconcat(
!if(!gt(requiresFastmath, 0),
[DeclareOpInterfaceMethods<FastmathFlagsInterface>],
[]),
traits)>,
Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
string resultPattern = !if(!gt(numResults, 1),
LLVM_IntrPatterns.structResult,
Expand Down Expand Up @@ -378,20 +383,23 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
return failure();
SmallVector<Type> resultTypes =
}] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
Operation *op = $_builder.create<$_qualCppClassName>(
auto op = $_builder.create<$_qualCppClassName>(
$_location, resultTypes, *mlirOperands);
}] # !if(!gt(numResults, 0), "$res = op->getResult(0);", "(void)op;");
}] # !if(!gt(requiresFastmath, 0),
"setFastmathFlagsAttr(inst, op);", "")
# !if(!gt(numResults, 0), "$res = op;", "(void)op;");
}

// Base class for LLVM intrinsic operations, should not be used directly. Places
// the intrinsic into the LLVM dialect and prefixes its name with "intr.".
class LLVM_IntrOp<string mnem, list<int> overloadedResults,
list<int> overloadedOperands, list<Trait> traits,
int numResults, bit requiresAccessGroup = 0,
bit requiresAliasScope = 0>
bit requiresAliasScope = 0, bit requiresFastmath = 0>
: LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
overloadedResults, overloadedOperands, traits,
numResults, requiresAccessGroup, requiresAliasScope>;
numResults, requiresAccessGroup, requiresAliasScope,
requiresFastmath>;

// Base class for LLVM intrinsic operations returning no results. Places the
// intrinsic into the LLVM dialect and prefixes its name with "intr.".
Expand Down Expand Up @@ -419,8 +427,11 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
// empty otherwise.
class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<int> overloadedOperands = [],
list<Trait> traits = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1>;
list<Trait> traits = [],
bit requiresFastmath = 0>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
/*requiresAccessGroup=*/0, /*requiresAliasScope=*/0,
requiresFastmath>;

def LLVM_OneResultOpBuilder :
OpBuilder<(ins "Type":$resultType, "ValueRange":$operands,
Expand Down
22 changes: 15 additions & 7 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Expand Up @@ -44,14 +44,14 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic,
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$lhs `,` $rhs custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
string llvmInstName = instName;
string mlirBuilder = [{
$res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
}];
}
class LLVM_IntArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName, traits> {
let arguments = commonArgs;
string mlirBuilder = [{
$res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
}];
}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
Expand All @@ -60,6 +60,11 @@ class LLVM_FloatArithmeticOp<string mnemonic, string instName,
dag fmfArg = (
ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
let arguments = !con(commonArgs, fmfArg);
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
setFastmathFlagsAttr(inst, op);
$res = op;
}];
}

// Class for arithmetic unary operations.
Expand All @@ -76,8 +81,10 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
let assemblyFormat = "$operand custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
string llvmInstName = instName;
string mlirBuilder = [{
$res = $_builder.create<$_qualCppClassName>($_location, $operand);
}];
auto op = $_builder.create<$_qualCppClassName>($_location, $operand);
setFastmathFlagsAttr(inst, op);
$res = op;
}];
}

// Integer binary operations.
Expand Down Expand Up @@ -146,11 +153,12 @@ def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [
string llvmBuilder = [{
$res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
}];
// FIXME: Import fastmath flags.
string mlirBuilder = [{
auto *fCmpInst = cast<llvm::FCmpInst>(inst);
$res = $_builder.create<$_qualCppClassName>(
auto op = $_builder.create<$_qualCppClassName>(
$_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs);
setFastmathFlagsAttr(inst, op);
$res = op;
}];
// Set the $predicate index to -1 to indicate there is no matching operand
// and decrement the following indices.
Expand Down
40 changes: 38 additions & 2 deletions mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
Expand Up @@ -38,6 +38,7 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/Type.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/Error.h"
Expand Down Expand Up @@ -326,8 +327,13 @@ getTopologicallySortedBlocks(llvm::Function *func) {
return blocks;
}

// Handles importing globals and functions from an LLVM module.
namespace {
/// Module import implementation class that provides methods to import globals
/// and functions from an LLVM module into an MLIR module. It holds mappings
/// between the original and translated globals, basic blocks, and values used
/// during the translation. Additionally, it keeps track of the current constant
/// insertion point since LLVM immediate values translate to MLIR operations
/// that are introduced at the beginning of the region.
class Importer {
public:
Importer(MLIRContext *context, ModuleOp module)
Expand Down Expand Up @@ -421,6 +427,10 @@ class Importer {
constantInsertionOp = nullptr;
}

/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.
void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const;
/// Returns personality of `func` as a FlatSymbolRefAttr.
FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func);
/// Imports `bb` into `block`, which must be initially empty.
Expand Down Expand Up @@ -487,6 +497,31 @@ class Importer {
};
} // namespace

void Importer::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);

// Even if the imported operation implements the fastmath interface, the
// original instruction may not have fastmath flags set. Exit if an
// instruction, such as a non floating-point function call, does not have
// fastmath flags.
if (!isa<llvm::FPMathOperator>(inst))
return;
llvm::FastMathFlags flags = inst->getFastMathFlags();

// Set the fastmath bits flag-by-flag.
FastmathFlags value = {};
value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs());
value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs());
value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros());
value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal());
value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract());
value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc());
value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc());
FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value);
iface->setAttr(iface.getFastmathAttrName(), attr);
}

// We only need integers, floats, doubles, and vectors and tensors thereof for
// attributes. Scalar and vector types are converted to the standard
// equivalents. Array types are converted to ranked tensors; nested array types
Expand Down Expand Up @@ -1032,6 +1067,7 @@ LogicalResult Importer::convertOperation(OpBuilder &odsBuilder,
} else {
callOp = builder.create<CallOp>(loc, types, operands);
}
setFastmathFlagsAttr(inst, callOp);
if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult());
return success();
Expand Down Expand Up @@ -1116,7 +1152,7 @@ LogicalResult Importer::convertOperation(OpBuilder &odsBuilder,

LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
// FIXME: Support uses of SubtargetData.
// FIXME: Add support for fast-math flags and call / operand attributes.
// FIXME: Add support for call / operand attributes.
// FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch,
// callbr, vaarg, landingpad, catchpad, cleanuppad instructions.

Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Target/LLVMIR/Import/fastmath.ll
@@ -0,0 +1,56 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s

; CHECK-LABEL: @fastmath_inst
define void @fastmath_inst(float %arg1, float %arg2) {
; CHECK: llvm.fadd %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nnan, ninf>} : f32
%1 = fadd nnan ninf float %arg1, %arg2
; CHECK: llvm.fsub %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nsz>} : f32
%2 = fsub nsz float %arg1, %arg2
; CHECK: llvm.fmul %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<arcp, contract>} : f32
%3 = fmul arcp contract float %arg1, %arg2
; CHECK: llvm.fdiv %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<afn, reassoc>} : f32
%4 = fdiv afn reassoc float %arg1, %arg2
; CHECK: llvm.fneg %{{.*}} {fastmathFlags = #llvm.fastmath<fast>} : f32
%5 = fneg fast float %arg1
ret void
}

; // -----

; CHECK-LABEL: @fastmath_fcmp
define void @fastmath_fcmp(float %arg1, float %arg2) {
; CHECK: llvm.fcmp "oge" %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nsz>} : f32
%1 = fcmp nsz oge float %arg1, %arg2
ret void
}

; // -----

declare float @fn(float)

; CHECK-LABEL: @fastmath_call
define void @fastmath_call(float %arg1) {
; CHECK: llvm.call @fn(%{{.*}}) {fastmathFlags = #llvm.fastmath<ninf>} : (f32) -> f32
%1 = call ninf float @fn(float %arg1)
ret void
}

; // -----

declare float @llvm.exp.f32(float)
declare float @llvm.powi.f32.i32(float, i32)
declare float @llvm.pow.f32(float, float)
declare float @llvm.fmuladd.f32(float, float, float)

; CHECK-LABEL: @fastmath_intr
define void @fastmath_intr(float %arg1, i32 %arg2) {
; CHECK: llvm.intr.exp(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf>} : (f32) -> f32
%1 = call nnan ninf float @llvm.exp.f32(float %arg1)
; CHECK: llvm.intr.powi(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, i32) -> f32
%2 = call fast float @llvm.powi.f32.i32(float %arg1, i32 %arg2)
; CHECK: llvm.intr.pow(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
%3 = call fast float @llvm.pow.f32(float %arg1, float %arg1)
; CHECK: llvm.intr.fmuladd(%{{.*}}, %{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32, f32) -> f32
%4 = call fast float @llvm.fmuladd.f32(float %arg1, float %arg1, float %arg1)
ret void
}

0 comments on commit b6ebecc

Please sign in to comment.