Skip to content

Commit

Permalink
[mlir] Add call_intrinsic op to LLVMIIR
Browse files Browse the repository at this point in the history
The call_intrinsic op allows us to call LLVM intrinsics from the LLVMDialect without implementing a new op every time.

Reviewed By: lattner, rriddle

Differential Revision: https://reviews.llvm.org/D137187
  • Loading branch information
electriclilies authored and Mogball committed Nov 2, 2022
1 parent 1ceafe5 commit 0efff7c
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 4 deletions.
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Expand Up @@ -707,6 +707,25 @@ def LLVM_vector_extract
}];
}

//===--------------------------------------------------------------------===//
// CallIntrinsicOp
//===--------------------------------------------------------------------===//
def LLVM_CallIntrinsicOp : LLVM_Op<"call_intrinsic", [Pure]> {
let summary = "Call to an LLVM intrinsic function.";
let description = [{
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
the MLIR function type of this op to determine which intrinsic to call.
}];
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args);
let results = (outs Variadic<LLVM_Type>:$results);
let llvmBuilder = [{
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
}];
let assemblyFormat = [{
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
}];
}

//
// LLVM Vector Predication operations.
//
Expand Down
72 changes: 68 additions & 4 deletions mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
Expand Up @@ -258,6 +258,70 @@ static SmallVector<unsigned> extractPosition(ArrayRef<int64_t> indices) {
return position;
}

/// Get the declaration of an overloaded llvm intrinsic. First we get the
/// overloaded argument types and/or result type from the CallIntrinsicOp, and
/// then use those to get the correct declaration of the overloaded intrinsic.
static FailureOr<llvm::Function *>
getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id,
llvm::Module *module,
LLVM::ModuleTranslation &moduleTranslation) {
SmallVector<llvm::Type *, 8> allArgTys;
for (Type type : op->getOperandTypes())
allArgTys.push_back(moduleTranslation.convertType(type));

llvm::Type *resTy;
if (op.getNumResults() == 0)
resTy = llvm::Type::getVoidTy(module->getContext());
else
resTy = moduleTranslation.convertType(op.getResult(0).getType());

// ATM we do not support variadic intrinsics.
llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false);

SmallVector<llvm::Intrinsic::IITDescriptor, 8> table;
getIntrinsicInfoTableEntries(id, table);
ArrayRef<llvm::Intrinsic::IITDescriptor> tableRef = table;

SmallVector<llvm::Type *, 8> overloadedArgTys;
if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
overloadedArgTys) !=
llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
return op.emitOpError("intrinsic type is not a match");
}

ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys;
return llvm::Intrinsic::getDeclaration(module, id, overloadedArgTysRef);
}

/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp &op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Intrinsic::ID id =
llvm::Function::lookupIntrinsicID(op.getIntrinAttr());
if (!id)
return op.emitOpError()
<< "couldn't find intrinsic: " << op.getIntrinAttr();

llvm::Function *fn = nullptr;
if (llvm::Intrinsic::isOverloaded(id)) {
auto fnOrFailure =
getOverloadedDeclaration(op, id, module, moduleTranslation);
if (failed(fnOrFailure))
return failure();
fn = fnOrFailure.value();
} else {
fn = llvm::Intrinsic::getDeclaration(module, id, {});
}

auto *inst =
builder.CreateCall(fn, moduleTranslation.lookupValues(op.getOperands()));
if (op.getNumResults() == 1)
moduleTranslation.mapValue(op->getResults().front()) = inst;
return success();
}

static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
Expand All @@ -272,8 +336,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
// Emit function calls. If the "callee" attribute is present, this is a
// direct function call and we also need to look up the remapped function
// itself. Otherwise, this is an indirect call and the callee is the first
// operand, look it up as a normal value. Return the llvm::Value representing
// the function result, which may be of llvm::VoidTy type.
// operand, look it up as a normal value. Return the llvm::Value
// representing the function result, which may be of llvm::VoidTy type.
auto convertCall = [&](Operation &op) -> llvm::Value * {
auto operands = moduleTranslation.lookupValues(op.getOperands());
ArrayRef<llvm::Value *> operandsRef(operands);
Expand Down Expand Up @@ -404,8 +468,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}

// Emit branches. We need to look up the remapped blocks and ignore the block
// arguments that were transformed into PHI nodes.
// Emit branches. We need to look up the remapped blocks and ignore the
// block arguments that were transformed into PHI nodes.
if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
llvm::BranchInst *branch =
builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
Expand Down
82 changes: 82 additions & 0 deletions mlir/test/Dialect/LLVMIR/call-intrin.mlir
@@ -0,0 +1,82 @@
// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s

// CHECK: ; ModuleID = 'LLVMDialectModule'
// CHECK: source_filename = "LLVMDialectModule"
// CHECK: declare ptr @malloc(i64)
// CHECK: declare void @free(ptr)
// CHECK: define <4 x float> @round_sse41() {
// CHECK: %1 = call <4 x float> @llvm.x86.sse41.round.ss(<4 x float> <float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000>, <4 x float> <float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000, float 0x3FC99999A0000000>, i32 1)
// CHECK: ret <4 x float> %1
// CHECK: }
llvm.func @round_sse41() -> vector<4xf32> {
%0 = llvm.mlir.constant(1 : i32) : i32
%1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32>
%res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {}
llvm.return %res: vector<4xf32>
}

// -----

// CHECK: ; ModuleID = 'LLVMDialectModule'
// CHECK: source_filename = "LLVMDialectModule"

// CHECK: declare ptr @malloc(i64)

// CHECK: declare void @free(ptr)

// CHECK: define float @round_overloaded() {
// CHECK: %1 = call float @llvm.round.f32(float 1.000000e+00)
// CHECK: ret float %1
// CHECK: }
llvm.func @round_overloaded() -> f32 {
%0 = llvm.mlir.constant(1.0 : f32) : f32
%res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {}
llvm.return %res: f32
}

// -----

// CHECK: ; ModuleID = 'LLVMDialectModule'
// CHECK: source_filename = "LLVMDialectModule"
// CHECK: declare ptr @malloc(i64)
// CHECK: declare void @free(ptr)
// CHECK: define void @lifetime_start() {
// CHECK: %1 = alloca float, i8 1, align 4
// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %1)
// CHECK: ret void
// CHECK: }
llvm.func @lifetime_start() {
%0 = llvm.mlir.constant(4 : i64) : i64
%1 = llvm.mlir.constant(1 : i8) : i8
%2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr
llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {}
llvm.return
}

// -----

llvm.func @variadic() {
%0 = llvm.mlir.constant(1 : i8) : i8
%1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr
llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> ()
llvm.return
}

// -----

llvm.func @no_intrinsic() {
// expected-error@below {{'llvm.call_intrinsic' op couldn't find intrinsic: "llvm.does_not_exist"}}
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
llvm.call_intrinsic "llvm.does_not_exist"() : () -> ()
llvm.return
}

// -----

llvm.func @bad_types() {
%0 = llvm.mlir.constant(1 : i8) : i8
// expected-error@below {{'llvm.call_intrinsic' op intrinsic type is not a match}}
// expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}}
llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {}
llvm.return
}

0 comments on commit 0efff7c

Please sign in to comment.