Skip to content

Commit

Permalink
[quant][graphmode] Support prim:TupleUnpack and prim::TupleConstruct (p…
Browse files Browse the repository at this point in the history
…ytorch#39895)

Summary: Pull Request resolved: pytorch#39895

Test Plan: Imported from OSS

Differential Revision: D22009854

fbshipit-source-id: a5dab2b4f943e5e047ba9e8573088adf66f5da6b
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Jun 16, 2020
1 parent eb358f4 commit f37b8e7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
12 changes: 9 additions & 3 deletions test/quantization/test_quantize_script.py
Expand Up @@ -2184,7 +2184,7 @@ def forward(self, x):
FileCheck().check_count("aten::dequantize", 1, exactly=True) \
.run(m.graph)

def test_quantize_general_shape_ops(self):
def test_general_shape_ops(self):
""" A test that checks dequantize will be swapped for
all supported general shape ops like aten::flatten
without actually checking for execution of these ops
Expand All @@ -2210,8 +2210,14 @@ def forward(self, x):
x = x.reshape([-1])
x = x.resize_(1, 1, x.numel())
x = x.view(-1)
# prim::ListConstruct
xs = [x, x]
y, x = xs
# prim::ListUnpack
x, y = xs
# prim::TupleConstruct
xs = (x, x)
# prim::TupleUnpack
x, y = xs
x = x.transpose(1, 2)
x = x.contiguous()
x, y = torch.chunk(x, 2)
Expand Down Expand Up @@ -2251,7 +2257,7 @@ def forward(self, x):
.check("aten::dequantize") \
.run(m.graph)

def test_quantize_general_value_ops(self):
def test_general_value_ops(self):
""" A test that checks correct patterns are produced for
all supported general value ops like aten::avg_pool2d \
without actually checking for execution of these ops
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/passes/quantization/helper.cpp
Expand Up @@ -295,9 +295,10 @@ std::vector<Value*> getPassThroughInputs(Value* v) {
inputs.push_back(output);
}
return inputs;
} else if (n->kind() == prim::ListUnpack) {
} else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
return {n->input(0)};
} else if (n->kind() == prim::ListConstruct) {
} else if (
n->kind() == prim::ListConstruct || n->kind() == prim::TupleConstruct) {
std::vector<Value*> inputs;
for (auto* v : n->inputs()) {
inputs.push_back(v);
Expand Down

0 comments on commit f37b8e7

Please sign in to comment.