From 83018ef28170f8d2659dd30bb28738857e5c0dec Mon Sep 17 00:00:00 2001 From: BiaoFangAIA <108742533+BiaoFangAIA@users.noreply.github.com> Date: Sat, 10 Dec 2022 00:08:52 +0800 Subject: [PATCH] Add hawq_v2 tuning strategy (#230) Signed-off-by: yiliu30 Co-authored-by: lvliang-intel Co-authored-by: chen, suyue Co-authored-by: xinhe Co-authored-by: Ray <106061964+yiliu30@users.noreply.github.com> --- examples/.config/model_params_pytorch.json | 18 + .../quantization/ptq/cpu/fx/conf.yaml | 2 +- neural_compressor/adaptor/pytorch.py | 29 +- .../adaptor/torch_utils/hawq_metric.py | 582 ++++++++++++++++++ neural_compressor/conf/config.py | 8 +- neural_compressor/config.py | 4 +- neural_compressor/strategy/hawq_v2.py | 176 ++++++ .../strategy/utils/tuning_sampler.py | 6 +- test/strategy/test_hawq_v2_2.x.py | 60 ++ 9 files changed, 876 insertions(+), 9 deletions(-) create mode 100644 neural_compressor/adaptor/torch_utils/hawq_metric.py create mode 100644 neural_compressor/strategy/hawq_v2.py create mode 100644 test/strategy/test_hawq_v2_2.x.py diff --git a/examples/.config/model_params_pytorch.json b/examples/.config/model_params_pytorch.json index fcd42576fcb..77990bbf42f 100644 --- a/examples/.config/model_params_pytorch.json +++ b/examples/.config/model_params_pytorch.json @@ -9,6 +9,24 @@ "batch_size": 100, "new_benchmark": false }, + "efficientnet_b0_fx": { + "model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "", + "yaml": "conf.yaml", + "strategy": "hawq_v2", + "batch_size": 100, + "new_benchmark": false + }, + "efficientnet_b3_fx": { + "model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "", + "yaml": "conf.yaml", + "strategy": "hawq_v2", + "batch_size": 100, + "new_benchmark": false + }, "resnet18_fx": { "model_src_dir": "image_recognition/torchvision_models/quantization/ptq/cpu/fx/", "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", diff --git a/examples/pytorch/image_recognition/torchvision_models/quantization/ptq/cpu/fx/conf.yaml b/examples/pytorch/image_recognition/torchvision_models/quantization/ptq/cpu/fx/conf.yaml index d1dab0d2f43..f11483acd16 100644 --- a/examples/pytorch/image_recognition/torchvision_models/quantization/ptq/cpu/fx/conf.yaml +++ b/examples/pytorch/image_recognition/torchvision_models/quantization/ptq/cpu/fx/conf.yaml @@ -77,4 +77,4 @@ tuning: relative: 0.01 # optional. default value is relative, other value is absolute. this example allows relative accuracy loss: 1%. exit_policy: timeout: 0 # optional. tuning timeout (seconds). default value is 0 which means early stop. combine with max_trials field to decide when to exit. - random_seed: 9527 # optional. random seed for deterministic tuning. + random_seed: 9527 # optional. random seed for deterministic tuning. \ No newline at end of file diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 9627ad2386a..b0b51040510 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -30,7 +30,6 @@ from .query import QueryBackendCapability from ..experimental.data.dataloaders.base_dataloader import BaseDataLoader - torch = LazyImport("torch") json = LazyImport("json") hvd = LazyImport("horovod.torch") @@ -1094,6 +1093,34 @@ def is_fused_module(self, module): return True else: return False + + def calculate_hessian_trace(self, + fp32_model, + dataloader, + q_model, + criterion, + enable_act = False + ): + """Calculate hessian trace. + + Args: + fp32_model: The original fp32 model. + criterion: The loss function for calculate the hessian trace. # loss = criterion(output, target) + dataloader: The dataloader for calculate the gradient. + q_model: The INT8 AMAP model. + enable_act: Enabling quantization error or not. + + Return: + hessian_trace(Dict[Tuple, float]), key: (op_name, op_type); value: hessian trace. + """ + from .torch_utils.hawq_metric import hawq_top + op_to_traces=hawq_top(fp32_model=fp32_model, + dataloader=dataloader, + q_model=q_model, + criterion=criterion, + enable_act=enable_act) + return op_to_traces + pass unify_op_type_mapping = { diff --git a/neural_compressor/adaptor/torch_utils/hawq_metric.py b/neural_compressor/adaptor/torch_utils/hawq_metric.py new file mode 100644 index 00000000000..f68a1234164 --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/hawq_metric.py @@ -0,0 +1,582 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...utils.utility import LazyImport +torch = LazyImport("torch") + +import copy +import numpy as np +from collections import OrderedDict +import torch.nn +from torch.quantization.quantize_fx import fuse_fx +import torch.nn.intrinsic.quantized as nniq +from torch.fx import symbolic_trace, graph_module +import torch.nn as nn +import logging +logger = logging.getLogger(__name__) +from typing import Dict, List, Optional, Any, Union, Callable, Set +# Define Collector based on hook, which is used to record the intermediate result +class Node_collector: + def __init__(self, m): + self.handle = m.register_forward_hook(self.hook_fn_act) + def hook_fn_act(self, m, inp, outp): + self.out_features = outp.clone() + self.in_features = inp + self.m = m + def remove(self): + self.handle.remove() +class HessianTrace: + """ + please refer to + Yao, Zhewei, et al. "Pyhessian: Neural networks through the lens of the hessian." + 2020 IEEE international conference on big data (Big data). IEEE, 2020. + Dong, Zhen, et al. "Hawq-v2: Hessian aware trace-weighted quantization of neural networks." + Advances in neural information processing systems 33 (2020): 18518-18529. + https://github.com/openvinotoolkit/nncf/blob/develop/nncf/torch/quantization/hessian_trace.py + """ + + def __init__(self, model, dataloader,q_model,criterion=None): + self.unfused_model = model.model + self.q_model=q_model + tmp_model=model.model + if 'graph' in (str(dir(tmp_model))): #check the attribute and it's length + logger.info("This is aready fused model") + self.model=model.model + else: + logger.info("fusing model") + self.model = fuse_fx(model.model) ##TODO need to check whether model has been already fused + self.dataloader = dataloader + self.max_iter = 500 + self.tolerance = 1e-5 + self.eps = 1e-6 + self.index = 0 + self.device = self.get_device(self.model) + self.criterion = criterion + if self.criterion == None: + self.criterion = torch.nn.CrossEntropyLoss().to(self.device) ##TODO need to set in config + self.criterion = self.criterion.to(self.device) + self.weight_to_op, self.op_list = self.get_fused_mapping() + self.get_params() + + def is_fused_module(self, module): + """This is a helper function for `_propagate_qconfig_helper` to detecte + if this module is fused. + Args: + module (object): input module + Returns: + (bool): is fused or not + """ + op_type = str(type(module)) + if 'fused' in op_type: + return True + else: + return False + + def mapping_module_to_op(self, name): + # length = len("_model.") + # if len(name) < length: + # return name + # else: + return name + def mse_metric_gap(self,fp32_tensor, dequantize_tensor): + """Calculate the euclidean distance between fp32 tensor and int8 dequantize tensor + Args: + fp32_tensor (tensor): The FP32 tensor. + dequantize_tensor (tensor): The INT8 dequantize tensor. + """ + fp32_max = np.max(fp32_tensor) + fp32_min = np.min(fp32_tensor) + dequantize_max = np.max(dequantize_tensor) + dequantize_min = np.min(dequantize_tensor) + fp32_tensor = (fp32_tensor - fp32_min) / (fp32_max - fp32_min) + dequantize_tensor = (dequantize_tensor - dequantize_min) / \ + (dequantize_max - dequantize_min) + diff_tensor = fp32_tensor - dequantize_tensor + euclidean_dist = np.sum(diff_tensor ** 2) + return euclidean_dist / fp32_tensor.size + def get_fused_mapping(self): + model = self.model + weights_info = dict(model.named_parameters()) + weight_to_op = {} + for op_name, child in model.named_modules(): + if self.is_fused_module(child): + for name, _ in child.named_children(): + if op_name + "." + name + ".weight" in weights_info: ##TODO check if this is right + + weight_to_op[op_name + "." + name + ".weight"] = self.mapping_module_to_op(op_name) + break + else: + name = op_name + ".weight" + if name in weights_info and name not in weight_to_op.keys(): + weight_to_op[op_name + ".weight"] = op_name + op_list = [] + for key in weight_to_op.keys(): + op_list.append(weight_to_op[key]) + return weight_to_op, op_list + + def get_device(self, model: torch.nn.Module): + for n, p in model.named_parameters(): + return p.data.device + + def _get_act_grad_hook(self, name): + def act_grad_hook(model, grad_input, grad_output): + ##print(name, grad_input[0].shape, grad_output[0].shape) + if type(model) == torch.nn.Linear: ##TODO very tricky + self.layer_acts_grads[name] = grad_input[1] + else: + self.layer_acts_grads[name] = grad_input[0] + + return act_grad_hook + + def _get_enable_act_grad_hook(self, name): + def enable_act_grad_hook(model, inputs, outputs): + input = inputs[0] + if input.requires_grad is False: + input.requires_grad = True + self.layer_acts[name] = input + + return enable_act_grad_hook + + # def _get_disable_input_grad_hook(self, name): + # def disable_input_grad_hook(model, inputs, outputs): + # try: + # input = inputs[0] ##TODO check whether this is right + # except: + # input = inputs + # if input.is_leaf == False:## you can only change requires_grad flags of leaf variables + # if input.requires_grad is True: + # input.requires_grad = False + # + # + # return disable_input_grad_hook + + def _unregister_hook(self): + for handel in self.hook_handles: + handel.remove() + + def register_act_grad_hooks(self, model): + for name, module in model.named_modules(): + if self.mapping_module_to_op(name) in self.op_list: + hook_handle = module.register_forward_hook(self._get_enable_act_grad_hook(name)) + self.hook_handles.append(hook_handle) + hook_handle = module.register_backward_hook(self._get_act_grad_hook(name)) + self.hook_handles.append(hook_handle) + + def reset_act_gradient_and_hooks(self): + # tmp_input = torch.zeros(self._input_shape, device=self.device) + # for name, module in self.model.named_modules(): + # if name in self.op_list: + # hook_handle = module.register_forward_hook(self._get_disable_input_grad_hook(name)) + # self.hook_handles.append(hook_handle) + # self.model(tmp_input) + self._unregister_hook() + + def get_params(self): + weight_names = [n for n, p in self.model.named_parameters() if + p.requires_grad and "bias" not in n] ##remove bias + params = [p for n, p in self.model.named_parameters() if p.requires_grad and "bias" not in n] ##remove bias + self.weight_names = weight_names + self.params = params + + def forward_backward(self, model, data, create_graph=False, return_w_grad=True): + model.zero_grad() + input = data[0].to(self.device) + ##self._input_shape = input.shape ## for resetting input activation + target = data[1].to(self.device) + input.requires_grad = True + output = model(input) + loss = self.criterion(output, target) + torch.autograd.backward(loss, create_graph=create_graph) + ##loss.backward(create_graph=create_graph) + if return_w_grad: + gradients = [] + for n, p in self.model.named_parameters(): + if p.grad != None and n in self.weight_names: + gradient = p.grad + gradients.append(gradient + 0.0) ## add 0 to create a copy + model.zero_grad() + return gradients + else: + model.zero_grad() + + # def get_params(self, model): + # parameters = [p for p in model.parameters() if p.requires_grad] + # return parameters + + def sample_rademacher(self, params): + samples = [] + for param in params: + r = torch.randint_like(param, high=2, device=self.device) + r.masked_fill_(r == 0, -1) + samples.append(r) + return samples + + def get_vtHv_weight(self, params, num_samples): + v = self.sample_rademacher(params) + H_v = [0] * len(v) + cnt = 0 + for step, data in enumerate(self.dataloader): + batch_size = data[0].shape[0] + cnt += batch_size + gradients = self.forward_backward(self.model, data, create_graph=True) + H_v_one = torch.autograd.grad(gradients, params, v, only_inputs=True, retain_graph=False) + H_v = [pre + cur * float(batch_size) for cur, pre in zip(H_v_one, H_v)] + if cnt >= num_samples: + break + if cnt > 0: + H_v = [item / cnt for item in H_v] + v_t_H_v = torch.stack([torch.mean(h_v * v_t) for (h_v, v_t) in zip(H_v, v)]) ##maybe sum is better + return v_t_H_v + + # def get_vtHv_act(self, params, num_samples): + # v = self.sample_rademacher(params) + # H_v = [0] * len(v) + # cnt = 0 + # for step, data in enumerate(self.dataloader): + # if cnt >= num_samples: + # break + # for i in range(self.dataloader.batchsize): ##force to batchsize to be 1 + # input = data[0][i:i + 1] + # target = data[1][i:i + 1] + + # self.get_gradients(self.model, (input, target), self.criterion, create_graph=True) + # layer_acts = [self.layer_acts[key] for key in self.layer_acts.keys()] + # layer_act_gradients = [self.layer_acts_grads[key] for key in self.layer_acts.keys()] + # hv_one = torch.autograd.grad(layer_act_gradients, layer_acts, v, + # only_inputs=True, retain_graph=False) + # cnt += 1 + # if cnt >= num_samples: + # break + + def get_weight_traces(self, num_samples): + import tqdm + layer_traces_per_iter = [] + prev_avg_model_trace = 0 + for iter in tqdm.tqdm(range(self.max_iter)): + layer_traces = self.get_vtHv_weight(self.params, num_samples) + layer_traces_per_iter.append(layer_traces) + layer_traces_estimate = torch.mean(torch.stack(layer_traces_per_iter), dim=0) + model_trace = torch.sum(layer_traces_estimate) + diff_ratio = abs(model_trace - prev_avg_model_trace) / (prev_avg_model_trace + self.eps) + if diff_ratio < self.tolerance and iter > 10: ##TODO magic number + break + # if iter == 20: ##TODO for debugging + # break + prev_avg_model_trace = model_trace + weight_name_to_traces = {} + layer_traces = layer_traces_estimate + for weight_name, trace in zip(self.weight_names, layer_traces): + weight_name_to_traces[weight_name] = float(trace)# tensor->float + op_name_to_trace = {} + for weight_name in self.weight_names: + op_name = self.weight_to_op[weight_name] + op_name_to_trace[op_name] = weight_name_to_traces[weight_name] + return op_name_to_trace + def get_act_traces(self, num_samples): + unfused_training = self.unfused_model.training + self.unfused_model.eval() + self.hook_handles = [] + self.layer_acts = {} + self.layer_acts_grads = {} + self.register_act_grad_hooks(self.unfused_model) + cnt = 0 + act_traces_per_sample = [] + for step, data in enumerate(self.dataloader): + if cnt >= num_samples: + break + bs = data[0].shape[0] + act_traces_sum = 0 + act_traces_per_iter = [] + prev_avg_model_trace = 0 + act_traces_sums = None + for i in range(bs): ##force the bs to be one + input = data[0][i:i + 1] + target = data[1][i:i + 1] + self.forward_backward(self.unfused_model, (input, target), create_graph=True, return_w_grad=False) + acts = [self.layer_acts[key] for key in self.layer_acts.keys()] + if act_traces_sums == None: + act_traces_sums = [0] * len(acts) + acts_grad = [self.layer_acts_grads[key] for key in self.layer_acts.keys()] ##same order with acts + vt_H_v_sum_per_act = [0] * len(acts) + + prev_model_act_trace = 0 + for iter in range(self.max_iter): + v = self.sample_rademacher(acts) + H_v = torch.autograd.grad(acts_grad, acts, v, only_inputs=True, retain_graph=True) + vt_H_v = [torch.mean(h_v * v_t) for (h_v, v_t) in zip(H_v, v)] + + vt_H_v_sum_per_act = [vt_H_v_sum_per_act[index] + vt_H_v[index] for index, item in + enumerate(vt_H_v_sum_per_act)] + vt_H_v_mean_per_act = [item / (iter + 1) for item in vt_H_v_sum_per_act] + current_model_act_trace = torch.mean(torch.stack(vt_H_v_mean_per_act)) + + diff_ratio = abs(current_model_act_trace - prev_model_act_trace) / ( + prev_model_act_trace + self.eps) + if diff_ratio < self.tolerance and iter > 10: ##TODO magic number + break + # if iter == 50: ##TODO for debug + # break + + prev_model_act_trace = current_model_act_trace + act_traces_per_sample.append(vt_H_v_mean_per_act) + cnt += 1 + if cnt >= num_samples: + break + + if unfused_training: + self.unfused_model.train() + self.reset_act_gradient_and_hooks() ##TODO have issues to reset the input grad to False + act_traces_stack = torch.stack([torch.stack(item) for item in act_traces_per_sample]) + act_traces = torch.mean(act_traces_stack, dim=0) + res_dict = {} + for index, key in enumerate(self.layer_acts.keys()): + res_dict[key] = act_traces[index] + + self.layer_acts = [] + self.layer_acts_grads = [] + return res_dict + def insert_hook(self, model, target_module_list): + intern_outputs = [] + for layer,module in model.named_modules(): + for target_module in target_module_list: + # print("layer:",layer) + # print("target_model:",target_module) + if layer == target_module: + logging.debug("Collect: %s" % (module)) + # print("Collect: %s" % (module)) + intern_outputs.append(Node_collector(module)) + + logging.info("Total %d hook inserted" % (len(intern_outputs))) + # print("Total %d hook inserted" % (len(intern_outputs))) + return model, intern_outputs + def insert_hook_quantize(self,model, target_module_list): + intern_outputs = [] + for layer,module in model.named_modules(): + for target_module in target_module_list: + # print("layer:",layer) + length = len("_model.") + new_key = layer[length:] + # print("target_model:",target_module) + if new_key == target_module: + logging.debug("Collect: %s" % (module)) + # print("Collect: %s" % (module)) + intern_outputs.append(Node_collector(module)) + logging.info("Total %d hook inserted" % (len(intern_outputs))) + # print("Total %d hook inserted" % (len(intern_outputs))) + return model, intern_outputs + def get_act_gap(self,fp32_model,q_model): + """ + Estimates each activation gap between quantized model and float model + """ + self.handle_acts=[] + fp32_model.eval() + # temp_model = fuse_fx(fp32_model.model) + temp_model=fp32_model + # target_module_list = [nn.ReLU] # Insert hook for FP32 model + target_module_list = self.op_list + temp_model, intern_outputs =self.insert_hook(temp_model, target_module_list) + # intern_outputs={} + for input, target in self.dataloader: + temp_model(input) + break + + fp32_act_out={} + for i, intern_output in enumerate(intern_outputs): + stat_features = intern_output.out_features.view(-1) + # print ("No.", i, " ", intern_output.out_features.shape) + # print ("Numpy No.", i, " ", intern_output.out_features.cpu().data.numpy().shape) + # print ("No.", i, " ", stat_features.cpu().data.numpy().shape) + # print ("Numpy No.", i, " ", stat_features.cpu().data.numpy()) + fp32_act_out[target_module_list[i]]=stat_features.cpu().data.numpy() + # break + for i in intern_outputs: + # print(i) + i.remove() + target_module_list = self.op_list + q_model, intern_outputs=self.insert_hook_quantize(q_model, target_module_list) + for input, target in self.dataloader: #only one sample + q_model(input) + break + qnt_act_out={} + intern_outputs={} + for i, intern_output in enumerate(intern_outputs): + stat_features = intern_output.out_features.view(-1) + qnt_act_out[target_module_list[i]]=stat_features.dequantize().cpu().data.numpy() + # break + for i in intern_outputs: + # print(i) + i.remove() + act_gap={} + mse_gap={} + for fp_i,int_i in zip(fp32_act_out,qnt_act_out): + activation_qnt_error=fp32_act_out[fp_i]-qnt_act_out[int_i] + mse_gap[fp_i]=self.mse_metric_gap(fp32_act_out[fp_i],qnt_act_out[int_i]) + act_gap[fp_i]=np.sum(activation_qnt_error)/activation_qnt_error.size + return act_gap,mse_gap + def get_avg_traces(self, enable_act=True, num_samples=32): + """ + Estimates average hessian trace for each parameter + """ + assert num_samples > 0 + traces = {} + weight_traces = self.get_weight_traces(num_samples) + traces['weight'] = weight_traces + act_trace={} + if enable_act: + act_gap,mse_gap=self.get_act_gap(self.model,self.q_model) + act_traces = self.get_act_traces(num_samples) + for i,j in zip(act_traces,mse_gap): + #currently use mse to analysis + act_trace[i]=float(act_traces[i])+float(mse_gap[j])# Tensor->float + traces['activation'] = act_traces + return traces + + +##copy from torch.quantization._numeric_suite +def _find_match( + str_list: Union[Dict[str, Any], List[str]], key_str: str, + postfix: str, +) -> Optional[str]: + split_str = key_str.split(".") + if split_str[-1] == postfix: + match_string = "".join(key_str.split(".")[0:-1]) + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + + # For matching "fc.weight" and "fc._packed_params._packed_params" + if postfix == "_packed_params": + match_string = "".join(key_str.split(".")[0:-2]) + if len(match_string) == 0: + return None + for s2 in str_list: + pattern1 = "".join(s2.split(".")[0:-1]) + pattern2 = "".join(s2.split(".")[0:-2]) + if match_string == pattern1: + return s2 + if match_string == pattern2: + return s2 + return None + else: + return None + + +##copy form torch.quantization._numeric_suite +def compare_weights( + float_dict: Dict[str, Any], quantized_dict: Dict[str, Any] +) -> Dict[str, Dict[str, torch.Tensor]]: + r"""Compare the weights of the float module with its corresponding quantized + module. Return a dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights. This dict can be used to compare and compute the quantization + error of the weights of float and quantized models. + + Example usage:: + + wt_compare_dict = compare_weights( + float_model.state_dict(), qmodel.state_dict()) + for key in wt_compare_dict: + print( + key, + compute_error( + wt_compare_dict[key]['float'], + wt_compare_dict[key]['quantized'].dequantize() + ) + ) + + Args: + float_dict: state dict of the float model + quantized_dict: state dict of the quantized model + + Return: + weight_dict: dict with key corresponding to module names and each entry being + a dictionary with two keys 'float' and 'quantized', containing the float and + quantized weights + """ + + weight_dict: Dict[str, Dict] = {} + for key in quantized_dict: + match_key = _find_match(float_dict, key, "weight") + if match_key is not None: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[match_key] + weight_dict[key]["quantized"] = quantized_dict[key] + continue + + # For matching "fc.weight" and "fc._packed_params._packed_params" + match_key = _find_match(float_dict, key, "_packed_params") + if match_key is not None: + weight_dict[match_key] = {} + weight_dict[match_key]["float"] = float_dict[match_key] + weight_dict[match_key]["quantized"] = quantized_dict[key][0] + ##TODO:should consider more models in further work + + # For LSTM + split_str = key.split(".") + if split_str[-1] == "param" and split_str[-3] == "_all_weight_values": + layer = split_str[-2] + module_name = ".".join(split_str[:-3]) + float_weight_ih_key = module_name + ".weight_ih_l" + layer + float_weight_hh_key = module_name + ".weight_hh_l" + layer + if float_weight_ih_key in float_dict and float_weight_hh_key in float_dict: + weight_dict[key] = {} + weight_dict[key]["float"] = float_dict[float_weight_ih_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][0].__getstate__()[0][0] + ) + weight_dict[key]["float"] = float_dict[float_weight_hh_key] + weight_dict[key]["quantized"] = ( + quantized_dict[key].__getstate__()[0][4][1].__getstate__()[0][0] + ) + + return weight_dict +def hawq_top(fp32_model,q_model,dataloader,criterion,enable_act): + orig_eval=True + if fp32_model.training: + orig_eval=False + fp32_model.eval() + ht=HessianTrace(fp32_model,dataloader=dataloader,q_model=q_model) + q_model_state_dict={} + for key in q_model.state_dict().keys(): + length=len("_model.") + new_key=key[length:] + q_model_state_dict[new_key]=q_model.state_dict()[key] + weight_quant_loss=compare_weights(ht.model.state_dict(),q_model_state_dict) + pertur_lst={} + for key in weight_quant_loss: + op_float_tensor=weight_quant_loss[key]['float'] + op_qnt_tensor=weight_quant_loss[key]['quantized'].dequantize() + diff_l2 = (torch.norm(op_float_tensor - op_qnt_tensor, p=2) ** 2) + pertur_lst[key]=diff_l2 + traces=ht.get_avg_traces(enable_act) + op_to_traces=traces['weight'] + if enable_act: + act_to_traces=traces['activation'] + for trace_i, pertur_i,act_i in zip(op_to_traces.keys(),pertur_lst.keys(),act_to_traces.keys()): + #Formula:Omig=Trace*L2+act_trace + op_to_traces[trace_i]=pertur_lst[pertur_i]*op_to_traces[trace_i]+act_to_traces[act_i] + else: + for trace_i, pertur_i in zip(op_to_traces.keys(),pertur_lst.keys()): + op_to_traces[trace_i]=pertur_lst[pertur_i]*op_to_traces[trace_i] #Formula:Omig=Trace*L2 + if orig_eval==False: + fp32_model.train() + return op_to_traces + + \ No newline at end of file diff --git a/neural_compressor/conf/config.py b/neural_compressor/conf/config.py index 5e889d34cec..dc9d11c3f92 100644 --- a/neural_compressor/conf/config.py +++ b/neural_compressor/conf/config.py @@ -851,7 +851,7 @@ def percent_to_float(data): Optional('model_conversion'): model_conversion_schema, Optional('tuning', default={ - 'strategy': {'name': 'basic'}, + 'strategy': {'name': 'basic'}, 'accuracy_criterion': {'relative': 0.01, 'higher_is_better': True}, 'objective': 'performance', 'exit_policy': {'timeout': 0, 'max_trials': 100, 'performance_only': False}, @@ -866,7 +866,8 @@ def percent_to_float(data): Optional('sigopt_experiment_name', default='nc-tune'): str, Optional('accuracy_weight', default=1.0): float, Optional('latency_weight', default=1.0): float, - Optional('confidence_batches', default=2): int + Optional('confidence_batches', default=2): int, + Optional('hawq_v2_loss', default=None): object, } , Hook('accuracy_criterion', handler=_valid_accuracy_field): object, Optional('accuracy_criterion', default={'relative': 0.01}): { @@ -1360,7 +1361,8 @@ def map_pyconfig_to_cfg(self, pythonic_config): if pythonic_config.quantization.strategy_kwargs: st_kwargs = pythonic_config.quantization.strategy_kwargs for st_key in ['sigopt_api_token', 'sigopt_project_id', 'sigopt_experiment_name', \ - 'accuracy_weight', 'latency_weight']: + 'accuracy_weight', 'latency_weight', 'hawq_v2_loss']: + if st_key in st_kwargs: st_val = st_kwargs[st_key] mapping.update({'tuning.strategy.' + st_key: st_val}) diff --git a/neural_compressor/config.py b/neural_compressor/config.py index 584d9050108..295890a61ca 100644 --- a/neural_compressor/config.py +++ b/neural_compressor/config.py @@ -419,7 +419,7 @@ def strategy(self): @strategy.setter def strategy(self, strategy): if check_value('strategy', strategy, str, - ['basic', 'mse', 'bayesian', 'random', 'exhaustive', 'sigopt', 'tpe', 'mse_v2']): + ['basic', 'mse', 'bayesian', 'random', 'exhaustive', 'sigopt', 'tpe', 'mse_v2', 'hawq_v2']): self._strategy = strategy @property @@ -562,7 +562,7 @@ def strategy(self): @strategy.setter def strategy(self, strategy): if check_value('strategy', strategy, str, - ['basic', 'mse', 'bayesian', 'random', 'exhaustive', 'sigopt', 'tpe', 'mse_v2']): + ['basic', 'mse', 'bayesian', 'random', 'exhaustive', 'sigopt', 'tpe', 'mse_v2', 'hawq_v2']): self._strategy = strategy @property diff --git a/neural_compressor/strategy/hawq_v2.py b/neural_compressor/strategy/hawq_v2.py new file mode 100644 index 00000000000..2f33bf39ba4 --- /dev/null +++ b/neural_compressor/strategy/hawq_v2.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from copy import deepcopy + +from .strategy import strategy_registry, TuneStrategy + +from .utils.tuning_sampler import OpTypeWiseTuningSampler, FallbackTuningSampler, ModelWiseTuningSampler +from .utils.tuning_structs import OpTuningConfig +from .utils.tuning_space import TUNING_ITEMS_LST +from ..utils import logger + +@strategy_registry +class HAWQ_V2TuneStrategy(TuneStrategy): + """The HAWQ v2 tuning strategy. + + Args: + model (object): The FP32 model specified for low precision tuning. + conf (Class): The Conf class instance initialized from user yaml + config file. + q_dataloader (generator): Data loader for calibration, mandatory for + post-training quantization. + It is iterable and should yield a tuple (input, + label) for calibration dataset containing label, + or yield (input, _) for label-free calibration + dataset. The input could be a object, list, tuple or + dict, depending on user implementation, as well as + it can be taken as model input. + q_func (function, optional): Reserved for future use. + eval_dataloader (generator, optional): Data loader for evaluation. It is iterable + and should yield a tuple of (input, label). + The input could be a object, list, tuple or dict, + depending on user implementation, as well as it can + be taken as model input. The label should be able + to take as input of supported metrics. If this + parameter is not None, user needs to specify + pre-defined evaluation metrics through configuration + file and should set "eval_func" parameter as None. + Tuner will combine model, eval_dataloader and + pre-defined metrics to run evaluation process. + eval_func (function, optional): The evaluation function provided by user. + This function takes model as parameter, and + evaluation dataset and metrics should be + encapsulated in this function implementation and + outputs a higher-is-better accuracy scalar value. + + The pseudo code should be something like: + + def eval_func(model): + input, label = dataloader() + output = model(input) + accuracy = metric(output, label) + return accuracy + dicts (dict, optional): The dict containing resume information. + Defaults to None. + + """ + + def __init__(self, model, conf, q_dataloader, q_func=None, + eval_dataloader=None, eval_func=None, dicts=None, q_hooks=None): + super( + HAWQ_V2TuneStrategy, + self).__init__( + model, + conf, + q_dataloader, + q_func, + eval_dataloader, + eval_func, + dicts, + q_hooks) + + def next_tune_cfg(self): + tuning_space = self.tuning_space + calib_size = tuning_space.root_item.get_option_by_name('calib_sampling_size').options[0] + + # Initialize the tuning config for each op according to the quantization approach + op_item_dtype_dict, quant_mode_wise_items, initial_op_tuning_cfg = self.initial_tuning_cfg() + # Optype-wise tuning tuning items: the algorithm/scheme/granularity of activation(weight) + early_stop_tuning = True + stage1_cnt = 0 + quant_ops = quant_mode_wise_items['static'] if 'static' in quant_mode_wise_items else [] + quant_ops += quant_mode_wise_items['dynamic'] if 'dynamic' in quant_mode_wise_items else [] + stage1_max = 1 # TODO set a more appropriate value + op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [], + op_item_dtype_dict, initial_op_tuning_cfg) + for op_tuning_cfg in op_wise_tuning_sampler: + stage1_cnt += 1 + if early_stop_tuning and stage1_cnt > stage1_max: + logger.info("Early stopping the stage 1.") + break + op_tuning_cfg['calib_sampling_size'] = calib_size + yield op_tuning_cfg + # Start compute the hessian trace + logger.info(f"************** Start compute the hessian trace *****************") + target_dtype = "int8" + hawq_v2_criterion =self.cfg.tuning.strategy.hawq_v2_loss + # assert hawq_v2_criterion is not None, "HAWQ-V2 strategy needs model loss function to compute the gradient, \ + # Please assign it by strategy_kwargs({'hawq_v2_loss': hawq_v2_loss})." + op_to_traces = self.adaptor.calculate_hessian_trace(fp32_model = self._fp32_model, + dataloader = self.calib_dataloader, + q_model = self.q_model, + criterion =hawq_v2_criterion, + enable_act = False) + sorted_op_to_traces = dict(sorted(op_to_traces.items(), key=lambda item: item[1], reverse=True)) + logger.info(f"************** Hessian Trace *****************") + for op_name, trace in sorted_op_to_traces.items(): + logger.info(f"*** op: {op_name}, hessian trace : {trace}") + logger.info(f"************************************************") + # WA for op mapping + ordered_ops_tmp = {} + for op_info in list(initial_op_tuning_cfg.keys()): + op_name, op_type = op_info + for op_trace_name in op_to_traces.keys(): + if isinstance(op_trace_name, str) and op_trace_name.startswith(op_name): + if op_name in ordered_ops_tmp: + logger.info((f"*** Already assigned the hessian trace to {op_name}", + f"update it with the value of {op_trace_name}")) + ordered_ops_tmp[op_name] = op_to_traces[op_trace_name] + + ordered_ops_tmp = sorted(ordered_ops_tmp.keys(), + key=lambda key: ordered_ops_tmp[key], + reverse=self.higher_is_better) + # WA for add op type + op_info_map = {} + for op_info in list(initial_op_tuning_cfg.keys()): + op_info_map[op_info[0]] = op_info # op_name: (op_name, op_type) + tmp_ordered_ops = [op_info_map[op_name] for op_name in ordered_ops_tmp] + op_dtypes = OrderedDict(zip(tmp_ordered_ops, [target_dtype] * len(ordered_ops_tmp))) + + logger.info(f"Start to accumulate fallback to {target_dtype}.") + initial_op_tuning_cfg = deepcopy(op_tuning_cfg) + fallback_sampler = FallbackTuningSampler(tuning_space, tuning_order_lst=[], + initial_op_tuning_cfg=op_tuning_cfg, + op_dtypes=op_dtypes, accumulate=True, + skip_first=False) + for op_tuning_cfg in fallback_sampler: + op_tuning_cfg['calib_sampling_size'] = calib_size + yield op_tuning_cfg + + def initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg: OpTuningConfig): + op_state = op_static_cfg.get_state() + op_name = op_static_cfg.op_name + op_type = op_static_cfg.op_type + op_quant_mode = 'dynamic' + tuning_space = self.tuning_space + dynamic_state = {} + for att in ['weight', 'activation']: + if att not in op_state: + continue + for item_name, item_val in op_state[att].items(): + att_item = (att, item_name) + if att_item not in TUNING_ITEMS_LST: + continue + if tuning_space.query_item_option((op_name, op_type), op_quant_mode, att_item, item_val): + dynamic_state[att_item] = item_val + else: + quant_mode_item = tuning_space.query_quant_mode_item((op_name, op_type), op_quant_mode) + tuning_item = quant_mode_item.get_option_by_name(att_item) + dynamic_state[att_item] = tuning_item.options[0] if tuning_item else None + return OpTuningConfig(op_name, op_type, op_quant_mode, tuning_space, kwargs=dynamic_state) diff --git a/neural_compressor/strategy/utils/tuning_sampler.py b/neural_compressor/strategy/utils/tuning_sampler.py index fea140a9e4d..9b5eff7dc1b 100644 --- a/neural_compressor/strategy/utils/tuning_sampler.py +++ b/neural_compressor/strategy/utils/tuning_sampler.py @@ -254,16 +254,18 @@ def __init__(self, tuning_order_lst: List[TuningOrder], initial_op_tuning_cfg: Dict[tuple, Any], op_dtypes: Dict[str, str], - accumulate: bool + accumulate: bool, + skip_first: bool = True ): super().__init__(tuning_space, tuning_order_lst, initial_op_tuning_cfg) self.op_dtypes = op_dtypes self.accumulate = accumulate + self.skip_first = skip_first pass def __iter__(self): new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) - skip_first = True + skip_first = self.skip_first for op_name_type, target_dtype in self.op_dtypes.items(): if not self.accumulate: new_tune_cfg = copy.deepcopy(self.initial_op_tuning_cfg) diff --git a/test/strategy/test_hawq_v2_2.x.py b/test/strategy/test_hawq_v2_2.x.py new file mode 100644 index 00000000000..19b52e07826 --- /dev/null +++ b/test/strategy/test_hawq_v2_2.x.py @@ -0,0 +1,60 @@ +"""Tests for HAWQ v2 strategy""" + +import copy +import shutil +import unittest + +from neural_compressor.utils import logger + +# loss function for hawq-v2 +def hawq_v2_loss(output, target): + import torch + return torch.nn.CrossEntropyLoss()(output, target) + +class TestHAWQV2TuningStrategy(unittest.TestCase): + + @classmethod + def setUpClass(self): + import torchvision + self.model = torchvision.models.resnet18() + + @classmethod + def tearDownClass(self): + shutil.rmtree('saved', ignore_errors=True) + shutil.rmtree('nc_workspace', ignore_errors=True) + + + def test_hawq_v2_pipeline(self): + logger.info("*** Test: HAWQ v2 with pytorch model.") + from neural_compressor.quantization import fit + from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion + from neural_compressor.data import DATASETS, DATALOADERS + + # model + model = copy.deepcopy(self.model) + + # fake evaluation function + self.test_hawq_v2_pipeline_fake_acc = 0 + def _fake_eval(model): + self.test_hawq_v2_pipeline_fake_acc -= 1 + return self.test_hawq_v2_pipeline_fake_acc + + # dataset and dataloader + dataset = DATASETS("pytorch")["dummy"](((1, 3, 224, 224))) + dataloader = DATALOADERS["pytorch"](dataset) + + #tuning and accuracy criterion + strategy_kwargs = {'hawq_v2_loss': hawq_v2_loss} + tuning_criterion = TuningCriterion(strategy='hawq_v2', strategy_kwargs=strategy_kwargs, max_trials=5) + conf = PostTrainingQuantConfig(approach="static", tuning_criterion=tuning_criterion) + + # fit + q_model = fit(model=model, + conf=conf, + calib_dataloader=dataloader, + eval_dataloader=dataloader, + eval_func=_fake_eval) + self.assertIsNone(q_model) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file