Skip to content

Commit

Permalink
fix conflicts and enable bias_shifting for input_maxes before tuning
Browse files Browse the repository at this point in the history
Signed-off-by: Lu, Yintong <yintong.lu@intel.com>
  • Loading branch information
yintong-lu committed Sep 13, 2023
1 parent aedd54c commit 7144235
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 33 deletions.
10 changes: 9 additions & 1 deletion neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,9 +1880,17 @@ def qdq_quantize(self, model, tune_cfg):
input_power = torch.pow(abs_input_max, alpha)
weight_power = torch.pow(weight_max, 1 - alpha)
scale = torch.clip(input_power / weight_power, min=1e-5)
if 'os_bias' in info: #lyt_os_debug_0822
bias_alpha = info['os_bias']
else:
bias_alpha = None
for op_name in absorbed_layer:
module = fetch_module(q_model, op_name)
new_module = SQLinearWrapper(module, 1.0 / scale, input_minmax, alpha)
new_module = SQLinearWrapper(module, 1.0/scale, input_minmax, alpha, bias_alpha=bias_alpha, layer_name=_) #lyt_os_debug_0822
if bias_alpha is not None:
logger.info(f"lyt_debug qdq_quantize bias_wrapper {_}, {bias_alpha.size()}, layer_name: {_}, absorbed_layer: {absorbed_layer}")
else:
logger.info(f"lyt_debug qdq_quantize bias_wrapper {_}, None, layer_name: {_}, absorbed_layer: {absorbed_layer}")
set_module(q_model, op_name, new_module)
logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}")

Expand Down
19 changes: 18 additions & 1 deletion neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,13 @@ def _wrap_lwq_layer(model, lwq_layers, op_cfgs):


class SQLinearWrapper(torch.nn.Module):
def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8):
def __init__(self, module, input_scale, input_minmax, alpha=0.5, dtype=torch.quint8, bias_alpha=None, layer_name=None): #lyt_os_debug
super().__init__()
self.register_buffer("input_scale", input_scale)
self.alpha = alpha
self.dtype = dtype
self.bias_alpha=bias_alpha #lyt_os_debug
self.layer_name = layer_name
# calculate and only save scale, zero_point to avoid memory usage
self.scale, self.zero_point = self._calculate_qparams(input_scale, input_minmax, dtype)
self.add_module("sq_linear", module)
Expand All @@ -122,6 +124,8 @@ def weight(self):
return self.sq_linear.weight

def forward(self, X):
if self.bias_alpha != None: #lyt_os_debug_0822
X = X - self.bias_alpha
if self.ipex:
X = self.sq_linear(X)
else:
Expand Down Expand Up @@ -157,7 +161,20 @@ def _update_sq_linear(self):
# remove mul and reset sq_linear for ipex inference
scale = self.input_scale.view(1, self.input_scale.shape[0])
with torch.no_grad():
import copy #lyt_os_debug
layer_weight = copy.deepcopy(self.sq_linear.weight) #lyt_os_debug
self.sq_linear.weight /= scale
if self.bias_alpha != None: #lyt_os_debug_0822:
res = torch.matmul(self.bias_alpha, layer_weight.transpose(0, 1))
layer_bias = copy.deepcopy(self.sq_linear.bias)
if self.sq_linear.bias != None:
bias_tmp = self.sq_linear.bias + res
self.sq_linear.bias.data.copy_(bias_tmp)
logger.info(f"lyt_debug bias shifted AP: {self.layer_name} {self.sq_linear.bias.data.size()}, {torch.all(layer_bias.data == self.sq_linear.bias.data)}")
else:
self.sq_linear.bias = torch.nn.Parameter(res)
logger.info(f"lyt_debug bias created AP: {self.layer_name} {self.sq_linear.bias.data.size()}, {layer_bias==None}")


def _recover_sq_linear(self):
# remove mul and reset sq_linear for ipex inference
Expand Down

0 comments on commit 7144235

Please sign in to comment.