/
ShapeInferencePass.cpp
122 lines (108 loc) · 4.3 KB
/
ShapeInferencePass.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a Function level pass performing interprocedural
// propagation of array shapes through function specialization.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "toy/ShapeInferenceInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
#define DEBUG_TYPE "shape-inference"
using namespace mlir;
using namespace toy;
/// Include the auto-generated definitions for the shape inference interfaces.
#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
namespace {
/// The ShapeInferencePass is a pass that performs intra-procedural
/// shape inference.
///
/// Algorithm:
///
/// 1) Build a worklist containing all the operations that return a
/// dynamically shaped tensor: these are the operations that need shape
/// inference.
/// 2) Iterate on the worklist:
/// a) find an operation to process: the next ready operation in the
/// worklist has all of its arguments non-generic,
/// b) if no operation is found, break out of the loop,
/// c) remove the operation from the worklist,
/// d) infer the shape of its output from the argument types.
/// 3) If the worklist is empty, the algorithm succeeded.
///
struct ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, OperationPass<toy::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass)
void runOnOperation() override {
auto f = getOperation();
// Populate the worklist with the operations that need shape inference:
// these are operations that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op))
opWorklist.insert(op);
});
// Iterate on the operations in the worklist until all operations have been
// inferred or no change happened (fix point).
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
Operation *op = *nextop;
opWorklist.erase(op);
// Ask the operation to infer its output shapes.
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
shapeOp.inferShapes();
} else {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
}
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
f.emitError("Shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
signalPassFailure();
}
}
/// A utility method that returns if the given operation has all of its
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return llvm::isa<RankedTensorType>(operandType);
});
}
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !llvm::isa<RankedTensorType>(resultType);
});
}
};
} // namespace
/// Create a Shape Inference pass.
std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
}