Skip to content

Commit

Permalink
Merge branch 'master' into quality
Browse files Browse the repository at this point in the history
  • Loading branch information
yiliu30 committed May 24, 2024
2 parents 1b40feb + 54f039d commit e8460e2
Show file tree
Hide file tree
Showing 38 changed files with 769 additions and 327 deletions.
2 changes: 2 additions & 0 deletions docs/source/quantization_weight_only.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Notes:
| use_max_length | False | Whether to align all calibration data to fixed length, which equals to pad_max_length. |
| block_size | 128 | Execute GPTQ quantization per block, block shape = [$C_{out}$, block_size] |
| static_groups | False | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements |
| true_sequential | False | Whether to quantize layers within a transformer block in their original order. This can lead to higher accuracy but slower overall quantization process. |
| lm_head | False | Whether to quantize the lm_head (linear layer related to prediction in the end of the language models). |

**Note:** Neural compressor provides `Unsigned integer for asymmetric quantization` and `Signed integer for symmetric quantization`. Please follow the below section to compress the low bit data type for saving.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import onnxruntime as ort
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import LlamaConfig, LlamaTokenizer

Expand Down Expand Up @@ -198,28 +197,33 @@ def replace_architectures(json_path):
json.dump(data, file, indent=4)

def eval_func(model):
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser

model_dir = model
if isinstance(model, str) and model.endswith(".onnx"):
model_dir = os.path.dirname(model)

replace_architectures(os.path.join(model_dir, "config.json"))

results = evaluate(
model="hf-causal",
model_args="pretrained=" + model_dir + ",tokenizer="+ args.tokenizer,
eval_args = LMEvalParser(
model="hf",
model_args="pretrained=" + model_dir + ",tokenizer=" + args.tokenizer + ",model_format=onnx",
batch_size=args.batch_size,
tasks=args.tasks,
model_format="onnx",
tasks=','.join(args.tasks),
device="cpu",
)
results = evaluate(eval_args)

eval_acc = 0
for task_name in args.tasks:
if task_name == "wikitext":
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
eval_acc += results["results"][task_name]["word_perplexity"]
print("Accuracy for %s is: %s" %
(task_name, results["results"][task_name]["word_perplexity,none"]))
eval_acc += results["results"][task_name]["word_perplexity,none"]
else:
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))
eval_acc += results["results"][task_name]["acc"]
print("Accuracy for %s is: %s" %
(task_name, results["results"][task_name]["acc,none"]))
eval_acc += results["results"][task_name]["acc,none"]

if len(args.tasks) != 0:
eval_acc /= len(args.tasks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ onnxruntime-extensions; python_version < '3.11'
datasets
optimum
evaluate
intel-extension-for-transformers
intel-extension-for-transformers >= 1.4.1
peft
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
lm-eval==0.4.2
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import onnxruntime as ort
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import LlamaConfig, LlamaTokenizer

Expand Down Expand Up @@ -135,28 +134,33 @@ def replace_architectures(json_path):
json.dump(data, file, indent=4)

def eval_func(model):
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser

model_dir = model
if isinstance(model, str) and model.endswith(".onnx"):
model_dir = os.path.dirname(model)

replace_architectures(os.path.join(model_dir, "config.json"))

results = evaluate(
model="hf-causal",
model_args="pretrained=" + model_dir + ",tokenizer="+ args.tokenizer,
eval_args = LMEvalParser(
model="hf",
model_args="pretrained=" + model_dir + ",tokenizer=" + args.tokenizer + ",model_format=onnx",
batch_size=args.batch_size,
tasks=args.tasks,
model_format="onnx",
tasks=','.join(args.tasks),
device="cpu",
)
results = evaluate(eval_args)

eval_acc = 0
for task_name in args.tasks:
if task_name == "wikitext":
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
eval_acc += results["results"][task_name]["word_perplexity"]
print("Accuracy for %s is: %s" %
(task_name, results["results"][task_name]["word_perplexity,none"]))
eval_acc += results["results"][task_name]["word_perplexity,none"]
else:
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))
eval_acc += results["results"][task_name]["acc"]
print("Accuracy for %s is: %s" %
(task_name, results["results"][task_name]["acc,none"]))
eval_acc += results["results"][task_name]["acc,none"]

