forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
partition_graph.cc
557 lines (493 loc) · 19.2 KB
/
partition_graph.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/relay/transforms/partition_graph.cc
*
* \brief Partition an input function into multiple functions according based
* on the inserted annotation nodes (i.e. compiler_begin and compiler_end).
* These nodes are used as boundaries to partition the Relay function into
* multiple regions that can be offloaded to different accelerators/backends.
*
* Each of these paritioned functions, a.k.a regions, will be viewed as
* external functions, and they will use the provided compiler for codegen.
*/
#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "../analysis/annotated_region_set.h"
#include "../backend/utils.h"
namespace tvm {
namespace relay {
namespace partitioning {
// Cache compiler_begin and compiler_end annotation ops for equivalence check to
// reduce registry lookup overhead.
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
/*!
* \brief The checker that verifies if a Relay program is annotated correctly
* for partitioning.
*/
class AnnotationChecker : public ExprVisitor {
public:
bool Check() {
if (!found_start_ && !found_end_) {
LOG(WARNING) << "No compiler annotation found";
} else if (!found_start_) {
LOG(ERROR) << "compiler_begin annotation is missing";
return false;
} else if (!found_end_) {
LOG(ERROR) << "compiler_end annotation is missing";
return false;
}
return true;
}
void VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return;
} else if (call->op == compiler_begin_op) {
found_start_ = true;
} else if (call->op == compiler_end_op) {
found_end_ = true;
}
}
private:
bool found_start_{false};
bool found_end_{false};
};
/*! \brief This class partitions the expr labeled with begin and end annotations
* into function containing multiple regions. Each region is labeled with
* a compiler attribute so that it will be handled by any compilers that are not
* in the TVM stack.
*
* Input : A Relay module that have functions with disjoint annotated regions
* using compiler_begin and compiler_end. There could be multiple
* outputs.
*
* Output : A Relay module with global functions for such disjoint annotated
* regions with calls inserted at the respective location
*
* Dependencies : AnnotatedRegionSet Utility class.
*
* Methodology :
* 1) The AnnotatedRegionSet utility class is able to construct a collection
* of nodes that are bound by a given annotation -- here we use
* compiler_begin and compiler_end
* 2) Initially, for each function in the module RegionSets are populated.
* 3) Then, Vistor pass is traversed until a compiler_end node is encountered
* that belongs to a "region".
* 4) When the first compiler_end of a given annotated region is found,
* a function is formed and inserted.
* a) if the region has multiple outputs, a Tuple node (capturing
* all outputs) is returned.
* 5) Thereafter, if we encounter an another output of the same annotated
* region, it is important to note that the function is already formed.
* Therefore, it will lookup the function and add a TupleGetItemNode.
* a) We will use the location index of "rets" of each Region" of
* AnnotatedRegionSet as TupleGetItemNode index.
* 6) Therefore, functions will be created for all annotated regions.
* The name for each global function is created using "Region" id and
* the compiler name.
*/
class Partitioner : public ExprMutator {
public:
explicit Partitioner(const IRModule& module) : module_(module) {
for (auto f : module->functions) {
GlobalVar f_var = f.first;
BaseFunc f_func = f.second;
// Creating regionset per function in the module
auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
partitioning::compiler_end_op);
regions_sets_[region_set] = f_func;
}
}
Expr VisitExpr_(const CallNode* call) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
return ExprMutator::VisitExpr_(call);
} else if (call->op == compiler_begin_op) {
// The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph.
Expr parent = call->args[0];
auto input_expr = VisitExpr(parent);
// Backtrace the parent to find the first ancestor node that is not a begin or end op
while (const auto* parent_call = parent.as<CallNode>()) {
if (parent_call->op == compiler_begin_op ||
parent_call->op == compiler_end_op) {
parent = parent_call->args[0];
} else {
break;
}
}
AnnotatedRegion sg = GetRegion(GetRef<Call>(call));
int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1);
if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
return shared_output_[parent][sg];
} else {
// The type of the created variable is the same as the compiler_begin
// node.
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string varname =
target + "_" + std::to_string(sg->GetID()) + "_i" + std::to_string(index);
auto var = Var(varname, GetRef<Call>(call)->checked_type_);
std::pair<Var, Expr> cand = std::make_pair(var, input_expr);
if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
region_args[sg].end()) {
region_args[sg].push_back(cand);
}
shared_output_[parent][sg] = var;
return std::move(var);
}
} else {
CHECK_EQ(call->op, compiler_end_op);
// The annotation node is inserted on edge so it must have only one
// argument.
CHECK_EQ(call->args.size(), 1U);
AnnotatedRegion region = GetRegion(GetRef<Call>(call));
// TODO(@manupa-arm) : need to use the parent function (to which region
// belongs to) name/key for the funtions that are created
BaseFunc f = GetFunc(GetRef<Call>(call));
// Traverse subgraph inputs.
auto input = VisitExpr(call->args[0]);
CHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call);
// functions are created for each annotated regions,
// when their first output is encountered.
// If multiple outputs are there, a tuple node is inserted at the end.
// region_function_calls is map that maintains
// (each annotated regions) --> created function
if (region_function_calls.find(region) != region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the
// region or the same output is being accessed multiple times.
// Thus, the function is always created and at the end there
// would be a tuple node Therefore, we insert a tuple get item node.
if (region->GetOutputs().size() == 1) {
return region_function_calls[region];
}
// Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
auto tuple_get_item_ = TupleGetItem(sg_call, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
} else {
// First time this region is encountered in the traversal
// Creating the function
Array<Expr> fields;
for (auto ret : region->GetOutputs()) {
auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
fields.push_back(ret_expr);
}
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;
for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}
Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}
std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem
// node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
auto tuple_get_item_ = TupleGetItem(ret, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
}
}
}
}
Expr VisitExpr_(const TupleNode* op) final {
auto region = GetRegion(GetRef<Tuple>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Array<Expr> fields;
for (auto field : op->fields) {
fields.push_back(VisitExpr(field));
}
return Tuple(fields);
}
}
Expr VisitExpr_(const TupleGetItemNode* g) final {
auto region = GetRegion(GetRef<TupleGetItem>(g));
if (!region.defined()) {
return ExprMutator::VisitExpr_(g);
} else {
auto t = VisitExpr(g->tuple);
return TupleGetItem(t, g->index);
}
}
Expr VisitExpr_(const FunctionNode* op) final {
auto region = GetRegion(GetRef<Function>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Array<Var> params;
for (auto param : op->params) {
Var new_param = Downcast<Var>(VisitExpr(param));
params.push_back(new_param);
}
auto body = VisitExpr(op->body);
return Function(params, body, op->ret_type, op->type_params, op->attrs);
}
}
Expr VisitExpr_(const LetNode* op) final {
auto region = GetRegion(GetRef<Let>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Var var = Downcast<Var>(VisitExpr(op->var));
auto value = VisitExpr(op->value);
auto body = VisitExpr(op->body);
return Let(var, value, body);
}
}
Expr VisitExpr_(const IfNode* op) final {
auto region = GetRegion(GetRef<If>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
auto guard = VisitExpr(op->cond);
auto true_b = VisitExpr(op->true_branch);
auto false_b = VisitExpr(op->false_branch);
return If(guard, true_b, false_b);
}
}
Expr VisitExpr_(const RefCreateNode* op) final {
auto region = GetRegion(GetRef<RefCreate>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Expr value = VisitExpr(op->value);
return RefCreate(value);
}
}
Expr VisitExpr_(const RefReadNode* op) final {
auto region = GetRegion(GetRef<RefRead>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Expr ref = VisitExpr(op->ref);
return RefRead(ref);
}
}
Expr VisitExpr_(const RefWriteNode* op) final {
auto region = GetRegion(GetRef<RefWrite>(op));
if (!region.defined()) {
return ExprMutator::VisitExpr_(op);
} else {
Expr ref = VisitExpr(op->ref);
Expr value = VisitExpr(op->value);
return RefWrite(ref, value);
}
}
IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
private:
/*!
* \brief Get the region an expression belongs to
* if its in a region.
*/
AnnotatedRegion GetRegion(const Expr& e) {
for (auto sg_set_it : regions_sets_) {
auto sg_set = sg_set_it.first;
AnnotatedRegion sg = sg_set->GetRegion(e);
if (sg.defined()) {
return sg;
}
}
return AnnotatedRegion(nullptr);
}
/*!
* \brief Get the function an expression belongs to
* if its in a region.
*/
BaseFunc GetFunc(const Expr& e) {
for (auto sg_set_it : regions_sets_) {
auto sg_set = sg_set_it.first;
auto func = sg_set_it.second;
AnnotatedRegion sg = sg_set->GetRegion(e);
if (sg.defined()) {
return func;
}
}
return BaseFunc(nullptr);
}
/*!
* \brief Get the index of the argument;
* this is to be used as tuplegetitem idx
*/
int GetArgIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetInputs()) {
if (arg == arg_) {
return idx;
}
idx++;
}
return -1;
}
/*!
* \brief Get the index of the return(output);
* this is to be used as tuplegetitem idx
*/
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetOutputs()) {
if (Downcast<Call>(arg)->args[0] == Downcast<Call>(arg_)->args[0]) {
return idx;
}
idx++;
}
return -1;
}
/*!
* \brief This map maintains the already created function calls.
* This is required in the multi-output scenario, to link rest of the outputs
* to call
*/
std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
/*!
* \brief This map maintains arguments (of region) visits through visitor
* patterns. Those arguement var and expression will be used to when creating
* the function.
*/
std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
region_args;
/*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it
* belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
/*!\brief Cache the output that is shared by different nodes. */
using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;
/*!\brief The IRModule used for partitioning. */
IRModule module_;
};
class DefaultRemover : public ExprMutator {
public:
explicit DefaultRemover(const IRModule& module) : module_(module) {}
IRModule Remove() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}
Expr VisitExpr_(const CallNode* call) final {
auto attrs = call->attrs.as<CompilerAttrs>();
if (attrs != nullptr && attrs->compiler == "default") {
return VisitExpr(call->args[0]);
}
return ExprMutator::VisitExpr_(call);
}
private:
IRModule module_;
};
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) {
// TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
// by treating them as un-annotated, but we don't have it yet. This workaround pass removes
// all "default" annotations and should be deleted in the future.
auto new_m = partitioning::DefaultRemover(m).Remove();
return partitioning::Partitioner(new_m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph);
} // namespace transform
} // namespace relay
} // namespace tvm