Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support bias shift in outlier suppression+ #1231

Closed
wants to merge 66 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
453125a
[Algo] fix conflicts
yintong-lu Nov 17, 2023
f72ab8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
5410d54
Merge branch 'master' into lyt/os
yintong-lu Nov 20, 2023
6e5c2a6
[Algo] code update
yintong-lu Nov 20, 2023
6b81f7b
[Algo] code update 1120
yintong-lu Nov 20, 2023
52a5a2e
[Algo] add new RMSNorm class
yintong-lu Nov 20, 2023
e9223e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2023
567f2f7
[Algo] update RMSnorm
yintong-lu Nov 21, 2023
3a0a7d4
[Algo] fix conflicts
yintong-lu Nov 23, 2023
2b5392f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2023
c820dc6
[Algo] log update
yintong-lu Nov 23, 2023
84e283e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2023
61275ed
[Algo] fix bug and support mistral models
yintong-lu Nov 29, 2023
d8f1f83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2023
6999a6a
[Algo] update log
yintong-lu Dec 4, 2023
29f70e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2023
9ee80e5
[Algo] log update
yintong-lu Dec 4, 2023
5be85db
[Algo] log update
yintong-lu Dec 4, 2023
eba857c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2023
8ede118
[Algo] log update
yintong-lu Dec 4, 2023
9b4f3da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 4, 2023
378d2e2
[Algo] log update
yintong-lu Dec 12, 2023
97edb69
[Algo] code update
yintong-lu Dec 12, 2023
ee400a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2023
96858c7
[Algo] log update
yintong-lu Dec 12, 2023
798290d
[Algo] update comment
yintong-lu Dec 12, 2023
e4d8ce2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2023
51ddcf1
[Algo] fix conflicts w.r.t blockwise
yintong-lu Dec 13, 2023
d15d6be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2023
ff76a2d
[Algo] op-replacement for llama and mistral
yintong-lu Dec 18, 2023
2d2d79a
[Algo] fix bug
yintong-lu Dec 18, 2023
d625d6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2023
462cef2
[Algo] reconfigure bias_shift argument
yintong-lu Dec 18, 2023
98525ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2023
bc69940
[Algo] reconfigure bias_shift argument
yintong-lu Dec 18, 2023
b99b90b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2023
410bf4c
minor change
yintong-lu Dec 18, 2023
480435a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2023
6082a7c
minor change
yintong-lu Dec 18, 2023
33ca458
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2023
9731f60
Merge branch 'master' into lyt/os
yintong-lu Dec 19, 2023
5dbd8e5
[Algo] add ut
yintong-lu Dec 19, 2023
092ef69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2023
7e325a3
[Algo] format comments
yintong-lu Dec 19, 2023
18b106d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2023
4a81253
fix bug
yintong-lu Dec 19, 2023
806d4b1
minor change
yintong-lu Dec 19, 2023
2663c0c
move code
yintong-lu Dec 19, 2023
38b068a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2023
a9f2541
remove comments
yintong-lu Dec 19, 2023
df2bcf5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2023
dd46b16
fix bug
yintong-lu Dec 19, 2023
a185add
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2023
15fcb53
fix bug
yintong-lu Dec 19, 2023
d309e1c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2023
fcdf10b
fix bug
yintong-lu Dec 20, 2023
d147589
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2023
fbca88b
minor change
yintong-lu Dec 20, 2023
5535121
fix issues
yintong-lu Dec 20, 2023
9f893ed
fix issues
yintong-lu Dec 20, 2023
ed5e1c8
minor change
yintong-lu Dec 20, 2023
44fecd4
code enhance
yintong-lu Dec 20, 2023
f569b02
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2023
6386961
codestyle fix
yintong-lu Dec 21, 2023
6059af6
rename blockwise arg to avoid itrex ut error
yintong-lu Dec 21, 2023
7be9a66
minor change
yintong-lu Dec 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def smooth_quant(
"alpha_step": 0.1,
"shared_criterion": "mean",
"do_blockwise": False,
"enable_bias_shift": False,
},
default_alpha=0.5,
):
Expand All @@ -201,6 +202,7 @@ def smooth_quant(
auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning.
By default the search space is 0.0-1.0 with step_size 0.1.
do_blockwise: Whether to do blockwise auto-tuning.
enable_bias_shift: Whether to do bias-shifting.
default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5.

Returns:
Expand Down
33 changes: 30 additions & 3 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,7 @@ def smooth_quant(
"alpha_step": 0.1,
"shared_criterion": "mean",
"do_blockwise": False,
"enable_bias_shift": False,
},
default_alpha=0.5,
):
Expand All @@ -1763,8 +1764,10 @@ def smooth_quant(
auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning.
By default the search space is 0.0-1.0 with step_size 0.1.
do_blockwise determines whether to do blockwise auto-tuning.
enable_bias_shift determines whether to do bias-shifting.
default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5.


Returns:
model: A modified fp32 model, inplace=True.
"""
Expand Down Expand Up @@ -2017,6 +2020,10 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):

# For smoothquant optimized model
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if "smooth_quant_args" in recipe_cfgs and "auto_alpha_args" in recipe_cfgs["smooth_quant_args"]:
enable_bias_shift = recipe_cfgs["smooth_quant_args"]["auto_alpha_args"].get("enable_bias_shift", False)
else:
enable_bias_shift = False
if (
recipe_cfgs
and recipe_cfgs.get("smooth_quant", False)
Expand All @@ -2025,7 +2032,12 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
):
return self.qdq_quantize(q_model, tune_cfg)

if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and recipe_cfgs["smooth_quant_args"]["folding"]:
if (
recipe_cfgs
and recipe_cfgs.get("smooth_quant", False)
and recipe_cfgs["smooth_quant_args"]["folding"]
and not enable_bias_shift
):
self._apply_pre_optimization(q_model, tune_cfg)

# For tensorboard display
Expand Down Expand Up @@ -2671,6 +2683,10 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):

# check smoothquant folding value
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if "smooth_quant_args" in recipe_cfgs and "auto_alpha_args" in recipe_cfgs["smooth_quant_args"]:
enable_bias_shift = recipe_cfgs["smooth_quant_args"]["auto_alpha_args"].get("enable_bias_shift", False)
else:
enable_bias_shift = False
if "smooth_quant_args" in recipe_cfgs and "folding" in recipe_cfgs["smooth_quant_args"]:
if recipe_cfgs["smooth_quant_args"]["folding"] is None:
if self.version.release < Version("2.1").release:
Expand All @@ -2679,6 +2695,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
folding = False
else:
folding = recipe_cfgs["smooth_quant_args"]["folding"]
logger.debug(f"SQ Ipex whether to perform bias_shift: {enable_bias_shift}, folding: {folding}")

# Update model parameter when smoothquant folding = False
if (
recipe_cfgs
Expand All @@ -2688,7 +2706,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
):
return self.qdq_quantize(model, q_model, tune_cfg, dataloader, q_func)
# Update model parameter when smoothquant folding = True
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding:
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and folding and not enable_bias_shift:
self._apply_pre_optimization(model, tune_cfg)

assert (
Expand Down Expand Up @@ -3514,14 +3532,23 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):

# For smoothquant optimized model
recipe_cfgs = tune_cfg.get("recipe_cfgs", None)
if "smooth_quant_args" in recipe_cfgs and "auto_alpha_args" in recipe_cfgs["smooth_quant_args"]:
enable_bias_shift = recipe_cfgs["smooth_quant_args"]["auto_alpha_args"].get("enable_bias_shift", False)
else:
enable_bias_shift = False
if (
recipe_cfgs
and recipe_cfgs.get("smooth_quant", False)
and not recipe_cfgs["smooth_quant_args"]["folding"]
and self.approach != "post_training_dynamic_quant"
):
return self.qdq_quantize(q_model, tune_cfg)
if recipe_cfgs and recipe_cfgs.get("smooth_quant", False) and recipe_cfgs["smooth_quant_args"]["folding"]:
if (
recipe_cfgs
and recipe_cfgs.get("smooth_quant", False)
and recipe_cfgs["smooth_quant_args"]["folding"]
and not enable_bias_shift
):
self._apply_pre_optimization(q_model, tune_cfg)

self.tune_cfg = tune_cfg
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,7 @@ def smooth_quant(
"alpha_step": 0.1,
"shared_criterion": "mean",
"do_blockwise": False,
"enable_bias_shift": False,
},
default_alpha=0.5,
):
Expand All @@ -1859,6 +1860,7 @@ def smooth_quant(
auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning.
By default the search space is 0.0-1.0 with step_size 0.1.
do_blockwise: Whether to do blockwise auto-tuning.
enable_bias_shift: Whether to do bias-shifting.
default_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5.

Returns:
Expand Down
44 changes: 44 additions & 0 deletions neural_compressor/adaptor/torch_utils/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,47 @@ def _recover_linear(self):
scale = self.input_scale.view(1, self.input_scale.shape[0])
with torch.no_grad():
self.linear.weight *= scale


class LlamaRMSNorm_bias(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, bias=None):
"""LlamaRMSNorm is equivalent to T5LayerNorm.

Add bias attribute and modify forward function for bias-shifting.
"""
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.bias = bias

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.bias is not None:
return self.weight * hidden_states.to(input_dtype) + self.bias.to(input_dtype)
else:
return self.weight * hidden_states.to(input_dtype)


class MistralRMSNorm_bias(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, bias=None):
"""MistralRMSNorm is equivalent to T5LayerNorm.

Add bias attribute and modify forward function for bias-shifting.
"""
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.bias = bias

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if hasattr(self, "bias") and self.bias is not None:
return self.weight * hidden_states.to(input_dtype) + self.bias.to(input_dtype)
else:
return self.weight * hidden_states.to(input_dtype)