Skip to content

Commit

Permalink
Add required example_inputs argument to prepare_fx and prepare_qat_fx
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#77608

X-link: pytorch/fx2trt#76

X-link: facebookresearch/d2go#249

X-link: fairinternal/ClassyVision#104

X-link: pytorch/benchmark#916

X-link: facebookresearch/ClassyVision#791

Pull Request resolved: #68

FX Graph Mode Quantization needs to know whether an fx node is a floating point Tensor before it can decide whether to
insert observer/fake_quantize module or not, since we only insert observer/fake_quantize module for floating point Tensors.
Currently we have some hacks to support this by defining some rules like NON_OBSERVABLE_ARG_DICT (https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fx/utils.py#L496), but this approach is fragile and we do not plan to maintain it long term in the pytorch code base.

As we discussed in the design review, we'd need to ask users to provide sample args and sample keyword args
so that we can infer the type in a more robust way. This PR starts with changing the prepare_fx and prepare_qat_fx api to require user to either provide
example arguments thrugh example_inputs, Note this api doesn't support kwargs, kwargs can make pytorch/pytorch#76496 (comment) (comment) simpler, but
it will be rare, and even then we can still workaround with positional arguments, also torch.jit.trace(https://pytorch.org/docs/stable/generated/torch.jit.trace.html) and ShapeProp: https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py#L140 just have single positional args, we'll just use a single example_inputs argument for now.

If needed, we can extend the api with an optional example_kwargs. e.g. in case when there are a lot of arguments for forward and it makes more sense to
pass the arguments by keyword

BC-breaking Note:
Before:
```python
m = resnet18(...)
m = prepare_fx(m, qconfig_dict)
# or
m = prepare_qat_fx(m, qconfig_dict)
```
After:
```python
m = resnet18(...)
m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))
# or
m = prepare_qat_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 3, 224, 224),))
```

Reviewed By: vkuzo, andrewor14

Differential Revision: D35984526

fbshipit-source-id: 706c8df71722c9aa5082a6491734f0144f0dd670
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed May 21, 2022
1 parent f313fb8 commit 659d501
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 7 deletions.
6 changes: 5 additions & 1 deletion mobile_cv/arch/tests/test_fbnet_v2_quantize.py
Expand Up @@ -86,7 +86,11 @@ def test_qat(self):
model.train()

qconfig_dict = {"": torch.ao.quantization.get_default_qat_qconfig("fbgemm")}
model_prepared = quantize_fx.prepare_qat_fx(model, qconfig_dict)
example_inputs = (torch.rand(2, 3, 8, 8),)
model_prepared = quantize_fx.prepare_qat_fx(
model, qconfig_dict, example_inputs=example_inputs
)

print(f"Prepared model {model_prepared}")

# calibration
Expand Down
3 changes: 2 additions & 1 deletion mobile_cv/arch/tests/test_fbnet_v2_res_block.py
Expand Up @@ -94,7 +94,8 @@ def test_res_block_quantize_partial(self):
data = torch.zeros(1, 8, 4, 4)

qconfig_dict = qu.get_qconfig_dict(model, qconfig)
model = prepare_fx(model, qconfig_dict)
example_inputs = (data,)
model = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
model = convert_fx(model)
print(model)

Expand Down
3 changes: 2 additions & 1 deletion mobile_cv/arch/tests/test_utils_quantize_utils.py
Expand Up @@ -361,7 +361,8 @@ def forward(self, x):
model = MM().eval()
qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
qconfig_dict = qu.get_qconfig_dict(model, qconfig)
model = prepare_fx(model, qconfig_dict)
example_inputs = (torch.rand(1, 1, 3, 3),)
model = prepare_fx(model, qconfig_dict, example_inputs=example_inputs)
model = convert_fx(model)
print(model)

Expand Down
6 changes: 4 additions & 2 deletions mobile_cv/arch/utils/quantize_utils.py
Expand Up @@ -186,13 +186,15 @@ def set_quant_config(self, quant_cfg):
self.qconfig = quant_cfg
return self

def prepare(self, qconfig_dict=None):
def prepare(self, example_inputs, qconfig_dict=None):
if qconfig_dict is None:
qconfig_dict = get_qconfig_dict(self.model, self.qconfig)
if qconfig_dict is None:
qconfig_dict = {"": self.qconfig}
self._prepared_model = torch.ao.quantization.quantize_fx.prepare_fx(
self.model, qconfig_dict
self.model,
qconfig_dict=qconfig_dict,
example_inputs=example_inputs,
)
return self

Expand Down
3 changes: 2 additions & 1 deletion mobile_cv/model_zoo/tools/create_model.py
Expand Up @@ -140,9 +140,10 @@ def convert_int8_jit(args, model, data, folder_name="int8_jit"):
)
else:
quant = qu.PostQuantizationFX(model)
example_inputs = tuple(data)
quant_model = (
quant.set_quant_backend("default")
.prepare()
.prepare(example_inputs=example_inputs)
.calibrate_model([data], 1)
.convert_model()
)
Expand Down
6 changes: 5 additions & 1 deletion mobile_cv/model_zoo/tools/model_exporter.py
Expand Up @@ -309,19 +309,23 @@ def export_to_torchscript_int8(
):
cur_loader = itertools.chain([inputs], data_iter)

example_inputs = tuple(inputs)
if hasattr(task, "get_quantized_model"):
print("calling get quantized model")
ptq_model = task.get_quantized_model(model, cur_loader)
model_attrs = _get_model_attributes(ptq_model)
print("after calling get quantized model")
elif args.use_graph_mode_quant:
print(f"Post quantization using {args.post_quant_backend} backend fx mode...")
model_attrs = _get_model_attributes(model)
quant = quantize_utils.PostQuantizationFX(model)
ptq_model = (
quant.set_quant_backend(args.post_quant_backend)
.prepare()
.prepare(example_inputs=example_inputs)
.calibrate_model(cur_loader, 1)
.convert_model()
)
print("after calling callback")
else:
print(f"Post quantization using {args.post_quant_backend} backend...")
qa_model = task.get_quantizable_model(model)
Expand Down

0 comments on commit 659d501

Please sign in to comment.