Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import argparse

from neural_compressor.adaptor.torch_utils.autoround import AutoRound, AutoOPTRound, AutoAdamRound
import sys
from neural_compressor.adaptor.torch_utils.autoround import (AutoRound,
AutoOPTRound,
AutoAdamRound)

parser = argparse.ArgumentParser()
import torch
import os
import re
import json

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True, warn_only=True)
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel

from transformers import set_seed

from eval import eval_model

import re

os.environ["TOKENIZERS_PARALLELISM"] = "false"


Expand Down Expand Up @@ -44,13 +44,13 @@
parser.add_argument("--sym", action='store_true',
help=" sym quantization")

parser.add_argument("--iters", default=400, type=int,
parser.add_argument("--iters", default=200, type=int,
help=" iters")

parser.add_argument("--use_quant_input", action='store_true',
help="whether to use the output of quantized block to tune the next block")

parser.add_argument("--lr", default=0.05, type=float,
parser.add_argument("--lr", default=0.005, type=float,
help="step size")

parser.add_argument("--minmax_lr", default=None, type=float,
Expand Down Expand Up @@ -83,6 +83,9 @@

parser.add_argument("--enable_minmax_tuning", action='store_true',
help="whether enable weight minmax tuning")

parser.add_argument("--use_optimum_format", default=True,
help="whether use HuggingFace format.")

# parser.add_argument("--tasks", default=["lambada_openai", "hellaswag", "winogrande", "piqa"],
# help="lm-eval tasks")
Expand Down Expand Up @@ -186,9 +189,17 @@

optq = round(model, tokenizer, args.num_bits, args.group_size, scheme, bs=args.train_bs,
seqlen=seqlen, n_blocks=args.n_blocks, iters=args.iters, lr=args.lr,
minmax_lr=args.minmax_lr, use_quant_input=args.use_quant_input,
amp=args.amp, n_samples=args.n_samples, low_gpu_mem_usage=args.low_gpu_mem_usage, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps) ##TODO args pass
optq.quantize()
use_quant_input=args.use_quant_input, amp=args.amp, n_samples=args.n_samples,
low_gpu_mem_usage=args.low_gpu_mem_usage, minmax_lr=args.minmax_lr,
seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps) ##TODO args pass
q_model, q_config = optq.quantize()
if args.use_optimum_format:
output_dir = args.output_dir + "_" + args.model_name.split('/')[-1] + "/"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
q_config_path = os.path.join(output_dir, "qconfig.json")
with open(q_config_path, "w") as f:
json.dump(q_config, f, indent=4)

torch.cuda.empty_cache()
model.eval()
Expand All @@ -202,3 +213,4 @@
eval_model(output_dir=output_dir, model=model, tokenizer=tokenizer, tasks=args.tasks, \
eval_bs=args.eval_bs, use_accelerate=args.low_gpu_mem_usage, device=cuda_device, excel_file=excel_name,
limit=None)

Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .autoround import AutoRound, AutoOPTRound, AutoAdamRound
from .export import export_compressed_model
63 changes: 57 additions & 6 deletions neural_compressor/adaptor/torch_utils/autoround/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,13 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", v=0, min_scal
weight = weight.reshape(-1, group_size)
if isinstance(v, torch.Tensor):
v = v.reshape(-1, group_size)

weight, scale, zp = quant_weight_actor(
weight, num_bits, scheme=scheme, v=v, min_scale=min_scale, max_scale=max_scale
)
weight = weight.reshape(orig_shape)
scale = scale.reshape(orig_shape[0], -1) # TODO validating the feasibility on conv1d
if zp is not None:
zp = zp.reshape(orig_shape[0], -1)
return weight, scale, zp

else:
Expand All @@ -160,11 +162,49 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", v=0, min_scal
weight_new, num_bits, scheme=scheme, v=v, min_scale=min_scale, max_scale=max_scale
)
weight_new = weight_new.reshape(orig_shape[0], -1)

scale = scale.reshape(orig_shape[0], -1)
if zp is not None:
zp = zp.reshape(orig_shape[0], -1)
weight_new = weight_new[:, :-pad_len]
scale = scale[:, :-pad_len]
zp = zp[:, :-pad_len]
return weight_new, scale, zp


def quant_weight_w_scale(weight, scale, zp, group_size=-1):
"""Quant and dequant tensor with group size.

Args:
weight: input weight
scale: scale
zp: zero point
group_size (int, optional): how many elements share one scale/zp. Defaults to -1.

Returns:
output: int weight.
"""
device = weight.device
scale = scale.to(device)
if zp is not None:
zp = zp.to(device)
if group_size == -1:
return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp)
int_weight = torch.zeros(weight.shape).to(device)
leng = weight.shape[1] // group_size
tail_flag = False if weight.shape[1] % group_size == 0 else True
for i in range(leng):
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1)
if zp is not None:
int_weight_tmp += zp[:, i].unsqueeze(1)
int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp)
if tail_flag:
int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1)
if zp is not None:
int_weight_tmp += zp[:, -1].unsqueeze(1)
int_weight[:, leng * group_size :] = torch.round(int_weight_tmp)
return int_weight


