Skip to content

Commit

Permalink
Support smoothquant with MinMaxObserver (#1326)
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <chang1.wang@intel.com>
Co-authored-by: Xin He <xin3.he@intel.com>
  • Loading branch information
changwangss and xin3he committed Nov 1, 2023
1 parent e841a8d commit 45b4966
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
19 changes: 17 additions & 2 deletions neural_compressor/adaptor/pytorch.py
Expand Up @@ -2622,6 +2622,7 @@ def __init__(self, framework_specific_info):
self.op_infos_from_cfgs = None
self.output_tensor_id_op_name = None
self.ipex_config_path = os.path.join(self.workspace_path, "ipex_config_tmp.json")
self.sq_minmax_init = True if framework_specific_info.get("model_init_algo", "kl") == "minmax" else False

try:
os.remove(self.ipex_config_path)
Expand Down Expand Up @@ -3111,7 +3112,14 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
smooth_quant_args = self.recipes.get("smooth_quant_args", {})
folding = smooth_quant_args.get("folding", False)
if not folding:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if self.sq_minmax_init:
from torch.ao.quantization.observer import MinMaxObserver

static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=0.5, act_observer=MinMaxObserver()
)
else:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if self.example_inputs is None:
self.example_inputs = get_example_inputs(model, self.q_dataloader)
if isinstance(self.example_inputs, dict):
Expand Down Expand Up @@ -3288,7 +3296,14 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
# Check save_qconf_summary part is a workaround for IPEX bug.
# Sometimes the prepared model from get_op_capablitiy loss this attribute
if not hasattr(model._model, "save_qconf_summary") or not hasattr(model._model, "load_qconf_summary"):
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if self.sq_minmax_init:
from torch.ao.quantization.observer import MinMaxObserver

static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
alpha=0.5, act_observer=MinMaxObserver()
)
else:
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if isinstance(self.example_inputs, dict):
model._model = ipex.quantization.prepare(
model._model, static_qconfig, example_kwarg_inputs=self.example_inputs, inplace=inplace
Expand Down
7 changes: 7 additions & 0 deletions neural_compressor/strategy/strategy.py
Expand Up @@ -1509,6 +1509,13 @@ def _set_framework_info(self, q_dataloader, q_func=None):
if framework == "pytorch_ipex" or framework == "pytorch" or framework == "pytorch_fx":
if self.config.backend == "ipex":
framework = "pytorch_ipex"
if self.config.recipes.get("smooth_quant", None) and (
self.config.op_name_dict or self.config.op_type_dict
):
model_dict = self.config.op_type_dict if self.config.op_type_dict else self.config.op_name_dict
model_algo = model_dict.get(".*", {}).get("activation", {}).get("algorithm", {})
if model_algo == "minmax" or "minmax" in model_algo:
framework_specific_info.update({"model_init_algo": "minmax"})
elif self.config.backend == "default":
framework = "pytorch_fx"
if self.mixed_precision_mode:
Expand Down
37 changes: 32 additions & 5 deletions test/ipex/test_adaptor_ipex.py
Expand Up @@ -339,13 +339,11 @@ def test_tune_add_with_recipe(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 1)
self.linear = torch.nn.Linear(224 * 224, 5)
self.linear = torch.nn.Linear(224 * 224 * 3, 5)

def forward(self, a):
x = self.conv(a)
x = x.view(1, -1)
def forward(self, x):
x += x
x = x.view(1, -1)
x = self.linear(x)
return x

Expand All @@ -365,6 +363,35 @@ def fake_eval(model):
q_model = quantization.fit(model, conf, calib_dataloader=calib_dataloader, eval_func=fake_eval)
self.assertTrue(isinstance(q_model._model, torch.jit.ScriptModule))

def test_tune_minmax_obs(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2, False)

def forward(self, x):
x = self.linear(x)
x = x + x
return x

example_input = torch.tensor([[torch.finfo(torch.float32).max, -torch.finfo(torch.float32).max]])
model = M()
model.linear.weight = torch.nn.Parameter(torch.tensor([[0.0, 1.0], [1.0, 0.0]]))

def calib_func(model):
model(example_input)

from neural_compressor import PostTrainingQuantConfig, quantization

conf = PostTrainingQuantConfig(
backend="ipex",
example_inputs=example_input,
op_name_dict={".*": {"activation": {"algorithm": "minmax"}}},
recipes={"smooth_quant": True, "smooth_quant_args": {"alpha": 0.5}},
)
q_model = quantization.fit(model, conf, calib_func=calib_func)
self.assertTrue(isinstance(q_model._model, torch.jit.ScriptModule))

@unittest.skipIf(
IPEX_VERSION.release < Version("2.1.0").release,
"Please use Intel extension for Pytorch version higher or equal to 2.1.0",
Expand Down

0 comments on commit 45b4966

Please sign in to comment.