if len(args.tasks) != 0:
eval_acc /= len(args.tasks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ onnxruntime-extensions; python_version < '3.11'
datasets
optimum
evaluate
intel-extension-for-transformers
intel-extension-for-transformers >= 1.4.1
peft
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
lm-eval==0.4.2
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
this should align with your model config, \
and your dataset builder args: args.pad_max_length')
parser.add_argument('--gptq_static_groups', action='store_true', help='Use determined group to do quantization')
parser.add_argument('--gptq_true_sequential', action='store_true', help="Whether to run in true_sequential model.")
parser.add_argument('--gptq_lm_head', action='store_true', help="Whether to use GPTQ to quantize the output layer of the LLMs.")
# ==============code generation args===========
parser.add_argument("--code_generation", action="store_true")
parser.add_argument("--n_samples", default=200, type=int)
Expand Down Expand Up @@ -278,7 +280,8 @@ def calib_func(prepared_model):
'use_max_length': args.gptq_use_max_length,
'pad_max_length': args.gptq_pad_max_length,
'static_groups': args.gptq_static_groups,
"enable_mse_search": args.woq_enable_mse_search,
"true_sequential": args.gptq_true_sequential,
"lm_head": args.gptq_lm_head,
}
# GPTQ: use assistive functions to modify calib_dataloader and calib_func
# TEQ: set calib_func=None, use default training func as calib_func
Expand Down
49 changes: 27 additions & 22 deletions neural_compressor/adaptor/ox_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,8 @@ def make_matmul_weight_only_node(
op_type = "MatMulNBits"

# pack quantized weight
for i in range(q_weight.shape[0]):
for k in range(0, group_size, 2):
packed[i][k // 2] = q_weight[i][k] | q_weight[i][k + 1] << 4
q_weight_pairs = q_weight[:, ::2] | q_weight[:, 1::2] << 4
packed[:, :] = q_weight_pairs[:, :blob_size]
packed = np.reshape(packed, (-1, k_blocks, blob_size))

# build scale tensor
Expand All @@ -120,15 +119,14 @@ def make_matmul_weight_only_node(
packed_zp = np.reshape(zero_point, (1, -1)).astype("uint8")
else:
packed_zp = np.full((zero_point.shape[0] + 1) // 2, 136, dtype="uint8")
for i in range(zero_point.shape[0] // k_blocks):
for j in range(k_blocks):
idx = i * k_blocks + j
zp = zero_point[idx]
packed_zp[idx // 2] = (
((packed_zp[idx // 2] & 0x0F) | (zp << 4))
if (idx & 1)
else ((packed_zp[idx // 2] & 0xF0) | zp)
)
# create an index array
idx = np.arange(zero_point.shape[0] // k_blocks * k_blocks).reshape(-1)
# separate odd and even indices
even_idx = idx[::2]
odd_idx = idx[1::2]
# vectorized operation for even and odd indices
packed_zp[even_idx // 2] = (packed_zp[even_idx // 2] & 0xF0) | zero_point[even_idx].ravel()
packed_zp[odd_idx // 2] = (packed_zp[odd_idx // 2] & 0x0F) | (zero_point[odd_idx].ravel() << 4)

zp_tensor = onnx.helper.make_tensor(
name=node.input[1] + "_zp", data_type=2, dims=packed_zp.shape, vals=packed_zp.tobytes(), raw=True
Expand Down Expand Up @@ -224,9 +222,8 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
if scheme == "sym":
max_range = np.maximum(np.abs(rmin), np.abs(rmax))
scale = np.ones(rmax.shape)
scale[max_range > 0] = np.array(
[float(i) / (maxq - minq) for i in (max_range[max_range > 0] * 2.0).flatten().tolist()]
)
mask = max_range > 0
scale[mask] = (max_range[mask] * 2.0).astype(np.float64) / (maxq - minq)
zero_point = (
np.zeros(scale.shape) if dtype == "int" else np.ones(rmax.shape, dtype="uint8") * (1 << (num_bits - 1))
)
Expand All @@ -240,7 +237,14 @@ def quant_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ra
if dtype == "int"
else np.maximum(0, np.minimum(maxq, ((np.zeros(scale.shape) - rmin) / scale).round())).astype("uint8")
)
return np.clip((data / scale + zero_point).round(), minq, maxq), scale, zero_point

q_weight = np.empty_like(data, dtype=scale.dtype)
np.divide(data, scale, out=q_weight)
np.add(q_weight, zero_point, out=q_weight)
np.round(q_weight, out=q_weight)
np.clip(q_weight, minq, maxq, out=q_weight)

return q_weight, scale, zero_point


def qdq_tensor(data, num_bits=4, group_size=32, scheme="asym", dtype="int", ratio=1.0):
Expand Down Expand Up @@ -756,6 +760,7 @@ def awq_quantize(
model.remove_tensors_from_outputs([i.name for i in org_output])

output_names = []

for node in model.nodes():
if (
node.op_type in ["MatMul"]
Expand Down Expand Up @@ -927,8 +932,8 @@ def find_params(weight):
perm = np.argsort(np.diag(H))[::-1]
W = W[perm, :]
H = H[perm, :][:, perm]
Losses = np.zeros(W.shape)
Q = np.zeros(W.shape)
Losses = np.zeros_like(W)
Q = np.zeros_like(W)
damp = percdamp * np.mean(np.diag(H))
diag = np.arange(shape[0])
H[diag, diag] += damp # add a average value of
Expand All @@ -939,9 +944,9 @@ def find_params(weight):
count = i2 - i1

W1 = copy.deepcopy(W[i1:i2, :])
Q1 = np.zeros(W1.shape)
Err1 = np.zeros(W1.shape)
Losses1 = np.zeros(W1.shape)
Q1 = np.zeros_like(W1)
Err1 = np.zeros_like(W1)
Losses1 = np.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

for i in range(count): # within a block, channel wise
Expand All @@ -952,7 +957,7 @@ def find_params(weight):
if (i1 + i) % group_size == 0:
scale, zp = find_params(W[(i1 + i) : (i1 + i + group_size), :])

q = (scale * (np.clip(np.round(np.expand_dims(w, axis=1) / scale) + zp, 0, maxq) - zp)).flatten()
q = (scale * (np.clip(np.round(w[:, np.newaxis] / scale) + zp, 0, maxq) - zp)).flatten()
Q1[i, :] = q
Losses1[i, :] = (w - q) ** 2 / d**2

Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4722,6 +4722,8 @@ def gptq_quantize(self, model, tune_cfg, dataloader):
"act_order": self.recipes["gptq_args"].get("act_order", False),
"block_size": self.recipes["gptq_args"].get("block_size", True),
"static_groups": self.recipes["gptq_args"].get("static_groups", False),
"true_sequential": self.recipes["gptq_args"].get("true_sequential", False),
"lm_head": self.recipes["gptq_args"].get("lm_head", False),
}
nsamples = self.recipes["gptq_args"].get("nsamples", 128)
use_max_length = self.recipes["gptq_args"].get("use_max_length", False)
Expand Down
Loading

0 comments on commit e8460e2

Please sign in to comment.