From 029a6325748210e102a566603ad7220a0fc70eea Mon Sep 17 00:00:00 2001 From: xinhe Date: Fri, 12 Aug 2022 13:56:48 +0800 Subject: [PATCH] support QAT PT to ONNX (#1143) --- neural_compressor/model/torch_model.py | 57 ++++++++++++++++--- .../pytorch_adaptor/test_torch2onnx.py | 44 +++++++++++--- 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index 42c7f2391a3..6d1a59de469 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -409,14 +409,15 @@ def export_to_bf16_onnx(self, original_initializer = copy.deepcopy(model.graph.initializer) for tensor in original_initializer: if tensor.name in bf16_tensor_name_list: - bf16_tensor = helper.make_tensor( - name=tensor.name, - data_type=TensorProto.BFLOAT16, - dims=tensor.dims, - vals=numpy_helper.to_array(tensor), - ) - model.graph.initializer.remove(tensor) - model.graph.initializer.append(bf16_tensor) + def fp32_to_bf16(fp32_np): + assert(fp32_np.dtype==np.float32) + int32_np = fp32_np.view(dtype=np.int32) + int32_np = int32_np >> 16 + bf16_np = int32_np.astype(np.int16) + return bf16_np + fp16_data = fp32_to_bf16(numpy_helper.to_array(tensor)) + tensor.raw_data = fp16_data.tobytes() + tensor.data_type = TensorProto.BFLOAT16 onnx.save(model, save_path) os.remove(fp32_path) @@ -430,7 +431,6 @@ def export_to_int8_onnx( self, save_path='int8-model.onnx', example_inputs = torch.rand([1, 1, 1, 1]), - example_input_names = 'input', opset_version=14, dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, @@ -497,6 +497,45 @@ def export_to_int8_onnx( fp32_model=fp32_model ) model = onnx.load(fp32_path) + + if self.q_config['approach'] == 'quant_aware_training': + # collect weights, bias from int8 PT model + model_dict = self._model.state_dict() + int8_model_dict = {} + for name, param in model_dict.items(): + # '_packed_params._packed_weight' is specific for quantized Embedding + if '_packed_params._packed_weight' in name: + name = name.replace('._packed_params._packed_weight', '').split('.module')[0] + int8_model_dict[name+'.weight'] = param.dequantize() + # '_packed_params._packed_params' is specific for quantized Linear + elif '_packed_params._packed_params' in name and isinstance(param, tuple): + name = name.replace('._packed_params._packed_params', '').split('.module')[0] + int8_model_dict[name+'.bias'] = param[1] + int8_model_dict[name+'.weight'] = param[0].dequantize() + # '.weight' and '.bias' is specific for quantized Conv + elif '.weight' in name: + int8_model_dict[name] = param.dequantize() + elif '.bias' in name: + int8_model_dict[name] = param + else: + int8_model_dict[name] = param + + # replace weight and bias in onnx fp32 model for QAT + from onnx import helper + tensor_list = [tensor for tensor in model.graph.initializer] + for tensor in tensor_list: + if tensor.name in int8_model_dict: + np_tensor = int8_model_dict[tensor.name].detach().cpu().numpy() + new_tensor = helper.make_tensor( + name=tensor.name, + data_type=tensor.data_type, + dims=tensor.dims, + vals=np_tensor, + ) + model.graph.initializer.remove(tensor) + model.graph.initializer.append(new_tensor) + onnx.save(model, fp32_path) + from neural_compressor.adaptor.onnxrt import ONNXRTAdaptor # pylint: disable=E1120 inc_model = ONNXRTAdaptor._replace_gemm_with_matmul(model) diff --git a/test/adaptor/pytorch_adaptor/test_torch2onnx.py b/test/adaptor/pytorch_adaptor/test_torch2onnx.py index 1171774b86f..1453a6c05d3 100644 --- a/test/adaptor/pytorch_adaptor/test_torch2onnx.py +++ b/test/adaptor/pytorch_adaptor/test_torch2onnx.py @@ -5,6 +5,7 @@ import unittest import os import onnx +import numpy as np from neural_compressor.adaptor.pytorch import PyTorchVersionMode import neural_compressor.adaptor.pytorch as nc_torch from neural_compressor.experimental import Quantization, common @@ -63,6 +64,13 @@ def build_pytorch_yaml(): with open('dynamic_yaml.yaml', 'w', encoding="utf-8") as f: f.write(fake_dyn_yaml) + fake_qat_yaml = fake_ptq_yaml.replace( + 'post_training_static_quant', + 'quant_aware_training', + ) + with open('qat_yaml.yaml', 'w', encoding="utf-8") as f: + f.write(fake_qat_yaml) + def build_pytorch_fx_yaml(): fake_fx_ptq_yaml = fake_ptq_yaml.replace('pytorch', 'pytorch_fx') @@ -130,10 +138,10 @@ def setUpClass(self): def tearDownClass(self): os.remove('ptq_yaml.yaml') os.remove('dynamic_yaml.yaml') + os.remove('qat_yaml.yaml') shutil.rmtree('runs', ignore_errors=True) os.remove('fp32-model.onnx') os.remove('int8-model.onnx') - os.remove('int8-model.tmp') @unittest.skipIf(not BF16_MODE, "Unsupport BF16 Mode with ONNX Version Below 1.11") def test_bf16_onnx(self): @@ -163,22 +171,37 @@ def test_eager_quant(self): "output": {0: "batch_size"}}, do_constant_folding=True, ) - for fake_yaml in ['dynamic_yaml.yaml', ]: + for fake_yaml in ['dynamic_yaml.yaml', 'ptq_yaml.yaml', 'qat_yaml.yaml']: model = M() quantizer = Quantization(fake_yaml) quantizer.conf.usr_cfg.tuning.exit_policy['performance_only'] = True dataset = quantizer.dataset('dummy', (10, 3, 224, 224), label=True) quantizer.model = model - quantizer.calib_dataloader = common.DataLoader(dataset) + if fake_yaml == 'qat_yaml.yaml': + def train_func(model): + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + optimizer.zero_grad() + model.train() + input = torch.randn([1, 3, 224, 224]) + # compute output + output = model(input) + loss = output.abs().sum() + loss.backward() + optimizer.step() + return model + quantizer.q_func = train_func + elif fake_yaml == 'ptq_yaml.yaml': + quantizer.calib_dataloader = common.DataLoader(dataset) quantizer.eval_dataloader = common.DataLoader(dataset) q_model = quantizer.fit() int8_jit_model = q_model.export_to_jit(example_inputs) # INC will keep fallbacked fp32 modules when exporting onnx model - if 'ptq_yaml.yaml': - calib_dataloader = quantizer.calib_dataloader - else: + if fake_yaml == 'dynamic_yaml.yaml': calib_dataloader = None + else: + quantizer.calib_dataloader = common.DataLoader(dataset) + calib_dataloader = quantizer.calib_dataloader q_model.export_to_int8_onnx( save_path='int8-model.onnx', example_inputs=example_inputs, @@ -191,6 +214,14 @@ def test_eager_quant(self): fp32_model=model, calib_dataloader=calib_dataloader, ) + if fake_yaml == 'qat_yaml.yaml': + model = onnx.load('int8-model.onnx') + tensor_list = {tensor.name:tensor for tensor in model.graph.initializer} + torch_data = q_model.model.conv.weight().dequantize().detach().cpu().numpy() + from onnx.numpy_helper import to_array + onnx_data = to_array(tensor_list['conv.weight_quantized']) + onnx_scale = to_array(tensor_list['conv.weight_scale']) + self.assertTrue(np.allclose(torch_data, onnx_data * onnx_scale, atol=0.001)) def test_input_tuple(self): from neural_compressor.adaptor.torch_utils.util import input2tuple @@ -214,7 +245,6 @@ def tearDownClass(self): os.remove('fx_dynamic_yaml.yaml') shutil.rmtree('runs', ignore_errors=True) os.remove('int8-model.onnx') - os.remove('int8-model.tmp') def test_fx_quant(self): for fake_yaml in ['fx_dynamic_yaml.yaml', 'fx_ptq_yaml.yaml']: