diff --git a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/autoround/main.py b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/autoround/main.py index 4217ada9976..f699d3e298c 100644 --- a/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/autoround/main.py +++ b/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/autoround/main.py @@ -1,21 +1,21 @@ import argparse - -from neural_compressor.adaptor.torch_utils.autoround import AutoRound, AutoOPTRound, AutoAdamRound +import sys +from neural_compressor.adaptor.torch_utils.autoround import (AutoRound, + AutoOPTRound, + AutoAdamRound) parser = argparse.ArgumentParser() import torch import os +import re +import json os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.use_deterministic_algorithms(True, warn_only=True) from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel - from transformers import set_seed - from eval import eval_model -import re - os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -44,13 +44,13 @@ parser.add_argument("--sym", action='store_true', help=" sym quantization") - parser.add_argument("--iters", default=400, type=int, + parser.add_argument("--iters", default=200, type=int, help=" iters") parser.add_argument("--use_quant_input", action='store_true', help="whether to use the output of quantized block to tune the next block") - parser.add_argument("--lr", default=0.05, type=float, + parser.add_argument("--lr", default=0.005, type=float, help="step size") parser.add_argument("--minmax_lr", default=None, type=float, @@ -83,6 +83,9 @@ parser.add_argument("--enable_minmax_tuning", action='store_true', help="whether enable weight minmax tuning") + + parser.add_argument("--use_optimum_format", default=True, + help="whether use HuggingFace format.") # parser.add_argument("--tasks", default=["lambada_openai", "hellaswag", "winogrande", "piqa"], # help="lm-eval tasks") @@ -186,9 +189,17 @@ optq = round(model, tokenizer, args.num_bits, args.group_size, scheme, bs=args.train_bs, seqlen=seqlen, n_blocks=args.n_blocks, iters=args.iters, lr=args.lr, - minmax_lr=args.minmax_lr, use_quant_input=args.use_quant_input, - amp=args.amp, n_samples=args.n_samples, low_gpu_mem_usage=args.low_gpu_mem_usage, seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps) ##TODO args pass - optq.quantize() + use_quant_input=args.use_quant_input, amp=args.amp, n_samples=args.n_samples, + low_gpu_mem_usage=args.low_gpu_mem_usage, minmax_lr=args.minmax_lr, + seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps) ##TODO args pass + q_model, q_config = optq.quantize() + if args.use_optimum_format: + output_dir = args.output_dir + "_" + args.model_name.split('/')[-1] + "/" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + q_config_path = os.path.join(output_dir, "qconfig.json") + with open(q_config_path, "w") as f: + json.dump(q_config, f, indent=4) torch.cuda.empty_cache() model.eval() @@ -202,3 +213,4 @@ eval_model(output_dir=output_dir, model=model, tokenizer=tokenizer, tasks=args.tasks, \ eval_bs=args.eval_bs, use_accelerate=args.low_gpu_mem_usage, device=cuda_device, excel_file=excel_name, limit=None) + diff --git a/neural_compressor/adaptor/torch_utils/autoround/__init__.py b/neural_compressor/adaptor/torch_utils/autoround/__init__.py index 96c727d97ef..060e02e45c9 100644 --- a/neural_compressor/adaptor/torch_utils/autoround/__init__.py +++ b/neural_compressor/adaptor/torch_utils/autoround/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .autoround import AutoRound, AutoOPTRound, AutoAdamRound +from .export import export_compressed_model diff --git a/neural_compressor/adaptor/torch_utils/autoround/autoround.py b/neural_compressor/adaptor/torch_utils/autoround/autoround.py index fa66a25e0ea..0c9b55906a1 100644 --- a/neural_compressor/adaptor/torch_utils/autoround/autoround.py +++ b/neural_compressor/adaptor/torch_utils/autoround/autoround.py @@ -142,11 +142,13 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", v=0, min_scal weight = weight.reshape(-1, group_size) if isinstance(v, torch.Tensor): v = v.reshape(-1, group_size) - weight, scale, zp = quant_weight_actor( weight, num_bits, scheme=scheme, v=v, min_scale=min_scale, max_scale=max_scale ) weight = weight.reshape(orig_shape) + scale = scale.reshape(orig_shape[0], -1) # TODO validating the feasibility on conv1d + if zp is not None: + zp = zp.reshape(orig_shape[0], -1) return weight, scale, zp else: @@ -160,11 +162,49 @@ def quant_weight(weight, num_bits=4, group_size=-1, scheme="asym", v=0, min_scal weight_new, num_bits, scheme=scheme, v=v, min_scale=min_scale, max_scale=max_scale ) weight_new = weight_new.reshape(orig_shape[0], -1) - + scale = scale.reshape(orig_shape[0], -1) + if zp is not None: + zp = zp.reshape(orig_shape[0], -1) weight_new = weight_new[:, :-pad_len] + scale = scale[:, :-pad_len] + zp = zp[:, :-pad_len] return weight_new, scale, zp +def quant_weight_w_scale(weight, scale, zp, group_size=-1): + """Quant and dequant tensor with group size. + + Args: + weight: input weight + scale: scale + zp: zero point + group_size (int, optional): how many elements share one scale/zp. Defaults to -1. + + Returns: + output: int weight. + """ + device = weight.device + scale = scale.to(device) + if zp is not None: + zp = zp.to(device) + if group_size == -1: + return torch.round(weight / scale) if zp is None else torch.round(weight / scale + zp) + int_weight = torch.zeros(weight.shape).to(device) + leng = weight.shape[1] // group_size + tail_flag = False if weight.shape[1] % group_size == 0 else True + for i in range(leng): + int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size] / scale[:, i].unsqueeze(1) + if zp is not None: + int_weight_tmp += zp[:, i].unsqueeze(1) + int_weight[:, i * group_size : (i + 1) * group_size] = torch.round(int_weight_tmp) + if tail_flag: + int_weight_tmp = weight[:, leng * group_size :] / scale[:, -1].unsqueeze(1) + if zp is not None: + int_weight_tmp += zp[:, -1].unsqueeze(1) + int_weight[:, leng * group_size :] = torch.round(int_weight_tmp) + return int_weight + + def round_ste(x: torch.Tensor): """Straight-Through Estimator for rounding. This function is adapted from omniquant. @@ -819,6 +859,7 @@ def get_block_names(model): for n, m in model.named_modules(): if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: target_m = (n, m) + break for n, m in target_m[1].named_children(): block_names.append(target_m[0] + "." + n) return block_names @@ -976,7 +1017,6 @@ def __init__( self.amp = amp self.use_quant_input = use_quant_input self.enable_minmax_tuning = enable_minmax_tuning - self.n_samples = n_samples self.n_blocks = n_blocks self.bits = bits self.group_size = group_size @@ -997,6 +1037,8 @@ def __init__( self.tokenizer = tokenizer self.seqlen = seqlen self.train_bs = bs + self.n_samples = bs * (n_samples // bs) + assert self.n_samples > 0, f"Recommend setting an n_samples that is divisible by batch size{self.train_bs}" self.n_blocks = n_blocks self.device = device self.amp_dtype = torch.float16 @@ -1393,20 +1435,29 @@ def quantize(self): if n in self.weight_config.keys(): if hasattr(m, "scale"): self.weight_config[n]["scale"] = m.scale + # self.weight_config[n]["scale_dtype"] = m.scale.dtype self.weight_config[n]["zp"] = m.zp + # self.weight_config[n]["zp_dtype"] = m.zp.dtype delattr(m, "scale") delattr(m, "zp") else: self.weight_config[n]["data_type"] = "float" - if self.amp_dtype == torch.bfloat16: - self.weight_config[n]["data_type"] = "bfloat" - self.weight_config[n]["bits"] = 16 + self.weight_config[n]["bits"] = 32 + if self.amp: + self.weight_config[n]["bits"] = 16 + if self.amp_dtype == torch.bfloat16: + self.weight_config[n]["data_type"] = "bfloat" self.weight_config[n]["group_size"] = None self.weight_config[n]["sym"] = None + for k, v in self.weight_config.items(): + for m, n in v.items(): + if isinstance(n, torch.Tensor): + self.weight_config[k][m] = n.tolist() end_time = time.time() cost_time = end_time - start_time logger.info(f"quantization runtime {cost_time}") + return self.model, self.weight_config diff --git a/neural_compressor/adaptor/torch_utils/autoround/export.py b/neural_compressor/adaptor/torch_utils/autoround/export.py new file mode 100644 index 00000000000..d47ed343022 --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/autoround/export.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +from typing import Union + +try: + from neural_compressor.utils.utility import LazyImport + + torch = LazyImport("torch") + from neural_compressor.utils import logger +except: # pragma: no cover + import logging + + import torch + + logger = logging.getLogger() + + +def export_compressed_model( + model, + weight_config: Union[str, dict], + enable_full_range=False, + compression_dtype=torch.int32, + compression_dim=1, + scale_dtype=torch.float32, + device="cpu", + use_optimum_format=True, +): + """Convert Linear to WeightOnlyLinear for low memory inference. + + Args: + weight_config (str|dict): qconfig dict or Path of qconfig.json. + enable_full_range (bool, optional): Whether to leverage the full compression range + under symmetric quantization. Defaults to False. + compression_dtype (torch.Tensor, optional): The target dtype after comoression. + Defaults to torch.int32. + compression_dim (int, optional): Select from [0, 1], 0 is output channel, + 1 is input channel. Defaults to 1. + scale_dtype (torch.Tensor, optional): Use float32 or float16. + Defaults to torch.float32. + device (str, optional): choose device for compression. Defaults to cpu. + use_optimum_format (bool, optional): use the popular huggingface compression format. + 1: compression_dim: weight = 1, zeros = 0 and both are transposed. + 2: zeros -= 1 before compression. Why we need it? + 3: g_idx: use same number for one group instead of recording the channel order. + 4. parameter name changed, such as 'packed_weight' -> 'qweight'. + 5. zeros is always needed even for sym. + """ + from .autoround import get_module, quant_weight_w_scale, set_module + from .model_wrapper import WeightOnlyLinear + + compressed_model = copy.deepcopy(model) + if isinstance(weight_config, str): + with open(weight_config, "r") as f: + q_config = json.load(f) + else: + q_config = weight_config + for k, v in q_config.items(): + logger.info(f"Compressing {k} on device {device}") + if v["data_type"] == "float": + continue + else: + dtype = v["data_type"] + num_bits = v["bits"] + group_size = v["group_size"] + scheme = v["scheme"] + m = get_module(compressed_model, k) + fp_weight = m.weight.data + scale = torch.tensor(v["scale"], dtype=torch.float32) # may exist dtype dismatch problem + zp = None if scheme == "sym" else torch.tensor(v["zp"], dtype=torch.int32) + int_weight = quant_weight_w_scale(fp_weight, scale, zp, group_size) + int_weight = int_weight.type(torch.int32) + new_module = WeightOnlyLinear( + m.in_features, + m.out_features, + num_bits, + group_size, + dtype=dtype, + zp=zp is not None, + bias=m.bias is not None, + device=device, + use_optimum_format=True, + ) + new_module.pack(int_weight, scale, zp, m.bias) + set_module(compressed_model, k, new_module) + return compressed_model diff --git a/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py b/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py new file mode 100644 index 00000000000..bd73fddd94d --- /dev/null +++ b/neural_compressor/adaptor/torch_utils/autoround/model_wrapper.py @@ -0,0 +1,345 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Torch.nn.Module Class Definition.""" +import logging + +# Note: Do not import this file unless you have already imported torch, +# since the model classes inherit torch.nn.Module. +import math + +import torch +from packaging.version import Version +from torch.autograd import Function +from torch.nn import functional as F + +logger = logging.getLogger() + + +NF4 = [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, +] +FP4_BNB = [-12.0, -8.0, -6.0, -4.0, -3.0, -2.0, -0.0625, 0, 0.0625, 2.0, 3.0, 4.0, 6.0, 8.0, 12.0] +FP4_E2M1 = [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.0625, 0, 0.0625, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + +# the order is the same as float list, bit value range is [-7, 7] +# 1111 = -1, 1110 = -2, 1101= -3, ... + +NF4_BIT = [7, 1, 2, 3, 4, 5, 6, 0, -8, -7, -6, -5, -4, -3, -2, -1] +FP4_BNB_BIT = [-5, -6, -3, -4, -1, -2, -7, 0, 1, 6, 7, 4, 5, 2, 3] +FP4_E2M1_BIT = [-1, -2, -3, -4, -5, -6, -7, 0, 1, 2, 3, 4, 5, 6, 7] + +FLOAT_MAPPING = {"nf4": NF4, "fp4": FP4_BNB, "fp4_e2m1_bnb": FP4_BNB, "fp4_e2m1": FP4_E2M1} +INT_MAPPING = {"nf4": NF4_BIT, "fp4": FP4_BNB_BIT, "fp4_e2m1_bnb": FP4_BNB_BIT, "fp4_e2m1": FP4_E2M1_BIT} + + +def get_torch_version(): + try: + torch_version = torch.__version__.split("+")[0] + except ValueError as e: # pragma: no cover + assert False, "Got an unknown version of torch: {}".format(e) + version = Version(torch_version) + return version + + +PT_VERSION = get_torch_version().release + + +class WeightOnlyLinear(torch.nn.Module): + def __init__( + self, + in_features, + out_features, + bits, + groupsize, + dtype="int", + zp=False, + bias=False, + scale_dtype=torch.float32, + compression_dtype=torch.int32, + compression_dim=1, + device="cpu", + use_optimum_format=True, + ): + super().__init__() + self.use_optimum_format = use_optimum_format + self.dtype = dtype + if "int" not in self.dtype: # for nf4, fp4 + float_list = FLOAT_MAPPING[self.dtype] + int_list = INT_MAPPING[self.dtype] + self.int2float_mapping = {} + for k, v in zip(int_list, float_list): + self.int2float_mapping[k] = v + self.device = device + self.in_features = in_features + self.out_features = out_features + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else in_features + self.compression_dim = compression_dim + assert compression_dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ], "Only support torch.int8|16|32|64 as compressed dtype." + dtype_bits_mapping = {torch.int8: 8, torch.int16: 16, torch.int32: 32, torch.int64: 64} + self.compress_bits = dtype_bits_mapping[compression_dtype] + self.n_pack = self.compress_bits // self.bits + # K is input channel, N is output channel + assert compression_dim in [0, 1], ( + "Only support 0 or 1 as compression dimension, " + "0 is output channel, 1 is input channel." + ) + if self.use_optimum_format: + self.float_type = torch.float16 + self.compression_dtype = torch.int32 + self.register_buffer( + "scales", + torch.zeros( + (math.ceil(in_features / self.groupsize), out_features), + dtype=self.float_type, + ).to(device), + ) + self.scales = self.scales.T + self.register_buffer( + "qweight", + torch.zeros( + (math.ceil(in_features / self.n_pack), out_features), + dtype=self.compression_dtype, + ).to(device), + ) + self.qweight = self.qweight.T + self.register_buffer( + "qzeros", + torch.zeros( + (math.ceil(self.in_features / self.groupsize), math.ceil(self.out_features / self.n_pack)), + dtype=self.compression_dtype, + ).to(device), + ) + self.qzeros = self.qzeros.T + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) + else: + self.compression_dtype = compression_dtype + self.float_type = scale_dtype + self.register_buffer( + "scales", + torch.zeros( + (out_features, math.ceil(in_features / self.groupsize)), + dtype=self.float_type, + ).to(device), + ) + if compression_dim == 1: + self.register_buffer( + "qweight", + torch.zeros( + (out_features, math.ceil(in_features / self.n_pack)), + dtype=self.compression_dtype, + ).to(device), + ) + if zp: + self.register_buffer( + "qzeros", + torch.zeros( + (self.out_features, math.ceil(self.in_features / self.groupsize / self.n_pack)), + dtype=self.compression_dtype, + ).to(device), + ) + else: + self.register_buffer( + "qweight", + torch.zeros( + (math.ceil(out_features / self.n_pack), in_features), + dtype=self.compression_dtype, + ).to(device), + ) + if zp: + self.register_buffer( + "qzeros", + torch.zeros( + (math.ceil(self.out_features / self.n_pack), math.ceil(self.in_features / self.groupsize)), + dtype=self.compression_dtype, + ).to(device), + ) + if bias: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=self.float_type).to(device)) + else: + self.bias = None + + def pack(self, int_weight, scale, zp, bias): + int_weight = int_weight.to(self.device) + if self.use_optimum_format and zp is None: + # to avoid overflow + int_weight = int_weight.type(torch.int32) + shift_bias = 2 ** (self.bits - 1) + int_weight += shift_bias + zp = torch.zeros_like(scale, dtype=torch.uint8) + shift_bias + if bias is not None: + assert hasattr(self, "bias"), "bias is not set when initializing." + self.bias = bias.type(self.float_type).to(self.device) + assert scale.shape == self.scales.shape, "Scale shape is mismatched." + self.scales = scale.type(self.float_type).to(self.device) + if not self.use_optimum_format and self.compression_dim == 0: + int_weight = int_weight.T + self.qweight = self.qweight.T + origin_shape = int_weight.shape + target_shape = self.qweight.shape + assert origin_shape[0] == target_shape[0], "output channels mismatch, please check." + mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device) + + # pack weight + for j in range(target_shape[1]): + start = self.n_pack * j + end = self.n_pack * (j + 1) + tmp = int_weight[:, start:end].type(self.compression_dtype) + for e in range(tmp.shape[1]): + tmp[:, e] &= mask + tmp[:, e] = tmp[:, e] << (self.bits * e) + self.qweight[:, j] |= tmp[:, e] + if not self.use_optimum_format and self.compression_dim == 0: + self.qweight = self.qweight.T + + if zp is not None: + zp = zp.to(self.device) + if self.use_optimum_format: + zp -= 1 + if self.use_optimum_format or self.compression_dim == 0: + zp = zp.T + self.qzeros = self.qzeros.T + assert hasattr(self, "qzeros"), "zp is not set when initializing." + target_shape = self.qzeros.shape + for j in range(target_shape[1]): + start = self.n_pack * j + end = self.n_pack * (j + 1) + tmp = zp[:, start:end].type(self.compression_dtype) + for e in range(tmp.shape[1]): + tmp[:, e] &= mask + tmp[:, e] = tmp[:, e] << (self.bits * e) + self.qzeros[:, j] |= tmp[:, e] + if self.use_optimum_format or self.compression_dim == 0: + self.qzeros = self.qzeros.T + if self.use_optimum_format: + self.scales = self.scales.T + self.qweight = self.qweight.T + self.qzeros = self.qzeros.T + + def recover(self): + logger.debug(f"Recovering {self} weight") + scales = self.scales.T if self.use_optimum_format else self.scales + qweight = self.qweight.T if self.use_optimum_format else self.qweight + + device = scales.device + fp32_weight = torch.zeros(self.out_features, self.in_features, dtype=self.float_type).to(device) + mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(device) + if hasattr(self, "qzeros"): + weight_dtype = torch.uint8 + else: + weight_dtype = torch.int8 + # unpack weight + weight = torch.zeros(self.out_features, self.in_features, dtype=weight_dtype).to(device) + if not self.use_optimum_format and self.compression_dim == 0: + weight = weight.T + qweight = qweight.T + origin_shape = weight.shape + target_shape = qweight.shape + for j in range(target_shape[1]): + for e in range(self.n_pack): + index = j * self.n_pack + e + if index >= origin_shape[1]: + continue + tmp = qweight[:, j] + tmp = tmp << (self.compress_bits - self.bits * (e + 1)) + tmp = tmp >> self.compress_bits - self.bits + if weight_dtype == torch.uint8: + tmp &= mask # remove sign bit + weight[:, index] = tmp.type(weight_dtype) + if not self.use_optimum_format and self.compression_dim == 0: + weight = weight.T + if "int" not in self.dtype: + new_weight = torch.zeros(self.out_features, self.in_features).to(device) + for k, v in self.int2float_mapping.items(): + new_weight += torch.where(weight == k, v, 0) + weight = new_weight + # unpack zero_point + if hasattr(self, "qzeros"): + zp_dtype = self.compression_dtype # to avoid overflow when weight-zp + zp = torch.zeros(scales.shape, dtype=zp_dtype).to(device) + qzeros = self.qzeros.T if self.use_optimum_format else self.qzeros + if self.use_optimum_format or self.compression_dim == 0: + zp = zp.T + qzeros = qzeros.T + origin_shape = zp.shape + target_shape = qzeros.shape + for j in range(target_shape[1]): + for e in range(self.n_pack): + index = j * self.n_pack + e + if index >= origin_shape[1]: + continue + tmp = qzeros[:, j] + tmp = tmp << (self.compress_bits - self.bits * (e + 1)) + tmp = tmp >> self.compress_bits - self.bits + tmp &= mask + zp[:, index] = tmp.type(zp_dtype) + if self.use_optimum_format or self.compression_dim == 0: + zp = zp.T + if self.use_optimum_format: + # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1 + zp += 1 + zp = torch.where(zp > (2**self.bits - 1), 0, zp) + # recover fp32 weight with int_weight, scale, and zero_point + for idx in range(self.in_features): + g_idx = idx // self.groupsize + fp32_weight[:, idx] = (weight[:, idx] - zp[:, g_idx]) * scales[:, g_idx] + else: + # recover fp32 weight with int_weight, scale + for idx in range(self.in_features): + g_idx = idx // self.groupsize + fp32_weight[:, idx] = weight[:, idx] * scales[:, g_idx] + return fp32_weight + + def forward(self, input): + weight = self.recover() + device = self.scales.device + if weight.dtype == torch.float16 and device.type == "cpu": + weight = weight.float() + self.bias = self.bias.float() if self.bias is not None else None + input = input.type(weight.dtype) + return F.linear(input, weight, self.bias) + + def extra_repr(self) -> str: + tmp_str = "in_features={}, out_features={}, bits={}, group_size={}, bias={}".format( + self.in_features, + self.out_features, + self.bits, + self.groupsize, + self.bias is not None, + ) + if self.use_optimum_format: + tmp_str += ", use_optimum_format=True" + return tmp_str diff --git a/test/adaptor/pytorch_adaptor/test_autoround.py b/test/adaptor/pytorch_adaptor/test_autoround.py new file mode 100644 index 00000000000..2081729f7ad --- /dev/null +++ b/test/adaptor/pytorch_adaptor/test_autoround.py @@ -0,0 +1,87 @@ +import copy +import os +import shutil +import unittest + +import torch +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer + +from neural_compressor.adaptor.torch_utils.autoround import ( + AutoAdamRound, + AutoOPTRound, + AutoRound, + export_compressed_model, +) + + +class SimpleDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.randn([1, 30]) + + +class LLMDataLoader: + def __init__(self): + self.batch_size = 1 + + def __iter__(self): + for i in range(2): + yield torch.ones([1, 10], dtype=torch.long) + + +class TestPytorchWeightOnlyAdaptor(unittest.TestCase): + approach = "weight_only" + + @classmethod + def setUpClass(self): + self.dataloader = SimpleDataLoader() + self.gptj = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + torchscript=True, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True + ) + self.gptj_no_jit = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + ) + self.llm_dataloader = LLMDataLoader() + self.lm_input = torch.ones([1, 10], dtype=torch.long) + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_RTN_int_quant(self): + model = copy.deepcopy(self.gptj) + out1 = model(self.lm_input) + round = AutoRound + optq_1 = round(model, self.tokenizer, n_samples=20, device="cpu", amp=False, seqlen=10, iters=50) + q_model, weight_config1 = optq_1.quantize() + compressed_model = export_compressed_model(q_model, weight_config1) + out2 = model(self.lm_input) + out3 = compressed_model(self.lm_input) + self.assertTrue(torch.all(torch.isclose(out1[0], out2[0], atol=1e-1))) + self.assertFalse(torch.all(out1[0] == out2[0])) + self.assertTrue(torch.all(torch.isclose(out2[0], out3[0], atol=1e-3))) + self.assertTrue("transformer.h.0.attn.k_proj.qzeros" in compressed_model.state_dict().keys()) + + # model = copy.deepcopy(self.gptj) + # out6 = model(self.lm_input) + # optq_2 = round(model, self.tokenizer, n_samples=20, amp=False, seqlen=10) + # q_model, weight_config2 = optq_2.quantize() + # out4 = q_model(self.lm_input) + # out5 = model(self.lm_input) + + # self.assertTrue(torch.all(out1[0] == out6[0])) + # self.assertTrue(torch.all(out4[0] == out5[0])) + # self.assertTrue(torch.all(torch.isclose(out6[0], out5[0], atol=1e-1))) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/algorithm/test_autoround.py b/test/algorithm/test_autoround.py index 499ce0b282e..dc78186af08 100644 --- a/test/algorithm/test_autoround.py +++ b/test/algorithm/test_autoround.py @@ -23,12 +23,12 @@ def setUpClass(self): @classmethod def test_signround(self): - round = AutoRound(self.model, self.tokenizer, device="cpu", iters=5, seqlen=8, n_samples=1, group_size=7) + round = AutoRound(self.model, self.tokenizer, device="cpu", iters=5, seqlen=8, n_samples=8, group_size=7) round.quantize() @classmethod def test_Adamround(self): - round = AutoOPTRound(self.model, self.tokenizer, device="cpu", iters=2, seqlen=8, n_samples=1, scheme="sym") + round = AutoOPTRound(self.model, self.tokenizer, device="cpu", iters=2, seqlen=8, n_samples=8, scheme="sym") round.quantize() @@ -49,12 +49,12 @@ def setUpClass(self): @classmethod def test_signround(self): - round = AutoRound(self.model, self.tokenizer, device="cpu", iters=5, seqlen=8, n_samples=1, n_blocks=2) + round = AutoRound(self.model, self.tokenizer, device="cpu", iters=5, seqlen=8, n_samples=8, n_blocks=2) round.quantize() @classmethod def test_Adamround(self): - round = AutoAdamRound(self.model, self.tokenizer, device="cpu", iters=5, seqlen=8, n_samples=1) + round = AutoAdamRound(self.model, self.tokenizer, device="cpu", iters=5, seqlen=8, n_samples=8) round.quantize()