Skip to content

Commit

Permalink
add torch.amp bf16 support for ipex backend (#1497)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Nov 22, 2022
1 parent 773bb3c commit 2a361b8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
38 changes: 21 additions & 17 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -29,17 +29,13 @@
from ..utils import logger
from .query import QueryBackendCapability
from ..experimental.data.dataloaders.base_dataloader import BaseDataLoader
try: # pragma: no cover
import intel_extension_for_pytorch as ipex
IPEX = True
except: # pragma: no cover
IPEX = False


torch = LazyImport("torch")
json = LazyImport("json")
hvd = LazyImport("horovod.torch")
torch_utils = LazyImport("neural_compressor.adaptor.torch_utils")
ipex = LazyImport("intel_extension_for_pytorch")

REDUCE_RANGE = False if CpuInfo().vnni else True
logger.debug("Reduce range is {}".format(str(REDUCE_RANGE)))
Expand Down Expand Up @@ -1033,7 +1029,7 @@ def _get_quantizable_ops(self, model):


# get bf16 capability
if self.use_bf16 and (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \
if (CpuInfo().bf16 or os.getenv('FORCE_BF16') == '1') and \
(self.version.release >= Version("1.11.0").release):
self.bf16_ops = self.query_handler.get_op_types_by_precision("bf16")
bf16_ops = []
Expand Down Expand Up @@ -2148,8 +2144,6 @@ class PyTorch_IPEXAdaptor(TemplateAdaptor): # pragma: no cover
"""
def __init__(self, framework_specific_info):
super(PyTorch_IPEXAdaptor, self).__init__(framework_specific_info)

assert IPEX, "Please install intel-extension-for-pytorch."
self.version = get_torch_version()
query_config_file = "pytorch_ipex.yaml"
self.query_handler = PyTorchQuery(
Expand Down Expand Up @@ -2226,20 +2220,30 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
self.model_calibration(q_model, dataloader, iterations, None,
tune_cfg.get('calib_sampling_size', 1))
q_model.save_qconf_summary(qconf_summary=self.ipex_config_path)
q_model = ipex.quantization.convert(q_model)
with torch.no_grad():
try:
q_model = torch.jit.trace(q_model, example_inputs)
q_model = torch.jit.freeze(q_model.eval())
except:
q_model = torch.jit.trace(q_model, example_inputs, strict=False)
q_model = torch.jit.freeze(q_model.eval())
if self.use_bf16:
with torch.no_grad():
with torch.cpu.amp.autocast():
q_model = ipex.quantization.convert(q_model)
try:
q_model = torch.jit.trace(q_model, example_inputs)
q_model = torch.jit.freeze(q_model.eval())
except:
q_model = torch.jit.trace(q_model, example_inputs, strict=False)
q_model = torch.jit.freeze(q_model.eval())
else:
with torch.no_grad():
try:
q_model = torch.jit.trace(q_model, example_inputs)
q_model = torch.jit.freeze(q_model.eval())
except:
q_model = torch.jit.trace(q_model, example_inputs, strict=False)
q_model = torch.jit.freeze(q_model.eval())
# After freezing, run 1 time to warm up the profiling graph executor to insert prim::profile
# At the 2nd run, the llga pass will be triggered and the model is turned into
# an int8 model: prim::profile will be removed and will have LlgaFusionGroup in the graph
q_model(*example_inputs)
q_model(*example_inputs)

assert self.approach != 'quant_aware_training', \
"Intel PyTorch Extension didn't support quantization aware training mode"
model_._model = q_model
Expand Down
1 change: 1 addition & 0 deletions neural_compressor/adaptor/pytorch_ipex.yaml
Expand Up @@ -90,6 +90,7 @@
ops:
int8: *ops_default_s8
uint8: *ops_default_s8
bf16: []
fp32: ['*'] # '*' means all op types

capabilities: &1_10_capabilities
Expand Down
15 changes: 15 additions & 0 deletions test/ipex/test_adaptor_ipex.py
Expand Up @@ -135,6 +135,21 @@ def test_copy_prepared_model(self):
copy_model = torch_utils.util.auto_copy(prepared_model)
self.assertTrue(isinstance(copy_model, torch.nn.Module))


def test_bf16(self):
from neural_compressor.experimental import Quantization
model = M()
qconfig = ipex.quantization.default_static_qconfig
prepared_model = ipex.quantization.prepare(model, qconfig, example_inputs=torch.ones(1, 3, 224, 224), inplace=False)
config.quantization.use_bf16 = True
config.quantization.performance_only = True
quantizer = Quantization(config)
dataset = quantizer.dataset('dummy', (100, 3, 224, 224), label=True)
dataloader = torch.utils.data.DataLoader(dataset)
quantizer.model = model
quantizer.calib_dataloader = dataloader
quantizer.eval_dataloader = dataloader
nc_model = quantizer.fit()

if __name__ == "__main__":
unittest.main()

0 comments on commit 2a361b8

Please sign in to comment.