Skip to content

Commit

Permalink
[StaticRuntime] Fusion pass for ClipRanges/GatherRanges/LengthsToOffs…
Browse files Browse the repository at this point in the history
…ets (pytorch#49113)

Summary: Pull Request resolved: pytorch#49113

Reviewed By: ajyu

Differential Revision: D25388512

fbshipit-source-id: 3daa5b9387a3a10b6c220688df06540c4d844aea
  • Loading branch information
Hao Lu authored and hwangdeyu committed Dec 23, 2020
1 parent edea937 commit 1f1c0f5
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions torch/csrc/jit/runtime/static/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,30 @@ void ConcatBatchMatMulBatchGather(std::shared_ptr<torch::jit::Graph>& graph) {
fuse.runOnGraph(graph);
}

void ClipRangesGatherRangesLengthsToOffsets(
std::shared_ptr<torch::jit::Graph>& graph) {
// TODO:: check restrictions for inputs; outputs not used elsewhere
std::string pattern = R"IR(
graph(%a, %b, %c, %d):
%y0 : Tensor = fb::clip_ranges(%b, %c)
%y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0)
%y3 : Tensor = fb::lengths_to_offsets(%y2, %d)
return (%y3, %y1))IR";
std::string fused_pattern = R"IR(
graph(%a, %b, %c, %d):
%y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d)
return (%y1, %y0))IR";
SubgraphRewriter fuse;
fuse.RegisterRewritePattern(pattern, fused_pattern);
fuse.runOnGraph(graph);
}

void FuseInferenceOpsForSparseNN(std::shared_ptr<torch::jit::Graph>& graph) {
#ifdef FBCODE_CAFFE2
ConcatAddMulReplaceNaNClip(graph);
CastedBatchOneHotLengths(graph);
ConcatBatchMatMulBatchGather(graph);
ClipRangesGatherRangesLengthsToOffsets(graph);
#endif
}

Expand Down

0 comments on commit 1f1c0f5

Please sign in to comment.