Skip to content

Commit

Permalink
[mlir][llvm] Ordered traversal in LLVM IR import.
Browse files Browse the repository at this point in the history
The revision performs a topological sort of the blocks to
ensure the operations are processed in dominance order.
After the change, we do not need to introduce dummy
instructions if an operand has not yet been processed.
Additionally, the revision also moves and simplifies the
control-flow related tests to a separate test file.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D136230
  • Loading branch information
gysit committed Oct 19, 2022
1 parent 579ca5e commit 3883615
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 193 deletions.
98 changes: 54 additions & 44 deletions mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Target/LLVMIR/TypeFromLLVM.h"
#include "mlir/Tools/mlir-translate/Translation.h"

#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Attributes.h"
Expand Down Expand Up @@ -306,6 +307,22 @@ mlir::translateDataLayout(const llvm::DataLayout &dataLayout,
return DataLayoutSpecAttr::get(context, entries);
}

/// Get a topologically sorted list of blocks for the given function.
static SetVector<llvm::BasicBlock *>
getTopologicallySortedBlocks(llvm::Function *func) {
SetVector<llvm::BasicBlock *> blocks;
for (llvm::BasicBlock &bb : *func) {
if (blocks.count(&bb) == 0) {
llvm::ReversePostOrderTraversal<llvm::BasicBlock *> traversal(&bb);
blocks.insert(traversal.begin(), traversal.end());
}
}
assert(blocks.size() == func->getBasicBlockList().size() &&
"some blocks are not sorted");

return blocks;
}

// Handles importing globals and functions from an LLVM module.
namespace {
class Importer {
Expand All @@ -327,6 +344,18 @@ class Importer {
return mlir;
}

/// Stores the mapping between an LLVM block and its MLIR counterpart.
void mapBlock(llvm::BasicBlock *llvm, Block *mlir) {
auto result = blockMapping.try_emplace(llvm, mlir);
(void)result;
assert(result.second && "attempting to map a block that is already mapped");
}

/// Returns the MLIR block mapped to the given LLVM block.
Block *lookupBlock(llvm::BasicBlock *block) const {
return blockMapping.lookup(block);
}

/// Returns the remapped version of `value` or a placeholder that will be
/// remapped later if the defining instruction has not yet been visited.
Value processValue(llvm::Value *value);
Expand Down Expand Up @@ -413,13 +442,10 @@ class Importer {
return std::prev(module.getBody()->end());
}

/// Remapped blocks, for the current function.
DenseMap<llvm::BasicBlock *, Block *> blocks;
/// Mappings between original and imported values. These are function-local.
/// Function-local mapping between original and imported block.
DenseMap<llvm::BasicBlock *, Block *> blockMapping;
/// Function-local mapping between original and imported values.
DenseMap<llvm::Value *, Value> valueMapping;
/// Instructions that had not been defined when first encountered as a use.
/// Maps to the dummy Operation that was created in processValue().
DenseMap<llvm::Value *, Operation *> unknownInstMap;
/// Uniquing map of GlobalVariables.
DenseMap<llvm::GlobalVariable *, GlobalOp> globals;
/// The stateful type translator (contains named structs).
Expand Down Expand Up @@ -750,16 +776,7 @@ Value Importer::processValue(llvm::Value *value) {
if (it != valueMapping.end())
return it->second;

// We don't expect to see instructions in dominator order. If we haven't seen
// this instruction yet, create an unknown op and remap it later.
if (isa<llvm::Instruction>(value)) {
Type type = convertType(value->getType());
unknownInstMap[value] =
b.create(UnknownLoc::get(context), b.getStringAttr("llvm.unknown"),
/*operands=*/{}, type);
return unknownInstMap[value]->getResult(0);
}

// Process constants such as immediate arguments that have no mapping.
if (auto *c = dyn_cast<llvm::Constant>(value))
return processConstant(c);

Expand Down Expand Up @@ -829,7 +846,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
SmallVector<Value, 4> blockArguments;
if (failed(processBranchArgs(brInst, succ, blockArguments)))
return failure();
state.addSuccessors(blocks[succ]);
state.addSuccessors(lookupBlock(succ));
state.addOperands(blockArguments);
operandSegmentSizes[i + 1] = blockArguments.size();
}
Expand Down Expand Up @@ -866,10 +883,10 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
return failure();
caseOperandRefs[i] = caseOperands[i];
caseValues[i] = caseHandle.getCaseValue()->getSExtValue();
caseBlocks[i] = blocks[succBB];
caseBlocks[i] = lookupBlock(succBB);
}

b.create<SwitchOp>(loc, condition, blocks[defaultBB], defaultBlockArgs,
b.create<SwitchOp>(loc, condition, lookupBlock(defaultBB), defaultBlockArgs,
caseValues, caseBlocks, caseOperandRefs);
return success();
}
Expand Down Expand Up @@ -931,12 +948,12 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
if (llvm::Function *callee = ii->getCalledFunction()) {
op = b.create<InvokeOp>(
loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops,
blocks[ii->getNormalDest()], normalArgs, blocks[ii->getUnwindDest()],
unwindArgs);
lookupBlock(ii->getNormalDest()), normalArgs,
lookupBlock(ii->getUnwindDest()), unwindArgs);
} else {
ops.insert(ops.begin(), processValue(ii->getCalledOperand()));
op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
normalArgs, blocks[ii->getUnwindDest()],
op = b.create<InvokeOp>(loc, tys, ops, lookupBlock(ii->getNormalDest()),
normalArgs, lookupBlock(ii->getUnwindDest()),
unwindArgs);
}

