diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index b2257a163932c..bb2668790dbfb 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -707,6 +707,25 @@ def LLVM_vector_extract }]; } +//===--------------------------------------------------------------------===// +// CallIntrinsicOp +//===--------------------------------------------------------------------===// +def LLVM_CallIntrinsicOp : LLVM_Op<"call_intrinsic", [Pure]> { + let summary = "Call to an LLVM intrinsic function."; + let description = [{ + Call the specified llvm intrinsic. If the intrinsic is overloaded, use + the MLIR function type of this op to determine which intrinsic to call. + }]; + let arguments = (ins StrAttr:$intrin, Variadic:$args); + let results = (outs Variadic:$results); + let llvmBuilder = [{ + return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation); + }]; + let assemblyFormat = [{ + $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict + }]; +} + // // LLVM Vector Predication operations. // diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 1f89a55ee363e..abc2fadbbc9ac 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -258,6 +258,70 @@ static SmallVector extractPosition(ArrayRef indices) { return position; } +/// Get the declaration of an overloaded llvm intrinsic. First we get the +/// overloaded argument types and/or result type from the CallIntrinsicOp, and +/// then use those to get the correct declaration of the overloaded intrinsic. +static FailureOr +getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id, + llvm::Module *module, + LLVM::ModuleTranslation &moduleTranslation) { + SmallVector allArgTys; + for (Type type : op->getOperandTypes()) + allArgTys.push_back(moduleTranslation.convertType(type)); + + llvm::Type *resTy; + if (op.getNumResults() == 0) + resTy = llvm::Type::getVoidTy(module->getContext()); + else + resTy = moduleTranslation.convertType(op.getResult(0).getType()); + + // ATM we do not support variadic intrinsics. + llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false); + + SmallVector table; + getIntrinsicInfoTableEntries(id, table); + ArrayRef tableRef = table; + + SmallVector overloadedArgTys; + if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef, + overloadedArgTys) != + llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) { + return op.emitOpError("intrinsic type is not a match"); + } + + ArrayRef overloadedArgTysRef = overloadedArgTys; + return llvm::Intrinsic::getDeclaration(module, id, overloadedArgTysRef); +} + +/// Builder for LLVM_CallIntrinsicOp +static LogicalResult +convertCallLLVMIntrinsicOp(CallIntrinsicOp &op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Intrinsic::ID id = + llvm::Function::lookupIntrinsicID(op.getIntrinAttr()); + if (!id) + return op.emitOpError() + << "couldn't find intrinsic: " << op.getIntrinAttr(); + + llvm::Function *fn = nullptr; + if (llvm::Intrinsic::isOverloaded(id)) { + auto fnOrFailure = + getOverloadedDeclaration(op, id, module, moduleTranslation); + if (failed(fnOrFailure)) + return failure(); + fn = fnOrFailure.value(); + } else { + fn = llvm::Intrinsic::getDeclaration(module, id, {}); + } + + auto *inst = + builder.CreateCall(fn, moduleTranslation.lookupValues(op.getOperands())); + if (op.getNumResults() == 1) + moduleTranslation.mapValue(op->getResults().front()) = inst; + return success(); +} + static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -272,8 +336,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, // Emit function calls. If the "callee" attribute is present, this is a // direct function call and we also need to look up the remapped function // itself. Otherwise, this is an indirect call and the callee is the first - // operand, look it up as a normal value. Return the llvm::Value representing - // the function result, which may be of llvm::VoidTy type. + // operand, look it up as a normal value. Return the llvm::Value + // representing the function result, which may be of llvm::VoidTy type. auto convertCall = [&](Operation &op) -> llvm::Value * { auto operands = moduleTranslation.lookupValues(op.getOperands()); ArrayRef operandsRef(operands); @@ -404,8 +468,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } - // Emit branches. We need to look up the remapped blocks and ignore the block - // arguments that were transformed into PHI nodes. + // Emit branches. We need to look up the remapped blocks and ignore the + // block arguments that were transformed into PHI nodes. if (auto brOp = dyn_cast(opInst)) { llvm::BranchInst *branch = builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor())); diff --git a/mlir/test/Dialect/LLVMIR/call-intrin.mlir b/mlir/test/Dialect/LLVMIR/call-intrin.mlir new file mode 100644 index 0000000000000..30f5c9fb82572 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/call-intrin.mlir @@ -0,0 +1,82 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK: ; ModuleID = 'LLVMDialectModule' +// CHECK: source_filename = "LLVMDialectModule" +// CHECK: declare ptr @malloc(i64) +// CHECK: declare void @free(ptr) +// CHECK: define <4 x float> @round_sse41() { +// CHECK: %1 = call <4 x float> @llvm.x86.sse41.round.ss(<4 x float> , <4 x float> , i32 1) +// CHECK: ret <4 x float> %1 +// CHECK: } +llvm.func @round_sse41() -> vector<4xf32> { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32> + %res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {} + llvm.return %res: vector<4xf32> +} + +// ----- + +// CHECK: ; ModuleID = 'LLVMDialectModule' +// CHECK: source_filename = "LLVMDialectModule" + +// CHECK: declare ptr @malloc(i64) + +// CHECK: declare void @free(ptr) + +// CHECK: define float @round_overloaded() { +// CHECK: %1 = call float @llvm.round.f32(float 1.000000e+00) +// CHECK: ret float %1 +// CHECK: } +llvm.func @round_overloaded() -> f32 { + %0 = llvm.mlir.constant(1.0 : f32) : f32 + %res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {} + llvm.return %res: f32 +} + +// ----- + +// CHECK: ; ModuleID = 'LLVMDialectModule' +// CHECK: source_filename = "LLVMDialectModule" +// CHECK: declare ptr @malloc(i64) +// CHECK: declare void @free(ptr) +// CHECK: define void @lifetime_start() { +// CHECK: %1 = alloca float, i8 1, align 4 +// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %1) +// CHECK: ret void +// CHECK: } +llvm.func @lifetime_start() { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i8) : i8 + %2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr + llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {} + llvm.return +} + +// ----- + +llvm.func @variadic() { + %0 = llvm.mlir.constant(1 : i8) : i8 + %1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr + llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> () + llvm.return +} + +// ----- + +llvm.func @no_intrinsic() { + // expected-error@below {{'llvm.call_intrinsic' op couldn't find intrinsic: "llvm.does_not_exist"}} + // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} + llvm.call_intrinsic "llvm.does_not_exist"() : () -> () + llvm.return +} + +// ----- + +llvm.func @bad_types() { + %0 = llvm.mlir.constant(1 : i8) : i8 + // expected-error@below {{'llvm.call_intrinsic' op intrinsic type is not a match}} + // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} + llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {} + llvm.return +}