Skip to content

Commit

Permalink
[mlir][llvm] Introduce a mapValue function in LLVMIR import (nfc).
Browse files Browse the repository at this point in the history
The revision adds a mapValue function to the Importer, which can be used
in the MLIR builders to provide controlled accesses to the result
mapping of the imported instructions. Additionally, the change allows us
to avoid accessing a private member variable of the Importer class,
which simplifies future refactorings that aim at factoring out a
conversion interface (similar to the MLIR to LLVM translation). The
revision also renames the variables used when emitting the MLIR builders
to prepare the generalization to non-intrinsic instructions. In
particular, it renames callInst to inst and it passes in the instruction
arguments using an llvmOperands array rather than accessing the call
arguments directly.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D135645
  • Loading branch information
gysit committed Oct 11, 2022
1 parent a93d033 commit a2122a0
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 57 deletions.
18 changes: 9 additions & 9 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,12 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
# !if(!gt(numResults, 0), "$res = inst;", "");

// A builder to construct the MLIR LLVM dialect operation given the matching
// LLVM IR intrinsic instruction `callInst`. The following $-variables exist:
// - $name - substituted by the remapped `callInst` argument using the
// the index of the MLIR operation argument with the given name;
// LLVM IR instruction `inst` and its operands `llvmOperands`. The
// following $-variables exist:
// - $name - substituted by the remapped `inst` operand value at the index
// of the MLIR operation argument with the given name, or if the
// name matches the result name, by a reference to store the
// result of the newly created MLIR operation to;
// - $_int_attr - substituted by a call to an integer attribute matcher;
// - $_resultType - substituted with the MLIR result type;
// - $_location - substituted with the MLIR location;
Expand All @@ -348,16 +351,13 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
// NOTE: The $name variable resolution assumes the MLIR and LLVM argument
// orders match and there are no optional or variadic arguments.
string mlirBuilder = [{
SmallVector<llvm::Value *> operands(callInst->args());
SmallVector<Type> resultTypes =
}] # !if(!gt(numResults, 0),
"{$_resultType};", "{};") # [{
}] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
Operation *op = $_builder.create<$_qualCppClassName>(
$_location,
resultTypes,
processValues(operands));
}] # !if(!gt(numResults, 0),
"instMap[callInst] = op->getResult(0);", "(void)op;");
processValues(llvmOperands));
}] # !if(!gt(numResults, 0), "$res = op->getResult(0);", "(void)op;");
}

// Base class for LLVM intrinsic operations, should not be used directly. Places
Expand Down
111 changes: 67 additions & 44 deletions mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,18 @@ class Importer {
b.setInsertionPointToStart(module.getBody());
}

/// Stores the mapping between an LLVM value and its MLIR counterpart.
void mapValue(llvm::Value *llvm, Value mlir) { mapValue(llvm) = mlir; }

/// Provides write-once access to store the MLIR value corresponding to the
/// given LLVM value.
Value &mapValue(llvm::Value *value) {
Value &mlir = valueMapping[value];
assert(mlir == nullptr &&
"attempting to map a value that is already mapped");
return mlir;
}