def round_ste(x: torch.Tensor):
"""Straight-Through Estimator for rounding.
This function is adapted from omniquant.
Expand Down Expand Up @@ -819,6 +859,7 @@ def get_block_names(model):
for n, m in model.named_modules():
if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
target_m = (n, m)
break
for n, m in target_m[1].named_children():
block_names.append(target_m[0] + "." + n)
return block_names
Expand Down Expand Up @@ -976,7 +1017,6 @@ def __init__(
self.amp = amp
self.use_quant_input = use_quant_input
self.enable_minmax_tuning = enable_minmax_tuning
self.n_samples = n_samples
self.n_blocks = n_blocks
self.bits = bits
self.group_size = group_size
Expand All @@ -997,6 +1037,8 @@ def __init__(
self.tokenizer = tokenizer
self.seqlen = seqlen
self.train_bs = bs
self.n_samples = bs * (n_samples // bs)
assert self.n_samples > 0, f"Recommend setting an n_samples that is divisible by batch size{self.train_bs}"
self.n_blocks = n_blocks
self.device = device
self.amp_dtype = torch.float16
Expand Down Expand Up @@ -1393,20 +1435,29 @@ def quantize(self):
if n in self.weight_config.keys():
if hasattr(m, "scale"):
self.weight_config[n]["scale"] = m.scale
# self.weight_config[n]["scale_dtype"] = m.scale.dtype
self.weight_config[n]["zp"] = m.zp
# self.weight_config[n]["zp_dtype"] = m.zp.dtype
delattr(m, "scale")
delattr(m, "zp")
else:
self.weight_config[n]["data_type"] = "float"
if self.amp_dtype == torch.bfloat16:
self.weight_config[n]["data_type"] = "bfloat"
self.weight_config[n]["bits"] = 16
self.weight_config[n]["bits"] = 32
if self.amp:
self.weight_config[n]["bits"] = 16
if self.amp_dtype == torch.bfloat16:
self.weight_config[n]["data_type"] = "bfloat"
self.weight_config[n]["group_size"] = None
self.weight_config[n]["sym"] = None

for k, v in self.weight_config.items():
for m, n in v.items():
if isinstance(n, torch.Tensor):
self.weight_config[k][m] = n.tolist()
end_time = time.time()
cost_time = end_time - start_time
logger.info(f"quantization runtime {cost_time}")

return self.model, self.weight_config


Expand Down
99 changes: 99 additions & 0 deletions neural_compressor/adaptor/torch_utils/autoround/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
from typing import Union

try:
from neural_compressor.utils.utility import LazyImport

torch = LazyImport("torch")
from neural_compressor.utils import logger
except: # pragma: no cover
import logging

import torch

logger = logging.getLogger()


def export_compressed_model(
model,
weight_config: Union[str, dict],
enable_full_range=False,
compression_dtype=torch.int32,
compression_dim=1,
scale_dtype=torch.float32,
device="cpu",
use_optimum_format=True,
):
"""Convert Linear to WeightOnlyLinear for low memory inference.
Args:
weight_config (str|dict): qconfig dict or Path of qconfig.json.
enable_full_range (bool, optional): Whether to leverage the full compression range
under symmetric quantization. Defaults to False.
compression_dtype (torch.Tensor, optional): The target dtype after comoression.
Defaults to torch.int32.
compression_dim (int, optional): Select from [0, 1], 0 is output channel,
1 is input channel. Defaults to 1.
scale_dtype (torch.Tensor, optional): Use float32 or float16.
Defaults to torch.float32.
device (str, optional): choose device for compression. Defaults to cpu.
use_optimum_format (bool, optional): use the popular huggingface compression format.
1: compression_dim: weight = 1, zeros = 0 and both are transposed.
2: zeros -= 1 before compression. Why we need it?
3: g_idx: use same number for one group instead of recording the channel order.
4. parameter name changed, such as 'packed_weight' -> 'qweight'.
5. zeros is always needed even for sym.
"""
from .autoround import get_module, quant_weight_w_scale, set_module
from .model_wrapper import WeightOnlyLinear

compressed_model = copy.deepcopy(model)
if isinstance(weight_config, str):
with open(weight_config, "r") as f:
q_config = json.load(f)
else:
q_config = weight_config
for k, v in q_config.items():
logger.info(f"Compressing {k} on device {device}")
if v["data_type"] == "float":
continue
else:
dtype = v["data_type"]
num_bits = v["bits"]
group_size = v["group_size"]
scheme = v["scheme"]
m = get_module(compressed_model, k)
fp_weight = m.weight.data
scale = torch.tensor(v["scale"], dtype=torch.float32) # may exist dtype dismatch problem
zp = None if scheme == "sym" else torch.tensor(v["zp"], dtype=torch.int32)
int_weight = quant_weight_w_scale(fp_weight, scale, zp, group_size)
int_weight = int_weight.type(torch.int32)
new_module = WeightOnlyLinear(
m.in_features,
m.out_features,
num_bits,
group_size,
dtype=dtype,
zp=zp is not None,
bias=m.bias is not None,
device=device,
use_optimum_format=True,
)
new_module.pack(int_weight, scale, zp, m.bias)
set_module(compressed_model, k, new_module)
return compressed_model
Loading