Skip to content

Commit

Permalink
[regression fix] sq enhance calibration part (#1276)
Browse files Browse the repository at this point in the history
Signed-off-by: Lu, Yintong <yintong.lu@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: xinhe <xin3.he@intel.com>
(cherry picked from commit e6eda31)
  • Loading branch information
yintong-lu authored and chensuyue committed Sep 27, 2023
1 parent bd9f093 commit 49e950e
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 95 deletions.
178 changes: 83 additions & 95 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Expand Up @@ -60,7 +60,10 @@ def forward_wrapper(model, input, device=torch.device("cpu")):
if isinstance(input, dict) or isinstance(input, UserDict):
output = model(**input)
elif isinstance(input, list) or isinstance(input, tuple):
output = model(*input)
try:
output = model(*input)
except:
output = model(input)
else:
output = model(input)
return output
Expand Down Expand Up @@ -295,8 +298,6 @@ def __init__(self, model, dataloader, example_inputs=None, q_func=None, traced_m
self.dataloader = dataloader
self.example_inputs = example_inputs
self.q_func = q_func
self.input_values = {}
self.output_values = {}
self.input_maxes = {}
self.input_mins = {}
self.input_maxes_abs = {}
Expand Down Expand Up @@ -325,10 +326,6 @@ def _save_input_pc_hook(self, name, percentile=100):
:return: A hook function."""

def save_input_hook(module, inputs, outputs):
if name not in self.input_maxes.keys():
self.input_maxes[name] = []
self.input_mins[name] = []
self.input_maxes_abs[name] = []
input = inputs[0]
##TODO check input channel is correct
if len(module.weight.shape) == 4: ##conv3d or conv1d not supported now, need better way
Expand All @@ -339,43 +336,16 @@ def save_input_hook(module, inputs, outputs):
k_index = int(input.shape[0] * percentile / 100)
res, _ = torch.kthvalue(torch.abs(input), k_index, dim=0)
##res = torch.max(torch.abs(input),dim=0)[0]
self.input_maxes_abs[name].append(res)
self.input_maxes[name].append(max_tensor)
self.input_mins[name].append(min_tensor)
if name not in self.input_maxes.keys():
self.input_mins[name], self.input_maxes[name] = min_tensor, max_tensor
self.input_maxes_abs[name] = res
else:
self.input_mins[name] = torch.min(self.input_mins[name], min_tensor)
self.input_maxes[name] = torch.max(self.input_maxes[name], max_tensor)
self.input_maxes_abs[name] = torch.max(self.input_maxes_abs[name], res)

return save_input_hook

def _save_input_output_hook(self, name):
"""
A forward hook to save input and output values of a module
param name: the module name
return: A hook function
"""

def save_input_output_hook(module, inputs, outputs):
input = inputs[0]
cnt = 32
if name in self.input_values.keys() and len(self.input_values[name]) < cnt:
self.input_values[name].append(input)
self.output_values[name].append(outputs)
if name not in self.input_values.keys():
self.input_values[name] = [input] ##TODO save more,like 8
self.output_values[name] = [outputs] ##TODO do not save output

return save_input_output_hook

def _add_input_output_observer(self):
input_output_modules = {}
hook_layer_names = []
for key in self.absorb_to_layer:
hook_layer_names += self.absorb_to_layer[key]
for name in hook_layer_names:
input_output_modules[name] = get_module(self.model, name)
for key in input_output_modules.keys():
hook_func = self._save_input_output_hook(key)
hook_handle = input_output_modules[key].register_forward_hook(hook_func)
self.hook_handles.append(hook_handle)

def _add_min_max_observer(self, modules, percentile=100):
"""
:param modules: the modules which the observer will insert to
Expand All @@ -393,7 +363,7 @@ def _remove_observer(self):
for hook_handle in self.hook_handles:
hook_handle.remove()

def _calibrate(self, absorb_to_layer, calib_iter, percentile, save_input_output=False):
def _calibrate(self, absorb_to_layer, calib_iter, percentile):
"""
:param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer
:param calib_iter: Data size for calibration
Expand All @@ -406,8 +376,6 @@ def _calibrate(self, absorb_to_layer, calib_iter, percentile, save_input_output=
hook_modules[n] = module

self._add_min_max_observer(hook_modules, percentile)
if save_input_output:
self._add_input_output_observer()

self._dump_min_max(calib_iter=calib_iter)
self._remove_observer()
Expand All @@ -423,15 +391,6 @@ def _dump_min_max(self, calib_iter=100):
else:
assert self.dataloader, "Please set dataloader for calibration."
model_forward(self.model, self.dataloader, calib_iter, self.device)
##stack
for key in self.input_maxes.keys():
max_val = self.input_maxes[key]
max_val = torch.stack(max_val, dim=0)
min_val = self.input_mins[key]
min_val = torch.stack(min_val, dim=0)
self.input_maxes[key] = torch.max(max_val, dim=0)[0]
self.input_mins[key] = torch.min(min_val, dim=0)[0]
self.input_maxes_abs[key] = torch.max(torch.stack(self.input_maxes_abs[key], dim=0), dim=0)[0]

def _reshape_in_channel_to_last(self, layer_name):
"""Move the input channel to the last dim
Expand Down Expand Up @@ -877,43 +836,76 @@ def _auto_tune_alpha_new(
if not self.dataloader:
self._qdq_model_unwrapper_for_auto()
return best_alphas

for idx, input in enumerate(self.dataloader):
if isinstance(input, (tuple, list)):
input = input[0]
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
break
try:
for input, label in self.dataloader:
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
break
except:
for input in self.dataloader:
best_alphas_per_module = best_alphas
if isinstance(best_alphas, dict):
for key in self.absorb_to_layer.keys():
layer_names = self.absorb_to_layer[key]
for layer_name in layer_names:
best_alphas_per_module[layer_name] = best_alphas_per_module[key]

loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
if loss_alphas == {}:
loss_alphas = loss_tmp
else:
for key in loss_alphas.keys():
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
if isinstance(input, list):
input = move_input_to_device(input, self.device)
for inp in input:
cnt += inp.shape[0]
else:
cnt += input.shape[0]

if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
loss_alphas = {} ##TODO check need to remove this one
if cnt >= auto_calib_iter:
break

best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
Expand Down Expand Up @@ -995,11 +987,8 @@ def transform(
"you could set torchscript to True "
)
return self.model
save_input_output = False if alpha == "auto" else True
# if alpha == "auto":
# save_input_output = True

input_maxes_abs = self._calibrate(self.absorb_to_layer, calib_iter, percentile, save_input_output)
input_maxes_abs = self._calibrate(self.absorb_to_layer, calib_iter, percentile)

# Check if input_maxes match self.absorb_to_layer
# (due to self._get_all_layer_names use layer tree instead of forward_path)
Expand Down Expand Up @@ -1042,7 +1031,6 @@ def transform(
else:
logger.warning(" Could not get example input, equivelancy check is skipped")

self.input_values, self.output_values = {}, {}
return self.model

def output_is_equal(self, out1, out2, atol=1e-04):
Expand Down
19 changes: 19 additions & 0 deletions test/algorithm/test_smooth_quant.py
Expand Up @@ -1265,5 +1265,24 @@ def calib_func(prepared_model):
self.assertEqual(indices[2], torch.tensor([504]))


class TestMemoryUsage(unittest.TestCase):
def test_sq_auto_mem_usage(self):
import psutil

data = psutil.virtual_memory()
cpu_process = psutil.Process()
p = psutil.Process(cpu_process.pid)
mem_use0 = p.memory_info().rss / (1024**3)
model = transformers.AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
torchscript=True,
)
sq = TorchSmoothQuant(model, LLMCalibDataloader())
sq.transform(alpha="auto", calib_iter=0, folding=False)
mem_use1 = p.memory_info().rss / (1024**3)
logger.info(f"The memory usage of this ut is {mem_use1 - mem_use0} GBs.")
assert (mem_use1 - mem_use0) <= 2.0


if __name__ == "__main__":
unittest.main()

0 comments on commit 49e950e

Please sign in to comment.