/// Returns the remapped version of `value` or a placeholder that will be
/// remapped later if the defining instruction has not yet been visited.
Value processValue(llvm::Value *value);
Expand All @@ -196,8 +208,7 @@ class Importer {

/// Converts an LLVM intrinsic to an MLIR LLVM dialect operation if an MLIR
/// counterpart exists. Otherwise, returns failure.
LogicalResult convertIntrinsic(OpBuilder &odsBuilder,
llvm::CallInst *callInst);
LogicalResult convertIntrinsic(OpBuilder &odsBuilder, llvm::CallInst *inst);

/// Imports `f` into the current module.
LogicalResult processFunction(llvm::Function *f);
Expand All @@ -214,7 +225,8 @@ class Importer {
FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f);
/// Imports `bb` into `block`, which must be initially empty.
LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
/// Imports `inst` and populates instMap[inst] with the imported Value.
/// Imports `inst` and populates valueMapping[inst] with the result of the
/// imported operation.
LogicalResult processInstruction(llvm::Instruction *inst);
/// `br` branches to `target`. Append the block arguments to attach to the
/// generated branch op to `blockArguments`. These should be in the same order
Expand Down Expand Up @@ -258,8 +270,8 @@ class Importer {

/// Remapped blocks, for the current function.
DenseMap<llvm::BasicBlock *, Block *> blocks;
/// Remapped values. These are function-local.
DenseMap<llvm::Value *, Value> instMap;
/// Mappings between original and imported values. These are function-local.
DenseMap<llvm::Value *, Value> valueMapping;
/// Instructions that had not been defined when first encountered as a use.
/// Maps to the dummy Operation that was created in processValue().
DenseMap<llvm::Value *, Operation *> unknownInstMap;
Expand All @@ -283,8 +295,9 @@ Type Importer::convertType(llvm::Type *type) {
}

LogicalResult Importer::convertIntrinsic(OpBuilder &odsBuilder,
llvm::CallInst *callInst) {
llvm::Function *callee = callInst->getCalledFunction();
llvm::CallInst *inst) {
// Check if the callee is an intrinsic.
llvm::Function *callee = inst->getCalledFunction();
if (!callee || !callee->isIntrinsic())
return failure();

Expand All @@ -293,6 +306,8 @@ LogicalResult Importer::convertIntrinsic(OpBuilder &odsBuilder,
if (!isConvertibleIntrinsic(intrinsicID))
return failure();

// Copy the call arguments to an operands array used by the conversion.
SmallVector<llvm::Value *> llvmOperands(inst->args());
#include "mlir/Dialect/LLVMIR/LLVMIntrinsicFromLLVMIRConversions.inc"

return failure();
Expand Down Expand Up @@ -504,16 +519,16 @@ Value Importer::processConstant(llvm::Constant *c) {
b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin());
if (failed(processInstruction(i)))
return nullptr;
assert(instMap.count(i));
assert(valueMapping.count(i));

// If we don't remove entry of `i` here, it's totally possible that the
// next time llvm::ConstantExpr::getAsInstruction is called again, which
// always allocates a new Instruction, memory address of the newly
// created Instruction might be the same as `i`. Making processInstruction
// falsely believe that the new Instruction has been processed before
// and raised an assertion error.
Value value = instMap[i];
instMap.erase(i);
Value value = valueMapping[i];
valueMapping.erase(i);
// Remove this zombie LLVM instruction now, leaving us only with the MLIR
// op.
i->deleteValue();
Expand Down Expand Up @@ -574,8 +589,8 @@ Value Importer::processConstant(llvm::Constant *c) {
}

Value Importer::processValue(llvm::Value *value) {
auto it = instMap.find(value);
if (it != instMap.end())
auto it = valueMapping.find(value);
if (it != valueMapping.end())
return it->second;

// We don't expect to see instructions in dominator order. If we haven't seen
Expand Down Expand Up @@ -829,9 +844,13 @@ Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
// FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math
// flags and call / operand attributes are not supported.

// Convert all intrinsics that provide an MLIR builder.
if (auto callInst = dyn_cast<llvm::CallInst>(inst))
if (succeeded(convertIntrinsic(b, callInst)))
return success();

Location loc = translateLoc(inst->getDebugLoc());
assert(!instMap.count(inst) &&
"processInstruction must be called only once per instruction!");
switch (inst->getOpcode()) {
default:
return emitError(loc) << "unknown instruction: " << diag(*inst);
Expand Down Expand Up @@ -886,24 +905,25 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
state.addOperands(ops);
Operation *op = b.create(state);
if (!inst->getType()->isVoidTy())
instMap[inst] = op->getResult(0);
mapValue(inst, op->getResult(0));
return success();
}
case llvm::Instruction::Alloca: {
Value size = processValue(inst->getOperand(0));
auto *allocaInst = cast<llvm::AllocaInst>(inst);
instMap[inst] =
b.create<AllocaOp>(loc, convertType(inst->getType()),
convertType(allocaInst->getAllocatedType()), size,
allocaInst->getAlign().value());
Value res = b.create<AllocaOp>(loc, convertType(inst->getType()),
convertType(allocaInst->getAllocatedType()),
size, allocaInst->getAlign().value());
mapValue(inst, res);
return success();
}
case llvm::Instruction::ICmp: {
Value lhs = processValue(inst->getOperand(0));
Value rhs = processValue(inst->getOperand(1));
instMap[inst] = b.create<ICmpOp>(
Value res = b.create<ICmpOp>(
loc, getICmpPredicate(cast<llvm::ICmpInst>(inst)->getPredicate()), lhs,
rhs);
mapValue(inst, res);
return success();
}
case llvm::Instruction::FCmp: {
Expand All @@ -921,9 +941,10 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
resType = VectorType::get({numElements}, boolType);
}

instMap[inst] = b.create<FCmpOp>(
Value res = b.create<FCmpOp>(
loc, resType,
getFCmpPredicate(cast<llvm::FCmpInst>(inst)->getPredicate()), lhs, rhs);
mapValue(inst, res);
return success();
}
case llvm::Instruction::Br: {
Expand Down Expand Up @@ -987,17 +1008,12 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
}
case llvm::Instruction::PHI: {
Type type = convertType(inst->getType());
instMap[inst] = b.getInsertionBlock()->addArgument(
type, translateLoc(inst->getDebugLoc()));
mapValue(inst, b.getInsertionBlock()->addArgument(
type, translateLoc(inst->getDebugLoc())));
return success();
}
case llvm::Instruction::Call: {
llvm::CallInst *ci = cast<llvm::CallInst>(inst);

// For all intrinsics, try to generate to the corresponding op.
if (succeeded(convertIntrinsic(b, ci)))
return success();

SmallVector<llvm::Value *> args(ci->args());
SmallVector<Value> ops = processValues(args);
SmallVector<Type, 2> tys;
Expand All @@ -1015,7 +1031,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
op = b.create<CallOp>(loc, tys, ops);
}
if (!ci->getType()->isVoidTy())
instMap[inst] = op->getResult(0);
mapValue(inst, op->getResult(0));
return success();
}
case llvm::Instruction::LandingPad: {
Expand All @@ -1026,7 +1042,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
ops.push_back(processConstant(lpi->getClause(i)));

Type ty = convertType(lpi->getType());
instMap[inst] = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
Value res = b.create<LandingpadOp>(loc, ty, lpi->isCleanup(), ops);
mapValue(inst, res);
return success();
}
case llvm::Instruction::Invoke: {
Expand Down Expand Up @@ -1057,7 +1074,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
}

if (!ii->getType()->isVoidTy())
instMap[inst] = op->getResult(0);
mapValue(inst, op->getResult(0));
return success();
}
case llvm::Instruction::Fence: {
Expand Down Expand Up @@ -1087,7 +1104,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
getLLVMAtomicOrdering(atomicInst->getOrdering());

Type type = convertType(inst->getType());
instMap[inst] = b.create<AtomicRMWOp>(loc, type, binOp, ptr, val, ordering);
Value res = b.create<AtomicRMWOp>(loc, type, binOp, ptr, val, ordering);
mapValue(inst, res);
return success();
}
case llvm::Instruction::AtomicCmpXchg: {
Expand All @@ -1102,8 +1120,9 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
getLLVMAtomicOrdering(cmpXchgInst->getFailureOrdering());

Type type = convertType(inst->getType());
instMap[inst] = b.create<AtomicCmpXchgOp>(loc, type, ptr, cmpVal, newVal,
ordering, failOrdering);
Value res = b.create<AtomicCmpXchgOp>(loc, type, ptr, cmpVal, newVal,
ordering, failOrdering);
mapValue(inst, res);
return success();
}
case llvm::Instruction::GetElementPtr: {
Expand All @@ -1123,8 +1142,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
}

Type type = convertType(inst->getType());
instMap[inst] =
b.create<GEPOp>(loc, type, sourceElementType, basePtr, indices);
Value res = b.create<GEPOp>(loc, type, sourceElementType, basePtr, indices);
mapValue(inst, res);
return success();
}
case llvm::Instruction::InsertValue: {
Expand All @@ -1134,7 +1153,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {

SmallVector<int64_t> indices;
llvm::append_range(indices, ivInst->getIndices());
instMap[inst] = b.create<InsertValueOp>(loc, aggOperand, inserted, indices);
Value res = b.create<InsertValueOp>(loc, aggOperand, inserted, indices);
mapValue(inst, res);
return success();
}
case llvm::Instruction::ExtractValue: {
Expand All @@ -1143,7 +1163,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {

SmallVector<int64_t> indices;
llvm::append_range(indices, evInst->getIndices());
instMap[inst] = b.create<ExtractValueOp>(loc, aggOperand, indices);
Value res = b.create<ExtractValueOp>(loc, aggOperand, indices);
mapValue(inst, res);
return success();
}
case llvm::Instruction::ShuffleVector: {
Expand All @@ -1152,7 +1173,8 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
Value vec2 = processValue(svInst->getOperand(1));

SmallVector<int32_t> mask(svInst->getShuffleMask());
instMap[inst] = b.create<ShuffleVectorOp>(loc, vec1, vec2, mask);
Value res = b.create<ShuffleVectorOp>(loc, vec1, vec2, mask);
mapValue(inst, res);
return success();
}
}
Expand Down Expand Up @@ -1191,7 +1213,7 @@ void Importer::processFunctionAttributes(llvm::Function *func,

LogicalResult Importer::processFunction(llvm::Function *f) {
blocks.clear();
instMap.clear();
valueMapping.clear();
unknownInstMap.clear();

auto functionType =
Expand Down Expand Up @@ -1262,8 +1284,9 @@ LogicalResult Importer::processFunction(llvm::Function *f) {

// Add function arguments to the entry block.
for (const auto &kv : llvm::enumerate(f->args())) {
instMap[&kv.value()] = blockList[0]->addArgument(
functionType.getParamType(kv.index()), fop.getLoc());
mapValue(&kv.value(),
blockList[0]->addArgument(functionType.getParamType(kv.index()),
fop.getLoc()));
}

for (auto bbs : llvm::zip(*f, blockList)) {
Expand All @@ -1274,8 +1297,8 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
// Now that all instructions are guaranteed to have been visited, ensure
// any unknown uses we encountered are remapped.
for (auto &llvmAndUnknown : unknownInstMap) {
assert(instMap.count(llvmAndUnknown.first));
Value newValue = instMap[llvmAndUnknown.first];
assert(valueMapping.count(llvmAndUnknown.first));
Value newValue = valueMapping[llvmAndUnknown.first];
Value oldValue = llvmAndUnknown.second->getResult(0);
oldValue.replaceAllUsesWith(newValue);
llvmAndUnknown.second->erase();
Expand Down
12 changes: 8 additions & 4 deletions mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,17 @@ static LogicalResult emitOneIntrBuilder(const Record &record, raw_ostream &os) {
if (succeeded(argIndex)) {
// Process the argument value assuming the MLIR and LLVM operand orders
// match and there are no optional or variadic arguments.
bs << formatv("processValue(callInst->getArgOperand({0}))", *argIndex);
bs << formatv("processValue(llvmOperands[{0}])", *argIndex);
} else if (isResultName(op, name)) {
assert(op.getNumResults() == 1 &&
"expected operation to have one result");
bs << formatv("mapValue(inst)");
} else if (name == "_int_attr") {
bs << "matchIntegerAttr";
} else if (name == "_resultType") {
bs << "convertType(callInst->getType())";
bs << "convertType(inst->getType())";
} else if (name == "_location") {
bs << "translateLoc(callInst->getDebugLoc())";
bs << "translateLoc(inst->getDebugLoc())";
} else if (name == "_builder") {
bs << "odsBuilder";
} else if (name == "_qualCppClassName") {
Expand All @@ -228,7 +232,7 @@ static LogicalResult emitOneIntrBuilder(const Record &record, raw_ostream &os) {
bs << '$';
} else {
return emitError(name +
" is neither a known keyword nor an argument of " +
" is not a known keyword, argument, or result of " +
op.getOperationName());
}
// Finally, only keep the untraversed part of the string.
Expand Down

0 comments on commit a2122a0

Please sign in to comment.