Skip to content

Commit

Permalink
support QAT PT to ONNX (#1143)
Browse files Browse the repository at this point in the history
  • Loading branch information
xin3he committed Aug 12, 2022
1 parent 2657385 commit 029a632
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 16 deletions.
57 changes: 48 additions & 9 deletions neural_compressor/model/torch_model.py
Expand Up @@ -409,14 +409,15 @@ def export_to_bf16_onnx(self,
original_initializer = copy.deepcopy(model.graph.initializer)
for tensor in original_initializer:
if tensor.name in bf16_tensor_name_list:
bf16_tensor = helper.make_tensor(
name=tensor.name,
data_type=TensorProto.BFLOAT16,
dims=tensor.dims,
vals=numpy_helper.to_array(tensor),
)
model.graph.initializer.remove(tensor)
model.graph.initializer.append(bf16_tensor)
def fp32_to_bf16(fp32_np):
assert(fp32_np.dtype==np.float32)
int32_np = fp32_np.view(dtype=np.int32)
int32_np = int32_np >> 16
bf16_np = int32_np.astype(np.int16)
return bf16_np
fp16_data = fp32_to_bf16(numpy_helper.to_array(tensor))
tensor.raw_data = fp16_data.tobytes()
tensor.data_type = TensorProto.BFLOAT16
onnx.save(model, save_path)
os.remove(fp32_path)

Expand All @@ -430,7 +431,6 @@ def export_to_int8_onnx(
self,
save_path='int8-model.onnx',
example_inputs = torch.rand([1, 1, 1, 1]),
example_input_names = 'input',
opset_version=14,
dynamic_axes={"input": {0: "batch_size"},
"output": {0: "batch_size"}},
Expand Down Expand Up @@ -497,6 +497,45 @@ def export_to_int8_onnx(
fp32_model=fp32_model
)
model = onnx.load(fp32_path)

if self.q_config['approach'] == 'quant_aware_training':
# collect weights, bias from int8 PT model
model_dict = self._model.state_dict()
int8_model_dict = {}
for name, param in model_dict.items():
# '_packed_params._packed_weight' is specific for quantized Embedding
if '_packed_params._packed_weight' in name:
name = name.replace('._packed_params._packed_weight', '').split('.module')[0]
int8_model_dict[name+'.weight'] = param.dequantize()
# '_packed_params._packed_params' is specific for quantized Linear
elif '_packed_params._packed_params' in name and isinstance(param, tuple):
name = name.replace('._packed_params._packed_params', '').split('.module')[0]
int8_model_dict[name+'.bias'] = param[1]
int8_model_dict[name+'.weight'] = param[0].dequantize()
# '.weight' and '.bias' is specific for quantized Conv
elif '.weight' in name:
int8_model_dict[name] = param.dequantize()
elif '.bias' in name:
int8_model_dict[name] = param
else:
int8_model_dict[name] = param

# replace weight and bias in onnx fp32 model for QAT
from onnx import helper
tensor_list = [tensor for tensor in model.graph.initializer]
for tensor in tensor_list:
if tensor.name in int8_model_dict:
np_tensor = int8_model_dict[tensor.name].detach().cpu().numpy()
new_tensor = helper.make_tensor(
name=tensor.name,
data_type=tensor.data_type,
dims=tensor.dims,
vals=np_tensor,
)
model.graph.initializer.remove(tensor)
model.graph.initializer.append(new_tensor)
onnx.save(model, fp32_path)

from neural_compressor.adaptor.onnxrt import ONNXRTAdaptor
# pylint: disable=E1120
inc_model = ONNXRTAdaptor._replace_gemm_with_matmul(model)
Expand Down
44 changes: 37 additions & 7 deletions test/adaptor/pytorch_adaptor/test_torch2onnx.py
Expand Up @@ -5,6 +5,7 @@
import unittest
import os
import onnx
import numpy as np
from neural_compressor.adaptor.pytorch import PyTorchVersionMode
import neural_compressor.adaptor.pytorch as nc_torch
from neural_compressor.experimental import Quantization, common
Expand Down Expand Up @@ -63,6 +64,13 @@ def build_pytorch_yaml():
with open('dynamic_yaml.yaml', 'w', encoding="utf-8") as f:
f.write(fake_dyn_yaml)

fake_qat_yaml = fake_ptq_yaml.replace(
'post_training_static_quant',
'quant_aware_training',
)
with open('qat_yaml.yaml', 'w', encoding="utf-8") as f:
f.write(fake_qat_yaml)


def build_pytorch_fx_yaml():
fake_fx_ptq_yaml = fake_ptq_yaml.replace('pytorch', 'pytorch_fx')
Expand Down Expand Up @@ -130,10 +138,10 @@ def setUpClass(self):
def tearDownClass(self):
os.remove('ptq_yaml.yaml')
os.remove('dynamic_yaml.yaml')
os.remove('qat_yaml.yaml')
shutil.rmtree('runs', ignore_errors=True)
os.remove('fp32-model.onnx')
os.remove('int8-model.onnx')
os.remove('int8-model.tmp')

@unittest.skipIf(not BF16_MODE, "Unsupport BF16 Mode with ONNX Version Below 1.11")
def test_bf16_onnx(self):
Expand Down Expand Up @@ -163,22 +171,37 @@ def test_eager_quant(self):
"output": {0: "batch_size"}},
do_constant_folding=True,
)
for fake_yaml in ['dynamic_yaml.yaml', ]:
for fake_yaml in ['dynamic_yaml.yaml', 'ptq_yaml.yaml', 'qat_yaml.yaml']:
model = M()
quantizer = Quantization(fake_yaml)
quantizer.conf.usr_cfg.tuning.exit_policy['performance_only'] = True
dataset = quantizer.dataset('dummy', (10, 3, 224, 224), label=True)
quantizer.model = model
quantizer.calib_dataloader = common.DataLoader(dataset)
if fake_yaml == 'qat_yaml.yaml':
def train_func(model):
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
model.train()
input = torch.randn([1, 3, 224, 224])
# compute output
output = model(input)
loss = output.abs().sum()
loss.backward()
optimizer.step()
return model
quantizer.q_func = train_func
elif fake_yaml == 'ptq_yaml.yaml':
quantizer.calib_dataloader = common.DataLoader(dataset)
quantizer.eval_dataloader = common.DataLoader(dataset)
q_model = quantizer.fit()

int8_jit_model = q_model.export_to_jit(example_inputs)
# INC will keep fallbacked fp32 modules when exporting onnx model
if 'ptq_yaml.yaml':
calib_dataloader = quantizer.calib_dataloader
else:
if fake_yaml == 'dynamic_yaml.yaml':
calib_dataloader = None
else:
quantizer.calib_dataloader = common.DataLoader(dataset)
calib_dataloader = quantizer.calib_dataloader
q_model.export_to_int8_onnx(
save_path='int8-model.onnx',
example_inputs=example_inputs,
Expand All @@ -191,6 +214,14 @@ def test_eager_quant(self):
fp32_model=model,
calib_dataloader=calib_dataloader,
)
if fake_yaml == 'qat_yaml.yaml':
model = onnx.load('int8-model.onnx')
tensor_list = {tensor.name:tensor for tensor in model.graph.initializer}
torch_data = q_model.model.conv.weight().dequantize().detach().cpu().numpy()
from onnx.numpy_helper import to_array
onnx_data = to_array(tensor_list['conv.weight_quantized'])
onnx_scale = to_array(tensor_list['conv.weight_scale'])
self.assertTrue(np.allclose(torch_data, onnx_data * onnx_scale, atol=0.001))

def test_input_tuple(self):
from neural_compressor.adaptor.torch_utils.util import input2tuple
Expand All @@ -214,7 +245,6 @@ def tearDownClass(self):
os.remove('fx_dynamic_yaml.yaml')
shutil.rmtree('runs', ignore_errors=True)
os.remove('int8-model.onnx')
os.remove('int8-model.tmp')

def test_fx_quant(self):
for fake_yaml in ['fx_dynamic_yaml.yaml', 'fx_ptq_yaml.yaml']:
Expand Down

0 comments on commit 029a632

Please sign in to comment.