Skip to content

Commit

Permalink
Bugfix operator fusion for residual block with layout transform
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Sep 22, 2018
1 parent 7beafdd commit 5b25512
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
3 changes: 2 additions & 1 deletion nnvm/src/compiler/graph_compile.cc
Expand Up @@ -109,13 +109,14 @@ nnvm::Graph GraphCompile(const nnvm::Graph& g) {
inputs.push_back(it->second);
}
// Find master idx in the subgraph.
int sub_master_idx = 0;
int sub_master_idx = -1;
for (uint32_t i = 0; i < subidx.num_nodes(); i++) {
if (subidx[i].source->op() == idx[master].source->op()) {
sub_master_idx = i;
break;
}
}
CHECK_NE(sub_master_idx, -1);
fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx);
for (LoweredFunc f : fe.compiled_func->funcs) {
if (!func_set.count(f.get())) {
Expand Down
17 changes: 14 additions & 3 deletions nnvm/src/compiler/graph_fuse.cc
Expand Up @@ -136,11 +136,15 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {

// Point to the group root id of each node.
GroupVec group_vec(idx.num_nodes(), -1);
std::vector<std::vector<uint32_t> > node_ids_per_group(idx.num_nodes());
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
uint32_t nid = i - 1;
const auto& inode = idx[nid];
bool is_root = false;
if (group_vec[nid] == -1) {
group_vec[nid] = nid;
node_ids_per_group[nid].push_back(nid);
is_root = true;
}

// Check if injective op and out_ewise_fusable op (e.g. conv2d) are in the same group.
Expand All @@ -156,7 +160,15 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
}
}
// Change the master node from out_ewise_fusable op to itself
if (parent_injective && parent_out_ewise) master_vec[nid] = nid;
if (parent_injective && parent_out_ewise) {
master_vec[nid] = nid;
if (!is_root) {
// Children nodes in the same group might be pointing to a master node in a different group.
for (uint32_t j : node_ids_per_group[group_vec[nid]]) {
master_vec[j] = nid;
}
}
}

// Propagate the group id.
for (const auto& e : inode.inputs) {
Expand All @@ -172,6 +184,7 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
CHECK(group_vec[e.node_id] == -1||
group_vec[e.node_id] == group_vec[nid]);
group_vec[e.node_id] = group_vec[nid];
node_ids_per_group[group_vec[nid]].push_back(e.node_id);
}
}
}
Expand Down Expand Up @@ -223,12 +236,10 @@ nnvm::Graph GraphFindFusibleGroups(nnvm::Graph g) {
*/
if (opt_level >= 1) {
std::vector<std::vector<uint32_t> > children_group_ids(idx.num_nodes());
std::vector<std::vector<uint32_t> > node_ids_per_group(idx.num_nodes());
for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
CHECK_NE(group_vec[nid], -1);
node_ids_per_group[group_vec[nid]].push_back(nid);
if (inode.inputs.size() != 1) continue;
const uint32_t parent_nid = inode.inputs[0].node_id;
// if parent node has more than one child, record each child's group id.
Expand Down

0 comments on commit 5b25512

Please sign in to comment.