49 changes: 15 additions & 34 deletions mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,31 +179,12 @@ class LLVMIntrinsic {
};
} // namespace

/// Emits C++ code constructing an LLVM IR intrinsic given the generated MLIR
/// operation. In LLVM IR, intrinsics are constructed as function calls.
static void emitBuilder(const LLVMIntrinsic &intr, llvm::raw_ostream &os) {
auto overloadedRes = intr.getOverloadableResultsIdxs();
auto overloadedOps = intr.getOverloadableOperandsIdxs();
os << " llvm::Module *module = builder.GetInsertBlock()->getModule();\n";
os << " llvm::Function *fn = llvm::Intrinsic::getDeclaration(\n";
os << " module, llvm::Intrinsic::" << intr.getProperRecordName()
<< ", {";
for (unsigned idx : overloadedRes.set_bits()) {
os << "\n opInst.getResult(" << idx << ").getType()"
<< ".cast<LLVM::LLVMType>().getUnderlyingType(),";
}
for (unsigned idx : overloadedOps.set_bits()) {
os << "\n opInst.getOperand(" << idx << ").getType()"
<< ".cast<LLVM::LLVMType>().getUnderlyingType(),";
}
if (overloadedRes.any() || overloadedOps.any())
os << "\n ";
os << "});\n";
os << " auto operands =\n";
os << " lookupValues(opInst.getOperands());\n";
os << " " << (intr.getNumResults() > 0 ? "$res = " : "")
<< "builder.CreateCall(fn, operands);\n";
os << " ";
/// Prints the elements in "range" separated by commas and surrounded by "[]".
template <typename Range>
void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
os << '[';
mlir::interleaveComma(range, os);
os << ']';
}

/// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic.
Expand All @@ -224,16 +205,16 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {

// Emit the definition.
os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass
<< "<\"" << intr.getOperationName() << "\", [";
mlir::interleaveComma(traits, os);
os << "]>, Arguments<(ins" << (operands.empty() ? "" : " ");
<< "<\"" << intr.getOperationName() << "\", ";
printBracketedRange(intr.getOverloadableResultsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os);
os << ", ";
printBracketedRange(traits, os);
os << ", " << (intr.getNumResults() == 0 ? 0 : 1) << ">, Arguments<(ins"
<< (operands.empty() ? "" : " ");
mlir::interleaveComma(operands, os);
os << ")>, Results<(outs"
<< (intr.getNumResults() == 0 ? "" : " LLVM_Type:$res") << ")> {\n"
<< " let llvmBuilder = [{\n";
emitBuilder(intr, os);
os << "}];\n";
os << "}\n\n";
os << ")>;\n\n";

return false;
}
Expand Down