diff --git a/llvm/unittests/Transforms/Utils/LocalTest.cpp b/llvm/unittests/Transforms/Utils/LocalTest.cpp index 4fa6a09f2022f..3862a418603b2 100644 --- a/llvm/unittests/Transforms/Utils/LocalTest.cpp +++ b/llvm/unittests/Transforms/Utils/LocalTest.cpp @@ -1019,12 +1019,12 @@ TEST(Local, CanReplaceOperandWithVariable) { BasicBlock *BB0 = BasicBlock::Create(Ctx, "", TestBody); B.SetInsertPoint(BB0); - Value *Intrin = M.getOrInsertFunction("llvm.foo", FnType).getCallee(); - Value *Func = M.getOrInsertFunction("foo", FnType).getCallee(); - Value *VarArgFunc - = M.getOrInsertFunction("foo.vararg", VarArgFnType).getCallee(); - Value *VarArgIntrin - = M.getOrInsertFunction("llvm.foo.vararg", VarArgFnType).getCallee(); + FunctionCallee Intrin = M.getOrInsertFunction("llvm.foo", FnType); + FunctionCallee Func = M.getOrInsertFunction("foo", FnType); + FunctionCallee VarArgFunc + = M.getOrInsertFunction("foo.vararg", VarArgFnType); + FunctionCallee VarArgIntrin + = M.getOrInsertFunction("llvm.foo.vararg", VarArgFnType); auto *CallToIntrin = B.CreateCall(Intrin); auto *CallToFunc = B.CreateCall(Func); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 4af83eec969ca..84b64b910eeff 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -334,7 +334,12 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, return builder.CreateCall(functionMapping.lookup(attr.getValue()), operandsRef); } else { - return builder.CreateCall(operandsRef.front(), operandsRef.drop_front()); + auto *calleePtrType = + cast(operandsRef.front()->getType()); + auto *calleeType = + cast(calleePtrType->getElementType()); + return builder.CreateCall(calleeType, operandsRef.front(), + operandsRef.drop_front()); } }; @@ -353,14 +358,19 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, if (auto invOp = dyn_cast(opInst)) { auto operands = lookupValues(opInst.getOperands()); ArrayRef operandsRef(operands); - if (auto attr = opInst.getAttrOfType("callee")) + if (auto attr = opInst.getAttrOfType("callee")) { builder.CreateInvoke(functionMapping.lookup(attr.getValue()), blockMapping[invOp.getSuccessor(0)], blockMapping[invOp.getSuccessor(1)], operandsRef); - else + } else { + auto *calleePtrType = + cast(operandsRef.front()->getType()); + auto *calleeType = + cast(calleePtrType->getElementType()); builder.CreateInvoke( - operandsRef.front(), blockMapping[invOp.getSuccessor(0)], + calleeType, operandsRef.front(), blockMapping[invOp.getSuccessor(0)], blockMapping[invOp.getSuccessor(1)], operandsRef.drop_front()); + } return success(); }