Skip to content

Commit

Permalink
Fixed version error for PyTorch (intel#1231)
Browse files Browse the repository at this point in the history
  • Loading branch information
PenghuiCheng committed Sep 9, 2022
1 parent 81f63e2 commit 1ec33aa
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def _observer(algorithm,

if algorithm == 'placeholder' or dtype == torch.float: # pragma: no cover
return torch.quantization.PlaceholderObserver \
if get_torch_version() <= Version("1.7.0-rc1") \
if get_torch_version() <= Version("1.7.1") \
else torch.quantization.PlaceholderObserver.with_args(dtype=dtype,
compute_dtype=compute_dtype)
if algorithm == 'minmax':
Expand Down Expand Up @@ -491,12 +491,12 @@ def _fake_quantize(algorithm, scheme, granularity, dtype, compute_dtype='uint8')
Return:
fake quantization (object)
"""
version = get_torch_version()
if scheme == 'asym_float' \
and get_torch_version() >= Version("1.7.0-rc1"):
and version >= Version("1.7.0-rc1"):
return torch.quantization.default_float_qparams_observer
if algorithm == 'placeholder' or dtype == 'fp32': # pragma: no cover
return _observer(algorithm, scheme, granularity, dtype, compute_dtype=compute_dtype)
version = get_torch_version()
fake_quant = torch.quantization.FakeQuantize \
if version < Version("1.10.0-rc1") else \
torch.quantization.FusedMovingAvgObsFakeQuantize
Expand Down Expand Up @@ -1833,7 +1833,7 @@ def inspect_tensor(self,
iteration_list=None,
inspect_type='activation',
save_to_disk=False):
if self.version > Version("1.7.0-rc1"):
if self.version >= Version("1.8.0-rc1"):
from torch.fx import GraphModule
if type(model._model) == GraphModule: # pragma: no cover
assert False, "Inspect_tensor didn't support fx graph model now!"
Expand Down Expand Up @@ -2100,7 +2100,7 @@ def __init__(self, framework_specific_info):
assert IPEX_110 is not None and IPEX_112 is not None, 'Please install intel_extension_for_pytorch.'
self.version = get_torch_version()
IPEX_110 = True if Version("1.10.0-rc1") <= self.version and \
self.version <= Version("1.11.0-rc1") else False
self.version <= Version("1.11.0") else False
IPEX_112 = True if self.version >= Version("1.12.0-rc1") else False
query_config_file = "pytorch_ipex.yaml"
self.query_handler = PyTorchQuery(
Expand Down

0 comments on commit 1ec33aa

Please sign in to comment.