Expand Down Expand Up @@ -1001,9 +1018,8 @@ void Importer::processFunctionAttributes(llvm::Function *func,
}

LogicalResult Importer::processFunction(llvm::Function *f) {
blocks.clear();
blockMapping.clear();
valueMapping.clear();
unknownInstMap.clear();

auto functionType =
convertType(f->getFunctionType()).dyn_cast<LLVMFunctionType>();
Expand Down Expand Up @@ -1064,34 +1080,28 @@ LogicalResult Importer::processFunction(llvm::Function *f) {
return success();

// Eagerly create all blocks.
SmallVector<Block *, 4> blockList;
for (llvm::BasicBlock &bb : *f) {
blockList.push_back(b.createBlock(&fop.getBody(), fop.getBody().end()));
blocks[&bb] = blockList.back();
Block *block = b.createBlock(&fop.getBody(), fop.getBody().end());
mapBlock(&bb, block);
}
currentEntryBlock = blockList[0];
currentEntryBlock = &fop.getFunctionBody().getBlocks().front();

// Add function arguments to the entry block.
for (const auto &kv : llvm::enumerate(f->args())) {
mapValue(&kv.value(),
blockList[0]->addArgument(functionType.getParamType(kv.index()),
fop.getLoc()));
for (const auto &it : llvm::enumerate(f->args())) {
BlockArgument blockArg = fop.getFunctionBody().addArgument(
functionType.getParamType(it.index()), fop.getLoc());
mapValue(&it.value(), blockArg);
}

for (auto bbs : llvm::zip(*f, blockList)) {
if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))
// Process the blocks in topological order. The ordered traversal ensures
// operands defined in a dominating block have a valid mapping to an MLIR
// value once a block is translated.
SetVector<llvm::BasicBlock *> blocks = getTopologicallySortedBlocks(f);
for (llvm::BasicBlock *bb : blocks) {
if (failed(processBasicBlock(bb, lookupBlock(bb))))
return failure();
}

// Now that all instructions are guaranteed to have been visited, ensure
// any unknown uses we encountered are remapped.
for (auto &llvmAndUnknown : unknownInstMap) {
assert(valueMapping.count(llvmAndUnknown.first));
Value newValue = valueMapping[llvmAndUnknown.first];
Value oldValue = llvmAndUnknown.second->getResult(0);
oldValue.replaceAllUsesWith(newValue);
llvmAndUnknown.second->erase();
}
return success();
}

Expand Down
149 changes: 0 additions & 149 deletions mlir/test/Target/LLVMIR/Import/basic.ll
Expand Up @@ -176,49 +176,6 @@ if.end:
}
; CHECK-DBG: } loc(#[[UNKNOWNLOC]])

; Test that instructions that dominate can be out of sequential order.
; CHECK-LABEL: llvm.func @f2(%arg0: i64) -> i64 {
; CHECK-DAG: %[[c3:[0-9]+]] = llvm.mlir.constant(3 : i64) : i64
define i64 @f2(i64 %a) noduplicate {
entry:
; CHECK: llvm.br ^bb2
br label %next

; CHECK: ^bb1:
end:
; CHECK: llvm.return %1
ret i64 %b

; CHECK: ^bb2:
next:
; CHECK: %1 = llvm.add %arg0, %[[c3]] : i64
%b = add i64 %a, 3
; CHECK: llvm.br ^bb1
br label %end
}

; Test arguments/phis.
; CHECK-LABEL: llvm.func @f2_phis(%arg0: i64) -> i64 {
; CHECK-DAG: %[[c3:[0-9]+]] = llvm.mlir.constant(3 : i64) : i64
define i64 @f2_phis(i64 %a) noduplicate {
entry:
; CHECK: llvm.br ^bb2
br label %next

; CHECK: ^bb1(%1: i64):
end:
%c = phi i64 [ %b, %next ]
; CHECK: llvm.return %1
ret i64 %c

; CHECK: ^bb2:
next:
; CHECK: %2 = llvm.add %arg0, %[[c3]] : i64
%b = add i64 %a, 3
; CHECK: llvm.br ^bb1
br label %end
}

; CHECK-LABEL: llvm.func @f3() -> !llvm.ptr<i32>
define i32* @f3() {
; CHECK: %[[c:[0-9]+]] = llvm.mlir.addressof @g2 : !llvm.ptr<f64>
Expand Down Expand Up @@ -342,112 +299,6 @@ define i32 @useFreezeOp(i32 %x) {
ret i32 0
}

; Switch instruction
declare void @g(i32)

; CHECK-LABEL: llvm.func @simple_switch(%arg0: i32) {
define void @simple_switch(i32 %val) {
; CHECK: %[[C0:.+]] = llvm.mlir.constant(11 : i32) : i32
; CHECK: %[[C1:.+]] = llvm.mlir.constant(87 : i32) : i32
; CHECK: %[[C2:.+]] = llvm.mlir.constant(78 : i32) : i32
; CHECK: %[[C3:.+]] = llvm.mlir.constant(94 : i32) : i32
; CHECK: %[[C4:.+]] = llvm.mlir.constant(1 : i32) : i32
; CHECK: llvm.switch %arg0 : i32, ^[[BB5:.+]] [
; CHECK: 0: ^[[BB1:.+]],
; CHECK: 9: ^[[BB2:.+]],
; CHECK: 994: ^[[BB3:.+]],
; CHECK: 1154: ^[[BB4:.+]]
; CHECK: ]
switch i32 %val, label %def [
i32 0, label %one
i32 9, label %two
i32 994, label %three
i32 1154, label %four
]

; CHECK: ^[[BB1]]:
; CHECK: llvm.call @g(%[[C4]]) : (i32) -> ()
; CHECK: llvm.return
one:
call void @g(i32 1)
ret void
; CHECK: ^[[BB2]]:
; CHECK: llvm.call @g(%[[C3]]) : (i32) -> ()
; CHECK: llvm.return
two:
call void @g(i32 94)
ret void
; CHECK: ^[[BB3]]:
; CHECK: llvm.call @g(%[[C2]]) : (i32) -> ()
; CHECK: llvm.return
three:
call void @g(i32 78)
ret void
; CHECK: ^[[BB4]]:
; CHECK: llvm.call @g(%[[C1]]) : (i32) -> ()
; CHECK: llvm.return
four:
call void @g(i32 87)
ret void
; CHECK: ^[[BB5]]:
; CHECK: llvm.call @g(%[[C0]]) : (i32) -> ()
; CHECK: llvm.return
def:
call void @g(i32 11)
ret void
}

; CHECK-LABEL: llvm.func @switch_args(%arg0: i32) {
define void @switch_args(i32 %val) {
; CHECK: %[[C0:.+]] = llvm.mlir.constant(44 : i32) : i32
; CHECK: %[[C1:.+]] = llvm.mlir.constant(34 : i32) : i32
; CHECK: %[[C2:.+]] = llvm.mlir.constant(33 : i32) : i32
%pred = icmp ult i32 %val, 87
br i1 %pred, label %bbs, label %bb1

bb1:
%vx = add i32 %val, 22
%pred2 = icmp ult i32 %val, 94
br i1 %pred2, label %bb2, label %bb3

bb2:
%vx0 = add i32 %val, 23
br label %one

bb3:
br label %def

; CHECK: %[[V1:.+]] = llvm.add %arg0, %[[C2]] : i32
; CHECK: %[[V2:.+]] = llvm.add %arg0, %[[C1]] : i32
; CHECK: %[[V3:.+]] = llvm.add %arg0, %[[C0]] : i32
; CHECK: llvm.switch %arg0 : i32, ^[[BBD:.+]](%[[V3]] : i32) [
; CHECK: 0: ^[[BB1:.+]](%[[V1]], %[[V2]] : i32, i32)
; CHECK: ]
bbs:
%vy = add i32 %val, 33
%vy0 = add i32 %val, 34
%vz = add i32 %val, 44
switch i32 %val, label %def [
i32 0, label %one
]

; CHECK: ^[[BB1]](%[[BA0:.+]]: i32, %[[BA1:.+]]: i32):
one: ; pred: bb2, bbs
%v0 = phi i32 [%vx, %bb2], [%vy, %bbs]
%v1 = phi i32 [%vx0, %bb2], [%vy0, %bbs]
; CHECK: llvm.add %[[BA0]], %[[BA1]] : i32
%vf = add i32 %v0, %v1
call void @g(i32 %vf)
ret void

; CHECK: ^[[BBD]](%[[BA2:.+]]: i32):
def: ; pred: bb3, bbs
%v2 = phi i32 [%vx, %bb3], [%vz, %bbs]
; CHECK: llvm.call @g(%[[BA2]])
call void @g(i32 %v2)
ret void
}

; Varadic function definition
%struct.va_list = type { i8* }

Expand Down

0 comments on commit 3883615

Please sign in to comment.