From ccf142006e10717fa9749d40346c38e5249c453d Mon Sep 17 00:00:00 2001 From: Jincheng Miao Date: Fri, 22 Dec 2023 16:17:30 -0500 Subject: [PATCH 1/2] [tools/quantization] Add a guide for xFT with AWQ+AutoGPTQ on Llama2-7b --- tools/quantization/README.md | 85 +++ tools/quantization/autogptq-cpu.patch | 266 +++++++ tools/quantization/awq-cpu.patch | 658 ++++++++++++++++++ tools/{ => quantization}/gptq/README.md | 0 tools/{ => quantization}/gptq/gptq.py | 0 tools/{ => quantization}/gptq/quantizer.py | 0 .../gptq/run_model_quant.py | 0 tools/quantization/llama2_acc_xft.py | 198 ++++++ tools/quantization/llama_autogptq_convert.py | 281 ++++++++ 9 files changed, 1488 insertions(+) create mode 100644 tools/quantization/README.md create mode 100644 tools/quantization/autogptq-cpu.patch create mode 100644 tools/quantization/awq-cpu.patch rename tools/{ => quantization}/gptq/README.md (100%) rename tools/{ => quantization}/gptq/gptq.py (100%) rename tools/{ => quantization}/gptq/quantizer.py (100%) rename tools/{ => quantization}/gptq/run_model_quant.py (100%) create mode 100644 tools/quantization/llama2_acc_xft.py create mode 100644 tools/quantization/llama_autogptq_convert.py diff --git a/tools/quantization/README.md b/tools/quantization/README.md new file mode 100644 index 00000000..b515ae5a --- /dev/null +++ b/tools/quantization/README.md @@ -0,0 +1,85 @@ +# xFT quantization with AWQ + AutoGPTQ + +A Guide for xFT quantization with AWQ + AutoGPTQ on llama2-7b. + +## description +xFT can use AWQ and AutoGPTQ to quantize models, and these two quantization +techniques can be used individually or together. +Typically, AWQ and AutoGPTQ only works on GPU. We have some hacks to make it +running on CPU. +AWQ will perform a activation-aware quatization. It will search the scale and +zero for weights. We have some modification to dump the AWQ searched model. +Then this model can be quantized again by AutoGPTQ. + + ┌─────────┐ Float┌──────────┐Int ┌─────────────────┐ + │ Model ├─────►│ AWQ ├───►│ quantized model │ + └─────────┘ └────┬─────┘ └─────────────────┘ + │ + │ + │ + │Float Better accuracy + ┌────▼─────┐Int ┌─────────────────┐ + │ AutoGPTQ ├───►│ quantized model │ + └──────────┘ └─────────────────┘ + +In our test, AWQ + AutoGPTQ will improve the accuracy of Llama2-7B on Lambada. + +## prepare AWQ +clone the llm-awq source code. +```bash +cd 3rdparty +git clone https://github.com/mit-han-lab/llm-awq +cd llm-awq +git reset --hard 398b9661415e6a1f89f65c393a13b7f7047b582a +``` + +## AWQ on CPU +The llm-awq is targeted for GPU. We have a patch to make it works on CPU. +```bash +git apply ../tools/awq/awq-cpu.patch +pip install -e . +``` + +## quantize model (llama2-7b as an example) +run awq search on llama2-7b to get scales, zeros and dump model to `awq_model` +```bash +python -m awq.entry --model_path --w_bit 8 --q_group_size 128 --run_awq --dump_awq awq_cache/llama2_7b_w8.pt --dump_model awq_model/ +``` + +# AutoGPTQ +clone AutoGPTQ source code +```bash +cd 3rdparty +git clone https://github.com/PanQiWei/AutoGPTQ.git +cd AutoGPTQ +git reset --hard e4b2493733d69a6e60e22cebc64b619be39feb0e +``` +## AutoGPTQ on CPU +install AutoGPTQ from source +```bash +git apply ../tools/awq/autogptq-cpu.patch +pip install -v . +``` + +## quantize model with awq searched model +change the `pretrained_model_dir` to `awq_model` in examples/quantization/basic_usage_wikitext2.py +```bash +cd examples/quantization +python basic_usage_wikitext2.py +``` +After that the GPTQ quantized model would be stored to `quantized_model_dir` according to this script. + +## convert quantized model into xFT IR +set `MODEL_PATH` to `quantized_model_dir` from AutoGPTQ output. +set `IR_PATH` to converted xFT IR path. +```bash +python llama_autogptq_convert.py -in_file=${MODEL_PATH} -saved_dir=${IR_PATH} -processes 8 +``` + +## check the accuracy on lambada test set +set `TOKEN_PATH` to llama2-7b path. +set `IR_PATH` to converted xFT IR path. +```bash +cd tools/awq/ +python llama2_acc_xft.py --dataset_path lambada_test.jsonl --token_path=${TOKEN_PATH} --model_path=${IR_PATH} --show_progress --dtype=int8 +``` diff --git a/tools/quantization/autogptq-cpu.patch b/tools/quantization/autogptq-cpu.patch new file mode 100644 index 00000000..22d42be7 --- /dev/null +++ b/tools/quantization/autogptq-cpu.patch @@ -0,0 +1,266 @@ +diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py +index 84e92bf..34f03d5 100644 +--- a/auto_gptq/modeling/_base.py ++++ b/auto_gptq/modeling/_base.py +@@ -242,10 +242,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): + class LayerHijacker(nn.Module): + """hijack layer's forward pass to cache data""" + +- def __init__(self, m, device): ++ #def __init__(self, m, device): ++ def __init__(self, m): + super().__init__() + self.module = m +- self.data_device = device if cache_examples_on_gpu else CPU ++ #self.data_device = device if cache_examples_on_gpu else CPU ++ self.data_device = CPU + + def forward(self, inp=None, **kwargs): + if inp is None: # some models use all key-value arguments in forward pass call +@@ -275,42 +277,44 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): + layers = get_module_by_name_prefix(self.model, self.layers_block_name) + + force_layer_back_to_cpu = False +- if get_device(layers[0]) == CPU: +- layers[0] = layers[0].to(CUDA_0) +- force_layer_back_to_cpu = True ++ #if get_device(layers[0]) == CPU: ++ # layers[0] = layers[0].to(CUDA_0) ++ # force_layer_back_to_cpu = True + +- cur_layer_device = get_device(layers[0]) +- ori_outside_layer_module_devices = {} +- for module_name in self.outside_layer_modules: +- module = get_module_by_name_prefix(self.model, module_name) ++ #cur_layer_device = get_device(layers[0]) ++ #ori_outside_layer_module_devices = {} ++ #for module_name in self.outside_layer_modules: ++ # module = get_module_by_name_prefix(self.model, module_name) + +- if module is None: +- continue ++ # if module is None: ++ # continue + +- ori_outside_layer_module_devices[module_name] = get_device(module) +- if module is not None: +- move_to_device(module, cur_layer_device) ++ # ori_outside_layer_module_devices[module_name] = get_device(module) ++ # if module is not None: ++ # move_to_device(module, cur_layer_device) + + # get inputs for first layer +- layers[0] = LayerHijacker(layers[0], cur_layer_device) ++ #layers[0] = LayerHijacker(layers[0], cur_layer_device) ++ layers[0] = LayerHijacker(layers[0]) + for example in examples: + for k, v in example.items(): + if len(v.shape) == 1: + v = v.unsqueeze(0) +- example[k] = move_to_device(v, cur_layer_device) ++ #example[k] = move_to_device(v, cur_layer_device) ++ example[k] = v + try: + self.model(**example) + except ValueError: + pass + layers[0] = layers[0].module + +- move_to_device(layers[0], CPU if force_layer_back_to_cpu else cur_layer_device) +- for module_name in self.outside_layer_modules: +- module = get_module_by_name_prefix(self.model, module_name) +- if module is not None: +- move_to_device(module, ori_outside_layer_module_devices[module_name]) ++ #move_to_device(layers[0], CPU if force_layer_back_to_cpu else cur_layer_device) ++ #for module_name in self.outside_layer_modules: ++ # module = get_module_by_name_prefix(self.model, module_name) ++ # if module is not None: ++ # move_to_device(module, ori_outside_layer_module_devices[module_name]) + +- torch.cuda.empty_cache() ++ #torch.cuda.empty_cache() + + # resize attention mask and position ids for some special models + attention_masks = self._resize_attention_mask(attention_masks) +@@ -323,11 +327,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): + for i in range(len(layers)): + logger.info(f"Start quantizing layer {i + 1}/{len(layers)}") + layer = layers[i] +- force_layer_back_to_cpu = False +- if get_device(layer) == CPU: +- move_to_device(layer, CUDA_0) +- force_layer_back_to_cpu = True +- cur_layer_device = get_device(layer) ++ #force_layer_back_to_cpu = False ++ #if get_device(layer) == CPU: ++ # move_to_device(layer, CUDA_0) ++ # force_layer_back_to_cpu = True ++ #cur_layer_device = get_device(layer) + + full = find_layers(layer) + for names in inside_layer_modules: +@@ -352,17 +356,21 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(num_batches): +- layer_input = move_to_device(layer_inputs[j], cur_layer_device) +- layer_attention_mask = move_to_device(attention_masks[j], cur_layer_device) ++ #layer_input = move_to_device(layer_inputs[j], cur_layer_device) ++ layer_input = layer_inputs[j] ++ #layer_attention_mask = move_to_device(attention_masks[j], cur_layer_device) ++ layer_attention_mask = attention_masks[j] + additional_layer_inputs = { + "attention_mask": layer_attention_mask + } +- layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device) ++ layer_position_ids = None if not position_ids else position_ids[j] ++ #layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device) + if layer_position_ids is not None: + additional_layer_inputs["position_ids"] = layer_position_ids + for k, v in layer_input_kwargs[j].items(): + if isinstance(v, torch.Tensor): +- additional_layer_inputs[k] = move_to_device(v, cur_layer_device) ++ additional_layer_inputs[k] = v ++ #additional_layer_inputs[k] = move_to_device(v, cur_layer_device) + else: + additional_layer_inputs[k] = v + layer(layer_input, **additional_layer_inputs) +@@ -378,39 +386,49 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): + static_groups=self.quantize_config.static_groups + ) + quantizers[f'{self.layers_block_name}.{i}.{name}'] = ( +- gptq[name].quantizer.to(CPU if force_layer_back_to_cpu else cur_layer_device), +- move_to_device(scale, CPU if force_layer_back_to_cpu else cur_layer_device), +- move_to_device(zero, CPU if force_layer_back_to_cpu else cur_layer_device), +- move_to_device(g_idx, CPU if force_layer_back_to_cpu else cur_layer_device) ++ #gptq[name].quantizer.to(CPU if force_layer_back_to_cpu else cur_layer_device), ++ #move_to_device(scale, CPU if force_layer_back_to_cpu else cur_layer_device), ++ #move_to_device(zero, CPU if force_layer_back_to_cpu else cur_layer_device), ++ #move_to_device(g_idx, CPU if force_layer_back_to_cpu else cur_layer_device) ++ gptq[name].quantizer, ++ scale, ++ zero, ++ g_idx + ) + gptq[name].free() + + for j in range(num_batches): +- layer_input = move_to_device(layer_inputs[j], cur_layer_device) +- layer_attention_mask = move_to_device(attention_masks[j], cur_layer_device) ++ #layer_input = move_to_device(layer_inputs[j], cur_layer_device) ++ layer_input = layer_inputs[j] ++ #layer_attention_mask = move_to_device(attention_masks[j], cur_layer_device) ++ layer_attention_mask = attention_masks[j] + additional_layer_inputs = { + "attention_mask": layer_attention_mask + } +- layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device) ++ #layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device) ++ layer_position_ids = None if not position_ids else position_ids[j] + if layer_position_ids is not None: + additional_layer_inputs["position_ids"] = layer_position_ids + for k, v in layer_input_kwargs[j].items(): + if isinstance(v, torch.Tensor): +- additional_layer_inputs[k] = move_to_device(v, cur_layer_device) ++ #additional_layer_inputs[k] = move_to_device(v, cur_layer_device) ++ additional_layer_inputs[k] = v + else: + additional_layer_inputs[k] = v +- layer_output = move_to_device( +- layer(layer_input, **additional_layer_inputs)[0], +- cur_layer_device if cache_examples_on_gpu else CPU +- ) ++ #layer_output = move_to_device( ++ # layer(layer_input, **additional_layer_inputs)[0], ++ # cur_layer_device if cache_examples_on_gpu else CPU ++ #) ++ layer_output = layer(layer_input, **additional_layer_inputs)[0] + layer_outputs.append(layer_output) + +- layers[i] = move_to_device(layer, CPU if force_layer_back_to_cpu else cur_layer_device) ++ #layers[i] = move_to_device(layer, CPU if force_layer_back_to_cpu else cur_layer_device) ++ layers[i] = layer + del layer + del gptq + del layer_inputs + layer_inputs, layer_outputs = layer_outputs, [] +- torch.cuda.empty_cache() ++ #torch.cuda.empty_cache() + + pack_model( + model=self.model, +@@ -434,6 +452,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): + + @property + def device(self): ++ return CPU ++ + if not self.hf_device_map: + return self.model.device + else: +@@ -603,8 +623,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): + ): + """load un-quantized pretrained model to cpu""" + +- if not torch.cuda.is_available(): +- raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") ++ #if not torch.cuda.is_available(): ++ # raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") + + def skip(*args, **kwargs): + pass +@@ -669,7 +689,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): + model_init_kwargs["device_map"] = None + model_init_kwargs["low_cpu_mem_usage"] = False + +- torch.cuda.empty_cache() ++ #torch.cuda.empty_cache() + + merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) +diff --git a/auto_gptq/modeling/_utils.py b/auto_gptq/modeling/_utils.py +index 524ac86..e1a71ce 100644 +--- a/auto_gptq/modeling/_utils.py ++++ b/auto_gptq/modeling/_utils.py +@@ -365,7 +365,7 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in + if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2": + device = submodule.qweight.device + submodule.post_init(temp_dq = model.device_tensors[device]) +- torch.cuda.empty_cache() ++ #torch.cuda.empty_cache() + + return model + +diff --git a/auto_gptq/quantization/gptq.py b/auto_gptq/quantization/gptq.py +index 71e66b0..a1912d1 100644 +--- a/auto_gptq/quantization/gptq.py ++++ b/auto_gptq/quantization/gptq.py +@@ -12,8 +12,8 @@ from .quantizer import Quantizer + + logger = getLogger(__name__) + +-torch.backends.cuda.matmul.allow_tf32 = False +-torch.backends.cudnn.allow_tf32 = False ++#torch.backends.cuda.matmul.allow_tf32 = False ++#torch.backends.cudnn.allow_tf32 = False + + + class GPTQ: +@@ -160,7 +160,7 @@ class GPTQ: + logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + logger.debug(torch.sum(Losses)) + +- torch.cuda.synchronize() ++ #torch.cuda.synchronize() + logger.info(f'duration: {(time.time() - tick)}') + logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}') + +@@ -194,7 +194,7 @@ class GPTQ: + self.H = None + self.Losses = None + self.Trace = None +- torch.cuda.empty_cache() ++ #torch.cuda.empty_cache() + + + __all__ = ["GPTQ"] diff --git a/tools/quantization/awq-cpu.patch b/tools/quantization/awq-cpu.patch new file mode 100644 index 00000000..e450f6b8 --- /dev/null +++ b/tools/quantization/awq-cpu.patch @@ -0,0 +1,658 @@ +diff --git a/README.md b/README.md +index 7a742fe..6e44d6a 100644 +--- a/README.md ++++ b/README.md +@@ -154,3 +154,20 @@ If you find AWQ useful or relevant to your research, please kindly cite our pape + + [LLaVA: Large Language and Vision Assistant](https://github.com/haotian-liu/LLaVA) + ++## Changqing ++ ++```bash ++$ python -m awq.entry --model_path /data/opt-1.3b-hf/ --w_bit 8 --run_awq --dump_awq awq_cache/opt-1.3b-w8.pt ++Quantization config: {'zero_point': True, 'q_group_size': -1} ++* Building model /data/opt-1.3b-hf/ ++ * Split into 60 blocks ++Running AWQ...: 0%| | 0/24 [00:00>> awq_results["scale"]:', awq_results["scale"]) + + # Clear GPU memory +- torch.cuda.empty_cache() ++ # torch.cuda.empty_cache() + + if mse_range: + clip_list = auto_clip_block(layer, +@@ -160,12 +161,14 @@ def run_awq( + apply_clip(layer, clip_list) + # append prefix to make names global + awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".") ++ print('>>> awq_results["clip"]: ', awq_results["clip"]) ++ print('>>> awq_results["clip"][0][1].shape: ', awq_results["clip"][0][1].shape) + +- layer = layer.cpu() ++ # layer = layer.cpu() + # Haotian: check activation replacement + del input_feat + gc.collect() +- torch.cuda.empty_cache() ++ # torch.cuda.empty_cache() + + return awq_results + +diff --git a/awq/quantize/qmodule.py b/awq/quantize/qmodule.py +index 94ca4b5..08fa907 100644 +--- a/awq/quantize/qmodule.py ++++ b/awq/quantize/qmodule.py +@@ -1,7 +1,7 @@ + import math + import torch + import torch.nn as nn +-import awq_inference_engine # with CUDA kernels ++# import awq_inference_engine # with CUDA kernels + + + def make_divisible(c, divisor): +@@ -37,8 +37,8 @@ class WQLinear(nn.Module): + def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): + super().__init__() + +- if w_bit not in [4]: +- raise NotImplementedError("Only 4-bit are supported for now.") ++ if w_bit not in [4, 8]: ++ raise NotImplementedError("Only 4-bit and 8-bit are supported for now.") + + self.in_features = in_features + self.out_features = out_features +@@ -51,8 +51,8 @@ class WQLinear(nn.Module): + pack_num = (32 // self.w_bit) + # TODO (Haotian): a function for buffer shape calculation + self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev)) +- self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev)) +- self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev)) ++ self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size, pack_num)), dtype=torch.int32, device=dev)) ++ self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size, pack_num) * pack_num), dtype=torch.float16, device=dev)) + if bias: + self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev)) + else: +@@ -70,7 +70,7 @@ class WQLinear(nn.Module): + + pack_num = 32 // awq_linear.w_bit + qscales = torch.zeros( +- (scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num), ++ (scales.shape[0], calculate_zeros_width(linear.in_features, group_size, pack_num) * pack_num), + dtype=torch.float16, + device=scales.device + ) +@@ -92,8 +92,10 @@ class WQLinear(nn.Module): + if awq_linear.w_bit == 4: + # order_map = [0, 2, 4, 6, 1, 3, 5, 7] + order_map = [0, 1, 2, 3, 4, 5, 6, 7] ++ elif awq_linear.w_bit == 8: ++ order_map = [0, 1, 2, 3] + else: +- raise NotImplementedError("Only 4-bit are supported for now.") ++ raise NotImplementedError("Only 4-bit and 8-bit are supported for now.") + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * awq_linear.w_bit) +@@ -101,7 +103,7 @@ class WQLinear(nn.Module): + + zeros = zeros.to(dtype=torch.int32) + qzeros = torch.zeros( +- (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)), ++ (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size, pack_num)), + dtype=torch.int32, + device=zeros.device, + ) +@@ -110,8 +112,10 @@ class WQLinear(nn.Module): + if awq_linear.w_bit == 4: + # order_map = [0, 2, 4, 6, 1, 3, 5, 7] + order_map = [0, 1, 2, 3, 4, 5, 6, 7] ++ elif awq_linear.w_bit == 8: ++ order_map = [0, 1, 2, 3] + else: +- raise NotImplementedError("Only 4-bit are supported for now.") ++ raise NotImplementedError("Only 4-bit and 8-bit are supported for now.") + for i in range(pack_num): + if col * pack_num + order_map[i] >= zeros.shape[1]: + continue +@@ -124,10 +128,11 @@ class WQLinear(nn.Module): + def forward(self, x): + out_shape = x.shape[:-1] + (self.out_features, ) + inputs = x.reshape(-1, x.shape[-1]) +- if inputs.shape[0] > 8: +- out = awq_inference_engine.gemm_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters) +- else: +- out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size) ++ # if inputs.shape[0] > 8: ++ # out = awq_inference_engine.gemm_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters) ++ # else: ++ # out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size) ++ raise NotImplementedError("Need to replace cuda kernel with torch cpu kernel.") + out = out + self.bias if self.bias is not None else out + #print(out) + #assert 0 +diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py +index 0437fc0..4dd2861 100644 +--- a/awq/quantize/quantizer.py ++++ b/awq/quantize/quantizer.py +@@ -54,6 +54,8 @@ def pseudo_quantize_tensor(w, n_bit=8, + if q_group_size > 0: + assert org_w_shape[-1] % q_group_size == 0 + w = w.reshape(-1, q_group_size) ++ else: ++ w = w.reshape(-1, w.shape[-1]) + assert w.dim() == 2 + if zero_point: + max_val = w.amax(dim=1, keepdim=True) +@@ -98,9 +100,9 @@ def pseudo_quantize_model_weight( + for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."): + named_linears = get_named_linears(layers[i]) + for n, m in named_linears.items(): +- m.cuda() ++ # m.cuda() + m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config) +- m.cpu() ++ # m.cpu() + + + @torch.no_grad() +@@ -125,17 +127,17 @@ def real_quantize_model_weight( + q_linear.to(next(layer.parameters()).device) + set_op_by_name(layer, name, q_linear) + else: +- module.cuda() ++ # module.cuda() + module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config) + # scales = scales.t().contiguous() + # zeros = zeros.t().contiguous() + q_linear = WQLinear.from_linear( + module, w_bit, q_config['q_group_size'], False, scales, zeros) +- module.cpu() ++ # module.cpu() + q_linear.to(next(layer.parameters()).device) + set_op_by_name(layer, name, q_linear) +- torch.cuda.empty_cache() ++ # torch.cuda.empty_cache() + gc.collect() + +- torch.cuda.empty_cache() ++ # torch.cuda.empty_cache() + gc.collect() +diff --git a/awq/utils/calib_data.py b/awq/utils/calib_data.py +index 02a75c4..1f5fd74 100644 +--- a/awq/utils/calib_data.py ++++ b/awq/utils/calib_data.py +@@ -1,10 +1,12 @@ + import torch + from datasets import load_dataset + ++import os ++os.environ['CURL_CA_BUNDLE'] = '' + + def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): + if data == "pileval": +- dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") ++ dataset = load_dataset("./mit-han-lab/pile-val-backup", split="validation") + else: + raise NotImplementedError + dataset = dataset.shuffle(seed=42) +diff --git a/awq/utils/lm_eval_adaptor.py b/awq/utils/lm_eval_adaptor.py +index a0170bf..03dae2e 100644 +--- a/awq/utils/lm_eval_adaptor.py ++++ b/awq/utils/lm_eval_adaptor.py +@@ -65,7 +65,8 @@ class LMEvalAdaptor(BaseLM): + + @property + def device(self): +- return "cuda" ++ #return "cuda" ++ return "cpu" + + def tok_encode(self, string: str): + return self.tokenizer.encode(string, add_special_tokens=False) +diff --git a/tinychat/benchmark.py b/tinychat/benchmark.py +index 53d56c2..317622e 100644 +--- a/tinychat/benchmark.py ++++ b/tinychat/benchmark.py +@@ -99,7 +99,7 @@ def main(): + print("Benchmarking...") + with torch.inference_mode(): + for i in range(gen_length): +- torch.cuda.synchronize() ++ # torch.cuda.synchronize() + t_st = time.time() + + if i == 0: +@@ -109,7 +109,7 @@ def main(): + out = model(inputs, start_pos=start_pos) + start_pos += out.shape[1] + +- torch.cuda.synchronize() ++ # torch.cuda.synchronize() + t_ed = time.time() + time_lis.append(t_ed - t_st) + token = out[:, -1].max(1)[1].unsqueeze(1) +diff --git a/tinychat/models/falcon.py b/tinychat/models/falcon.py +index e005c38..8364ca7 100644 +--- a/tinychat/models/falcon.py ++++ b/tinychat/models/falcon.py +@@ -105,7 +105,7 @@ class FalconAttentionFused(nn.Module): + self.head_dim, + ) + ) +- .cuda() ++ # .cuda() + .half() + ) # added to half + # 8: pack 8 fp16 in FT, if fp32 then use 4 +@@ -119,7 +119,7 @@ class FalconAttentionFused(nn.Module): + 8, + ) + ) +- .cuda() ++ # .cuda() + .half() + ) # added to half + +diff --git a/tinychat/models/llama.py b/tinychat/models/llama.py +index 19f7d5f..c4ce680 100644 +--- a/tinychat/models/llama.py ++++ b/tinychat/models/llama.py +@@ -118,7 +118,7 @@ class LlamaAttentionFused(nn.Module): + self.head_dim, + ) + ) +- .cuda() ++ # .cuda() + .half() + ) # added to half + # 8: pack 8 fp16 in FT, if fp32 then use 4 +@@ -133,7 +133,7 @@ class LlamaAttentionFused(nn.Module): + 8, + ) + ) +- .cuda() ++ # .cuda() + .half() + ) # added to half + +diff --git a/tinychat/models/mpt.py b/tinychat/models/mpt.py +index 7e27b15..5610b81 100644 +--- a/tinychat/models/mpt.py ++++ b/tinychat/models/mpt.py +@@ -123,7 +123,7 @@ class MPTAttentionFused(nn.Module): + self.head_dim, + ) + ) +- .cuda() ++ # .cuda() + .half() + ) # added to half + # 8: pack 8 fp16 in FT, if fp32 then use 4 +@@ -137,7 +137,7 @@ class MPTAttentionFused(nn.Module): + 8, + ) + ) +- .cuda() ++ # .cuda() + .half() + ) # added to half + +diff --git a/tinychat/modules/fused_attn.py b/tinychat/modules/fused_attn.py +index 61ca8ce..ceadc8c 100644 +--- a/tinychat/modules/fused_attn.py ++++ b/tinychat/modules/fused_attn.py +@@ -305,7 +305,7 @@ def make_quant_attn(model, dev): + """ + Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. + """ +- model = model.cpu() ++ # model = model.cpu() + for name, m in model.named_modules(): + if not m.__class__.__name__ in ["LlamaAttention", "LlamaAttentionFused"]: + continue +@@ -366,5 +366,5 @@ def make_quant_attn(model, dev): + # print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") + setattr(parent, child_name, attn) + gc.collect() +- torch.cuda.empty_cache() ++ # torch.cuda.empty_cache() + model = model.to(dev) +diff --git a/tinychat/modules/fused_mlp.py b/tinychat/modules/fused_mlp.py +index 87e993e..4ab0528 100644 +--- a/tinychat/modules/fused_mlp.py ++++ b/tinychat/modules/fused_mlp.py +@@ -2,7 +2,7 @@ import numpy as np + import torch + import torch.nn as nn + import torch.nn.functional as F +-from torch.cuda.amp import custom_bwd, custom_fwd ++# from torch.cuda.amp import custom_bwd, custom_fwd + from transformers.models.llama.modeling_llama import LlamaMLP + + import awq_inference_engine +diff --git a/tinychat/stream_generators/falcon_stream_gen.py b/tinychat/stream_generators/falcon_stream_gen.py +index 15e66a9..891aa71 100644 +--- a/tinychat/stream_generators/falcon_stream_gen.py ++++ b/tinychat/stream_generators/falcon_stream_gen.py +@@ -143,4 +143,4 @@ def FalconStreamGenerator( + + # clean + gc.collect() +- torch.cuda.empty_cache() ++ # torch.cuda.empty_cache() +diff --git a/tinychat/stream_generators/stream_gen.py b/tinychat/stream_generators/stream_gen.py +index 424ef10..5c45700 100644 +--- a/tinychat/stream_generators/stream_gen.py ++++ b/tinychat/stream_generators/stream_gen.py +@@ -62,7 +62,7 @@ def StreamGenerator( + max_new_tokens = gen_params.n_predict + start_pos = 0 + for i in range(max_new_tokens): +- torch.cuda.synchronize() ++ # torch.cuda.synchronize() + t_st = time.time() + + if i == 0: +@@ -91,7 +91,7 @@ def StreamGenerator( + out = model(inputs, start_pos=start_pos) + start_pos += out.shape[1] + logits = out +- torch.cuda.synchronize() ++ # torch.cuda.synchronize() + t_ed = time.time() + + # Processing the logits +@@ -185,6 +185,6 @@ def StreamGenerator( + + del past_key_values, out + gc.collect() +- torch.cuda.empty_cache() ++ # torch.cuda.empty_cache() + + # return context_tokens, context_time, total_tokens, generation_time_list +diff --git a/tinychat/utils/tune.py b/tinychat/utils/tune.py +index 2ca231c..cf27a17 100644 +--- a/tinychat/utils/tune.py ++++ b/tinychat/utils/tune.py +@@ -19,10 +19,10 @@ def _time_module(module, inputs, measure_iters=1000): + for i in range(measure_iters): + module(inputs) + for i in range(measure_iters): +- torch.cuda.synchronize() ++ # torch.cuda.synchronize() + st = time.time() + module(inputs) +- torch.cuda.synchronize() ++ # torch.cuda.synchronize() + ed = time.time() + time_lis.append((ed - st)) + return np.median(time_lis) diff --git a/tools/gptq/README.md b/tools/quantization/gptq/README.md similarity index 100% rename from tools/gptq/README.md rename to tools/quantization/gptq/README.md diff --git a/tools/gptq/gptq.py b/tools/quantization/gptq/gptq.py similarity index 100% rename from tools/gptq/gptq.py rename to tools/quantization/gptq/gptq.py diff --git a/tools/gptq/quantizer.py b/tools/quantization/gptq/quantizer.py similarity index 100% rename from tools/gptq/quantizer.py rename to tools/quantization/gptq/quantizer.py diff --git a/tools/gptq/run_model_quant.py b/tools/quantization/gptq/run_model_quant.py similarity index 100% rename from tools/gptq/run_model_quant.py rename to tools/quantization/gptq/run_model_quant.py diff --git a/tools/quantization/llama2_acc_xft.py b/tools/quantization/llama2_acc_xft.py new file mode 100644 index 00000000..00076be7 --- /dev/null +++ b/tools/quantization/llama2_acc_xft.py @@ -0,0 +1,198 @@ +import json +import tqdm +import torch +import transformers +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from typing import Optional, Tuple, Union, Dict +import argparse +import time + +import importlib.util + +xft_spec = importlib.util.find_spec("xfastertransformer") + +if xft_spec is None: + import sys + + sys.path.append("/root/xFasterTransformer/src") + print("[INFO] xfastertransformer is not installed in pip, using source code.") +else: + print("[INFO] xfastertransformer is installed, using pip installed package.") + +import xfastertransformer + +class LambadaDataset(torch.utils.data.Dataset): + """ LAMBADA dataset class. """ + + def __init__(self, + path: str, + tokenizer: transformers.PreTrainedTokenizerBase): + self.tokenizer = tokenizer + with open(path, 'r') as f: + inputs, targets = zip(*[ + json.loads(line)["text"] .strip('\n').rsplit(' ', 1) + for line in f.readlines()]) + # This whitespace preprocessing (additional space to the target) + # is required. + #targets = [' ' + tgt for tgt in targets] + self.encodings = self.tokenizer(list(inputs), + targets, + padding=True, + add_special_tokens=False, + return_token_type_ids=True, + return_tensors='pt') + + def __len__(self): + return len(self.encodings['input_ids']) + + def __getitem__(self, idx): + return dict( + input_ids=self.encodings['input_ids'][idx], + attention_mask=self.encodings['attention_mask'][idx], + token_type_ids=self.encodings['token_type_ids'][idx] + ) + + +class Timer: + + def __init__(self): + self._start_times = {} + self._total_elapsed_times = {} + + def start(self, tag='__default'): + self._start_times[tag] = time.time() + + def stop(self, tag='__default'): + elapsed_time = time.time() - self._start_times[tag] + if tag not in self._total_elapsed_times: + self._total_elapsed_times[tag] = 0 + self._total_elapsed_times[tag] += elapsed_time + return elapsed_time + + def elapsed_time_in_sec(self, tag='__default'): + if tag not in self._total_elapsed_times: + return None + return self._total_elapsed_times[tag] + + def reset(self): + self._start_times.clear() + self._total_elapsed_times.clear() + +def split_inputs_and_targets(entries: Dict[str, torch.LongTensor], + pad_token_id: int, + pad_to_left=False): + input_ids = entries['input_ids'] + attn_mask = entries['attention_mask'] + token_type_ids = entries['token_type_ids'] + + # Split inputs and labels by token_type_ids. + input_token_ids = [ + ids[(mask == 1) & (type_ids == 0)] + for ids, mask, type_ids in zip(input_ids, attn_mask, token_type_ids)] + # FT allows int32 tensors. + input_lengths = torch.tensor( + [len(input_tokens) for input_tokens in input_token_ids]).int() + max_length = input_lengths.max() + input_token_ids = torch.stack([ + torch.nn.functional.pad( + token_ids, + pad=[max_length - len(token_ids), 0] + if pad_to_left else [0, max_length - len(token_ids)], + mode='constant', + value=pad_token_id + ) for token_ids in input_token_ids]) + target_token_ids = [ + ids[(mask == 1) & (type_ids == 1)] + for ids, mask, type_ids in zip(input_ids, attn_mask, token_type_ids)] + return input_token_ids, input_lengths, target_token_ids + +def get_args(): + DTYPE_LIST = ["fp16", "bf16", "int8", "w8a8", "bf16_fp16", "bf16_int8"] + + parser = argparse.ArgumentParser( + 'Evaluation: LAMBADA Task', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + group = parser.add_argument_group('LAMBADA Task Parameters') + group.add_argument( + '--dataset_path', type=str, metavar='PATH', required=True, + help="A file path to LAMBADA task dataset.") + group.add_argument( + "--token_path", type=str, metavar='DIR_OR_PATH', default=None, + help='A file path of a pretrained tokenizer or a checkpoint directory ' + 'of HF pretrained model.') + group.add_argument("--model_path", type=str, default="/data/chatglm-6b-cpu", help="Path to model file") + group.add_argument("--dtype", type=str, choices=DTYPE_LIST, default="fp16", help="Data type") + group.add_argument("--batch", default=1, type=int) + group.add_argument("--beam", default=1, type=int) + group.add_argument( + '--show_progress', action='store_true', + help='Show evaluation progress') + args = parser.parse_args() + + print('\n=================== Arguments ===================') + for k, v in vars(args).items(): + print(f' - {k.ljust(25, ".")}: {v}') + print('=================================================') + + return args + +def main(): + args = get_args() + + model = xfastertransformer.AutoModel.from_pretrained(args.model_path, dtype=args.dtype) + + if model.rank == 0: + # Master + tokenizer = AutoTokenizer.from_pretrained(args.token_path, padding_side="left") + + dataset = LambadaDataset(args.dataset_path, tokenizer=tokenizer) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch) + + num_requests = 0 + num_corrects = 0 + + timer = Timer() + if args.show_progress: + data_loader = tqdm.tqdm(data_loader) + + for entries in data_loader: + input_token_ids, input_lengths, target_token_ids = \ + split_inputs_and_targets(entries, tokenizer.pad_token_id, True) + + batch_size = input_token_ids.shape[0] + output_length = max([len(target) for target in target_token_ids]) + + timer.start() + outputs = model.generate(input_token_ids, max_length=input_lengths+output_length, num_beams=args.beam) + timer.stop() + output_token_ids = outputs[:, input_token_ids.shape[1]:] + output_token_ids = [ + out[:len(tgt)].cpu() + for out, tgt in zip(output_token_ids, target_token_ids)] + + output_texts = tokenizer.batch_decode(output_token_ids) + target_texts = tokenizer.batch_decode(target_token_ids) + print('\n', output_texts, target_texts, flush=True) + + for i in range(batch_size): + out = output_token_ids[i] + tgt = target_token_ids[i] + is_correct = (tgt == out).all() + num_corrects += int(is_correct) + + num_requests += batch_size + + accuracy = num_corrects * 100 / num_requests + print(f'Accuracy: {accuracy:0.4f}% ({num_corrects}/{num_requests}) ' + f'(elapsed time: {timer.elapsed_time_in_sec():.4f} sec)') + else: + # Slave + while True: + model.generate() + + +if __name__ == "__main__": + main() + diff --git a/tools/quantization/llama_autogptq_convert.py b/tools/quantization/llama_autogptq_convert.py new file mode 100644 index 00000000..be719571 --- /dev/null +++ b/tools/quantization/llama_autogptq_convert.py @@ -0,0 +1,281 @@ +""" +Convert huggingface ChatGLM model. Use https://huggingface.co/meta-llama +""" + +import argparse +import configparser +import multiprocessing +import numpy as np +import os +import sys +import torch + +from datetime import datetime +from pathlib import Path +from tqdm import tqdm +from transformers import LlamaForCausalLM, LlamaTokenizer + +from auto_gptq import AutoGPTQForCausalLM + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(dir_path + "/../../../..") +sys.path.append(dir_path) + + +def get_weight_data_type(data_type): + if data_type == "fp32": + return np.float32 + elif data_type == "fp16": + return np.float16 + else: + assert False, f"Invalid weight data type {data_type}" + + +def split_and_convert_process(i, saved_dir, factor, key, args, val, old_name, dtype): + def save_val(val, key, tp_num=None): + if key.startswith("model."): + path = saved_dir + "/" + key + else: + path = saved_dir + "/model." + key + + if tp_num is not None: + path += "." + str(tp_num) + path += ".bin" + + val.tofile(path) + + if ( + "input_layernorm.weight" in key + or "input_layernorm.bias" in key + or "attention.dense.bias" in key + or "post_attention_layernorm.weight" in key + or "post_attention_layernorm.bias" in key + or "mlp.dense_4h_to_h.bias" in key + or "final_layernorm.weight" in key + or "final_layernorm.bias" in key + ): + # shared weights, only need to convert the weights of rank 0 + if i == 0: + save_val(val, key) + + elif "qweight" in key or "zeros" in key or "scales" in key: + save_val(val, key, 0) + else: + print("[ERROR] cannot find key '{}'".format(key)) + + """ + + elif "mlp.gate_proj.weight" in key or "mlp.up_proj.weight" in key or "mlp.down_proj.weight" in key: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + save_val(split_vals[j], key, i * factor + j) + + elif "attention.query_key_value.weight" in key: + hidden_dim = val.shape[0] + local_dim = (int)(val.shape[-1] / 3) + + val = val.reshape(hidden_dim, 3, local_dim) + + split_vals = np.split(val, factor, axis=-1) + for j in range(factor): + save_val(split_vals[j], key, i * factor + j) + + elif "attention.dense.weight" in key: + split_vals = np.split(val, factor, axis=0) + for j in range(factor): + save_val(split_vals[j], key, i * factor + j) + + else: + print("[ERROR] cannot find key '{}'".format(key)) + """ + +def split_and_convert(args): + saved_dir = args.saved_dir + + # create directory if not exist + if not os.path.exists(saved_dir): + os.makedirs(saved_dir) + + # load the model + model = AutoGPTQForCausalLM.from_quantized(args.in_file) + + hf_config = vars(model.config) + quantize_config = vars(model.quantize_config) + + print(hf_config) + print(quantize_config) + + layer_names = [name for name, param in model.named_parameters()] + + # save parameters to config file + config = configparser.ConfigParser() + config["llama"] = {} + has_post_decoder_layernorm = True + try: + config["llama"]["model_name"] = "llama" if hf_config["_name_or_path"] == "" else hf_config["_name_or_path"] + config["llama"]["head_num"] = str(hf_config["num_attention_heads"]) + hidden_size = hf_config["hidden_size"] + config["llama"]["size_per_head"] = str(hidden_size // hf_config["num_attention_heads"]) + config["llama"]["inter_size"] = str(hf_config["intermediate_size"]) + config["llama"]["max_pos_seq_len"] = str(hf_config["max_position_embeddings"]) + config["llama"]["num_layer"] = str(hf_config["num_hidden_layers"]) + config["llama"]["rms_norm_eps"] = "1e-6" + config["llama"]["layernorm_type"] = "pre_layernorm" + config["llama"]["activation_type"] = "silu" + config["llama"]["has_post_decoder_layernorm"] = "1" if has_post_decoder_layernorm else "0" + config["llama"]["vocab_size"] = str(hf_config["vocab_size"]) + config["llama"]["start_id"] = str(hf_config["bos_token_id"]) + config["llama"]["end_id"] = str(hf_config["eos_token_id"]) + config["llama"]["weight_data_type"] = args.weight_data_type + + config["llama"]["quant_decoder_weights"] = str(True) + wbits = quantize_config["bits"] + assert wbits == 8, "Only 8bits quantization is supported" + config["llama"]["quant_wbits"] = str(wbits) + assert quantize_config["group_size"] == -1, "Only column wise quantization is supported." + config["llama"]["quant_groupsize"] = str(quantize_config["group_size"]) + #config["llama"]["quant_scheme"] = "sym" if quantize_config["sym"] == True else "asym" + + with open(saved_dir + "/config.ini", "w") as configfile: + config.write(configfile) + except Exception as e: + print("Fail to save the config in config.ini.", str(e)) + + np_weight_data_type = get_weight_data_type(args.weight_data_type) + + + hf_model_name_pattern = [ + "input_layernorm.weight", + "self_attn.qkv_proj.qweight", + "self_attn.qkv_proj.qzeros", + "self_attn.qkv_proj.scales", + "self_attn.o_proj.qweight", + "self_attn.o_proj.qzeros", + "self_attn.o_proj.scales", + "post_attention_layernorm.weight", + "mlp.gate_proj.qweight", + "mlp.gate_proj.qzeros", + "mlp.gate_proj.scales", + "mlp.up_proj.qweight", + "mlp.up_proj.qzeros", + "mlp.up_proj.scales", + "mlp.down_proj.qweight", + "mlp.down_proj.qzeros", + "mlp.down_proj.scales", + ] + + ft_model_name_pattern = [ + "input_layernorm.weight", + "attention.query_key_value.qweight", + "attention.query_key_value.zeros", + "attention.query_key_value.scales", + "attention.dense.qweight", + "attention.dense.zeros", + "attention.dense.scales", + "post_attention_layernorm.weight", + "mlp.gate_proj.qweight", + "mlp.gate_proj.zeros", + "mlp.gate_proj.scales", + "mlp.up_proj.qweight", + "mlp.up_proj.zeros", + "mlp.up_proj.scales", + "mlp.down_proj.qweight", + "mlp.down_proj.zeros", + "mlp.down_proj.scales", + ] + + state_dict = model.state_dict() + + model_named_parameters = dict() + for name, param in state_dict.items(): + if name.startswith("model."): + name = name[6:] + wf = torch.tensor(list(range(0, 32, wbits)), dtype=torch.int32).unsqueeze(0) + + if "embed" in name: + model_named_parameters[name] = param + elif "lm_head" in name: + model_named_parameters[name] = param + elif "scales" in name: + model_named_parameters[name] = param.float() + elif "qzeros" in name: + qzeros = param + qzeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // wbits), wf.unsqueeze(0)).to(torch.int16) + qzeros = torch.bitwise_and(qzeros, (2 ** wbits) - 1) + qzeros = qzeros + 1 - 128 # uint to int + qzeros = torch.flatten(qzeros).float() + scales = state_dict["model." + name.replace("qzeros", "scales")].float() + zeros = - scales * qzeros + model_named_parameters[name] = zeros + elif "qweight" in name: + # qweight is not transposed + param = torch.bitwise_right_shift(torch.unsqueeze(param, 1).expand(-1, 32 // wbits, -1), wf.unsqueeze(-1)).to(torch.int16) + param = torch.bitwise_and(param, (2 ** wbits) - 1) + param = param.reshape(-1, param.shape[2]) + param = param - 128 # uint to int + model_named_parameters[name] = param.to(torch.int8) + else: + model_named_parameters[name] = param.permute(1, 0) if len(param.shape) == 2 else param + + pool = multiprocessing.Pool(args.processes) + for name, param in model_named_parameters.items(): + if name == "model.embed_tokens.weight": + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.bin") + elif name == "model.norm.weight": + param.detach().cpu().numpy().astype(np_weight_data_type).tofile( + saved_dir + "model.final_layernorm.weight.bin" + ) + # elif name == 'model.final_layernorm.bias': + # param.detach().cpu().numpy().astype(np_weight_data_type).tofile( + # saved_dir + "model.final_layernorm.bias.bin") + elif name == "lm_head.weight": + param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") + else: + starmap_args = [] + dtype = np_weight_data_type + if "qweight" in name: + dtype = np.int8 + if "qzero" in name or "scales" in name: + dtype = np.float32 + for i in range(len(hf_model_name_pattern)): + if hf_model_name_pattern[i] in name: + factor = 1 + new_name = name.replace(hf_model_name_pattern[i], ft_model_name_pattern[i]) + starmap_args.append( + ( + 0, + saved_dir, + factor, + new_name, + args, + param.detach().cpu().numpy().astype(dtype), + name, + dtype, + ) + ) + pool.starmap_async(split_and_convert_process, starmap_args) + pool.close() + pool.join() + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_sharing_strategy("file_system") + + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("-saved_dir", "-o", type=str, help="file name of output file", required=True) + parser.add_argument("-in_file", "-i", type=str, help="file name of input checkpoint file", required=True) + parser.add_argument("-processes", "-p", type=int, help="processes to spawn for conversion (default: 8)", default=8) + parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) + + args = parser.parse_args() + print("\n=============== Argument ===============") + for key in vars(args): + print(f"{key}: {vars(args)[key]}") + print("========================================") + + start_time = datetime.now() + split_and_convert(args) + stop_time = datetime.now() + run_time = stop_time - start_time + print(f"[INFO] Spend {run_time} (h:m:s) to convert the model") From 5c02b9ac5350f19ab7472d61364ba8fe08042536 Mon Sep 17 00:00:00 2001 From: Jincheng Miao Date: Wed, 10 Apr 2024 17:20:44 +0800 Subject: [PATCH 2/2] [tools/quantization] rebase to main, update quantization guide --- tools/quantization/README.md | 32 ++- tools/quantization/gptq/README.md | 36 --- tools/quantization/gptq/gptq.py | 134 --------- tools/quantization/gptq/quantizer.py | 115 -------- tools/quantization/gptq/run_model_quant.py | 167 ----------- tools/quantization/llama_autogptq_convert.py | 281 ------------------- tools/quantization/requirements.txt | 138 +++++++++ 7 files changed, 156 insertions(+), 747 deletions(-) delete mode 100644 tools/quantization/gptq/README.md delete mode 100644 tools/quantization/gptq/gptq.py delete mode 100644 tools/quantization/gptq/quantizer.py delete mode 100644 tools/quantization/gptq/run_model_quant.py delete mode 100644 tools/quantization/llama_autogptq_convert.py create mode 100644 tools/quantization/requirements.txt diff --git a/tools/quantization/README.md b/tools/quantization/README.md index b515ae5a..d6023fad 100644 --- a/tools/quantization/README.md +++ b/tools/quantization/README.md @@ -24,7 +24,7 @@ Then this model can be quantized again by AutoGPTQ. In our test, AWQ + AutoGPTQ will improve the accuracy of Llama2-7B on Lambada. -## prepare AWQ +## AWQ clone the llm-awq source code. ```bash cd 3rdparty @@ -33,20 +33,20 @@ cd llm-awq git reset --hard 398b9661415e6a1f89f65c393a13b7f7047b582a ``` -## AWQ on CPU +### AWQ on CPU The llm-awq is targeted for GPU. We have a patch to make it works on CPU. ```bash -git apply ../tools/awq/awq-cpu.patch +git apply ../tools/quantization/awq-cpu.patch pip install -e . ``` -## quantize model (llama2-7b as an example) +### quantize model (llama2-7b as an example) run awq search on llama2-7b to get scales, zeros and dump model to `awq_model` ```bash python -m awq.entry --model_path --w_bit 8 --q_group_size 128 --run_awq --dump_awq awq_cache/llama2_7b_w8.pt --dump_model awq_model/ ``` -# AutoGPTQ +## AutoGPTQ clone AutoGPTQ source code ```bash cd 3rdparty @@ -54,14 +54,14 @@ git clone https://github.com/PanQiWei/AutoGPTQ.git cd AutoGPTQ git reset --hard e4b2493733d69a6e60e22cebc64b619be39feb0e ``` -## AutoGPTQ on CPU +### AutoGPTQ on CPU install AutoGPTQ from source ```bash -git apply ../tools/awq/autogptq-cpu.patch +git apply ../tools/quantization/autogptq-cpu.patch pip install -v . ``` -## quantize model with awq searched model +### quantize model with awq searched model change the `pretrained_model_dir` to `awq_model` in examples/quantization/basic_usage_wikitext2.py ```bash cd examples/quantization @@ -69,17 +69,21 @@ python basic_usage_wikitext2.py ``` After that the GPTQ quantized model would be stored to `quantized_model_dir` according to this script. -## convert quantized model into xFT IR +## convert quantized model into xFT format set `MODEL_PATH` to `quantized_model_dir` from AutoGPTQ output. -set `IR_PATH` to converted xFT IR path. -```bash -python llama_autogptq_convert.py -in_file=${MODEL_PATH} -saved_dir=${IR_PATH} -processes 8 +set `XFT_MODEL_PATH` to converted xFT model path. +```python +import xfastertransformer as xft + +MODEL_PATH="/data/model/llama2-7b-int8" +XFT_MODEL_PATH="/data/model/llama2-7b-int8-xft" +print(xft.LlamaConvert().convert(MODEL_PATH, XFT_MODEL_PATH, from_quantized_model="gptq")) ``` ## check the accuracy on lambada test set set `TOKEN_PATH` to llama2-7b path. -set `IR_PATH` to converted xFT IR path. +set `XFT_MODEL_PATH` to converted xFT model path. ```bash cd tools/awq/ -python llama2_acc_xft.py --dataset_path lambada_test.jsonl --token_path=${TOKEN_PATH} --model_path=${IR_PATH} --show_progress --dtype=int8 +python llama2_acc_xft.py --dataset_path lambada_test.jsonl --token_path=${TOKEN_PATH} --model_path=${XFT_MODEL_PATH} --show_progress --dtype=int8 ``` diff --git a/tools/quantization/gptq/README.md b/tools/quantization/gptq/README.md deleted file mode 100644 index 2a19edef..00000000 --- a/tools/quantization/gptq/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# GPTQ - -A demo for gptq. - -## How to use in model - -```bash -python run_model_quant.py --input_model_path=/data/Llama-2-7b-cpu --output_model_path=/data/Llama-2-quantized-7b-cpu --model_type=llama2 --wbits=8 - -``` - -## How to use in only weight - -```python -import os -import torch - -from gptq import * - -# TODO: move to tests/ut folder in xFT -if __name__ == '__main__': - weight = torch.tensor([ - [0.123456, 1.234567, 3.456789], - [4.567891, 5.678912, 6.789123], - [7.891234, 8.912345, 9.123456] - ]).float() - - llm_gptq = LLM_GPTQ(weight, 8, False) - quantized_weight, scale, zero = llm_gptq.fasterquant() - print("quantized weight is ") - print(quantized_weight) - print("scale is ") - print(scale) - print("zero is ") - print(zero) -``` \ No newline at end of file diff --git a/tools/quantization/gptq/gptq.py b/tools/quantization/gptq/gptq.py deleted file mode 100644 index 97899d4e..00000000 --- a/tools/quantization/gptq/gptq.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import math -import time - -import torch -import torch.nn as nn -# import transformers - -from quantizer import * - -DEBUG = False - -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False - -class LLM_GPTQ: - - def __init__(self, weight, wbits, sym): - self.weight = weight - self.wbits = wbits - self.rows = self.weight.shape[0] - self.columns = self.weight.shape[1] - self.H = torch.zeros((self.rows, self.rows)) - self.nsamples = 0 - self.quantizer = Quantizer() - self.quantizer.configure(wbits, perchannel=True, sym=sym, mse=False) - - def add_batch(self, inp, out): - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - # inp = inp.float() - inp = math.sqrt(2 / self.nsamples) * inp.float() - # self.H += 2 / self.nsamples * inp.matmul(inp.t()) - self.H += inp.matmul(inp.t()) - - def fasterquant( - self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False - ): - W = self.weight.clone() - W = W.float() - tick = time.time() - - if not self.quantizer.ready(): - self.quantizer.find_params(W) - - H = self.H - del self.H - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - - if actorder: - perm = torch.argsort(torch.diag(H), descending=True) - W = W[perm, :] - H = H[perm][:, perm] - - damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(self.rows) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H) - Hinv = H - - Losses = torch.zeros_like(W) - Q = torch.zeros_like(W) - Qint = torch.zeros_like(W) - - for i1 in range(0, self.rows, blocksize): - i2 = min(i1 + blocksize, self.rows) - count = i2 - i1 - - W1 = W[i1:i2, :].clone() - Err1 = torch.zeros_like(W1) - - for i in range(count): - w = W1[i, :] - d = Hinv[i1 + i, i1 + i] - - if groupsize != -1: - if (i1 + i) % groupsize == 0: - self.quantizer.find_params(W[(i1 + i):(i1 + i + groupsize), :]) - - q_int = quantize_to_int(w, self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq) - Qint[i + i1, :] = q_int - q = dequantize_to_float(q_int, self.quantizer.scale, self.quantizer.zero) - Q[i + i1, :] = q - - Losses[i + i1, :] = (w - q) ** 2 / ( 2 * d ** 2) - err1 = (w - q) / d - Err1[i, :] = err1 - - h = Hinv[(i1 + i):i2, i1 + i] - err1 = err1.reshape([1, err1.shape[0]]) - h = h.reshape([h.shape[0], 1]) - W1[i:, :] -= h.matmul(err1) - - W[i2:, :] -= Hinv[i2:, i1:i2].matmul(Err1) - - if actorder: - invperm = torch.argsort(perm) - Q = Q[invperm, :] - - # if isinstance(self.layer, transformers.Conv1D): - # Q = Q.t() - - print('time %.2f' % (time.time() - tick)) - print('error', torch.sum(Losses).item()) - - print("Input Weight (float):", self.weight) - print("Output quantized weight (int{})".format(self.wbits), Qint) - print("Output quantized weight (float)", Q) - return Qint, self.quantizer.scale, self.quantizer.zero - - def free(self): - self.H = None - self.Losses = None - self.Trace = None - torch.cuda.empty_cache() diff --git a/tools/quantization/gptq/quantizer.py b/tools/quantization/gptq/quantizer.py deleted file mode 100644 index 11e51820..00000000 --- a/tools/quantization/gptq/quantizer.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import torch -import torch.nn as nn - -def quantize_to_int(x, scale, zero, maxq): - return torch.clamp(torch.round(x / scale) + zero, 0, maxq) - -def dequantize_to_float(x, scale, zero): - return scale * (x - zero) - -class Quantizer(nn.Module): - - def __init__(self, shape=1): - super(Quantizer, self).__init__() - self.register_buffer('maxq', torch.tensor(0)) - self.register_buffer('scale', torch.zeros(shape)) - self.register_buffer('zero', torch.zeros(shape)) - - def configure( - self, - bits, perchannel=True, sym=False, - mse=False, norm=2.4, grid=100, maxshrink=.8 - ): - self.maxq = torch.tensor(2 ** bits - 1) - self.perchannel = perchannel - self.sym = sym - self.mse = mse - self.norm = norm - self.grid = grid - self.maxshrink = maxshrink - - def find_params(self, x): - dev = x.device - self.maxq = self.maxq.to(dev) - - shape = x.shape - if not self.perchannel: - x = x.flatten().unsqueeze(0) - - tmp = torch.zeros(x.shape[1], device=dev) - xmin = torch.minimum(x.min(0)[0], tmp) - xmax = torch.maximum(x.max(0)[0], tmp) - - if self.sym: - xmax = torch.maximum(torch.abs(xmin), xmax) - tmp = xmin < 0 - if torch.any(tmp): - xmin[tmp] = -xmax[tmp] - tmp = (xmin == 0) & (xmax == 0) - xmin[tmp] = -1 - xmax[tmp] = +1 - - if self.maxq < 0: - self.scale = xmax - self.zero = xmin - else: - self.scale = (xmax - xmin) / self.maxq - if self.sym: - self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) - else: - self.zero = torch.round(-xmin / self.scale) - - if self.mse: - best = torch.full([x.shape[0]], float('inf'), device=dev) - for i in range(int(self.maxshrink * self.grid)): - p = 1 - i / self.grid - xmin1 = p * xmin - xmax1 = p * xmax - scale1 = (xmax1 - xmin1) / self.maxq - zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero - q = quantize_to_int(x, scale1, zero1, self.maxq) - q = dequantize_to_float(q, scale1, zero1) - q -= x - q.abs_() - q.pow_(self.norm) - err = torch.sum(q, 1) - tmp = err < best - if torch.any(tmp): - best[tmp] = err[tmp] - self.scale[tmp] = scale1[tmp] - self.zero[tmp] = zero1[tmp] - if not self.perchannel: - tmp = shape[0] - self.scale = self.scale.repeat(tmp) - self.zero = self.zero.repeat(tmp) - - def quantize_to_int(self, x): - if self.ready(): - return quantize_to_int(x, self.scale, self.zero, self.maxq) - return x - - def dequantize_to_float(self, x): - if self.ready(): - return dequantize_to_float(x, self.scale, self.zero, self.maxq) - return x - - def enabled(self): - return self.maxq > 0 - - def ready(self): - return torch.all(self.scale != 0) diff --git a/tools/quantization/gptq/run_model_quant.py b/tools/quantization/gptq/run_model_quant.py deleted file mode 100644 index 55f6fc04..00000000 --- a/tools/quantization/gptq/run_model_quant.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import os -import shutil -import argparse - -import struct -import configparser -import torch - -from gptq import * - -def get_model_configs(config_file_path): - config = configparser.ConfigParser() - config.read(config_file_path) - first_section = config.sections()[0] - model_configs = {} - model_configs['head_num'] = int(config[first_section]['head_num']) - model_configs['size_per_head'] = int(config[first_section]['size_per_head']) - model_configs['inter_size'] = int(config[first_section]['inter_size']) - model_configs['num_layer'] = int(config[first_section]['num_layer']) - model_configs['weight_data_type'] = config[first_section]['weight_data_type'] - return model_configs - -def read_bin(bin_file_path, rows, columns): - weight = [] - bin_f = open(bin_file_path, "rb") - for h in range(rows): - weight.append([]) - for w in range(columns): - binary_value = bin_f.read(4) - fp32_value = struct.unpack("f", binary_value)[0] - weight[-1].append(fp32_value) - weight = torch.tensor(weight) - return weight - -def write_bin(bin_file_path, weight, wbits): - if os.path.exists(bin_file_path): - os.remove(bin_file_path) - bin_f = open(bin_file_path, "ab+") - weight = weight.flatten() - if wbits == 8: - for i in range(weight.shape[0]): - int8_value = int(weight[i]) - binary_value = int8_value.to_bytes(1, 'big') - bin_f.write(binary_value) - elif wbits == 4: - pos = 0 - int8_value = 0 - for i in range(weight.shape[0]): - if pos == 0: - int8_value = int(weight[i]) - pos = 1 - else: - int8_value += 16 * int(weight[i]) - binary_value = int8_value.to_bytes(1, 'big') - bin_f.write(binary_value) - pos = 0 - if pos == 1: - binary_value = int8_value.to_bytes(1, 'big') - - bin_f.write(binary_value) - - bin_f.close() - -def quantize_weight(input_model_path, output_model_path, bin_file, rows, columns, wbits, sym): - print("Start reading " + os.path.join(input_model_path, bin_file)) - weight = read_bin(os.path.join(input_model_path, bin_file), rows, columns) - llm_gptq = LLM_GPTQ(weight, wbits, sym) - quantized_weight, scale, zero = llm_gptq.fasterquant() - - output_bin_prefix = bin_file[:-5] - write_bin(os.path.join(output_model_path, output_bin_prefix + "quantized.0.bin"), quantized_weight, wbits) - write_bin(os.path.join(output_model_path, output_bin_prefix + "scale.0.bin"), scale, wbits) - write_bin(os.path.join(output_model_path, output_bin_prefix + "zero.0.bin"), zero, wbits) - print("Finish quantization for " + os.path.join(input_model_path, bin_file)) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--model_type', type=str, default="llama", - choices=["llama", "llama2", "chatglm", "chatglm2"] - ) - parser.add_argument( - '--input_model_path', type=str, default="/data/llama-13b-cpu" - ) - parser.add_argument( - '--output_model_path', type=str, default="/data/llama-quantized-13b-cpu" - ) - parser.add_argument( - '--wbits', type=int, default=8, choices=[4, 8] - ) - args = parser.parse_args() - - # TODO: read from config.ini in xFT - model_type = args.model_type - input_model_path = args.input_model_path - output_model_path = args.output_model_path - model_configs = get_model_configs(os.path.join(input_model_path, "config.ini")) - wbits = args.wbits # Support 4 bits or 8 bits - sym = False - - attHeadNum = model_configs["head_num"] - kvHeadNum = model_configs["head_num"] - size_per_head = model_configs["size_per_head"] - imSize = model_configs["inter_size"] - layers = model_configs["num_layer"] - hiddenSize = attHeadNum * size_per_head - qSize = hiddenSize - attHeadSize = int(hiddenSize / attHeadNum) - kvSize = attHeadSize * kvHeadNum - qkvSize = qSize + kvSize + kvSize - - attention_qkv_weight_rows = hiddenSize - attention_qkv_weight_columns = qkvSize - attention_dense_weight_rows = hiddenSize - attention_dense_weight_columns = hiddenSize - # For llama and llama2 - mlp_down_weight_rows = imSize - mlp_down_weight_columns = hiddenSize - mlp_up_weight_rows = hiddenSize - mlp_up_weight_columns = imSize - mlp_gate_weight_rows = hiddenSize - mlp_gate_weight_columns = imSize - # For chatglm and chatglm2 - mlp_dense_h_to_4h_rows = hiddenSize - mlp_dense_h_to_4h_columns = imSize - mlp_dense_4h_to_h_rows = imSize - mlp_dense_4h_to_h_columns = hiddenSize - - prefix = "model.layers" - suffix = "0.bin" - for layer_index in range(layers): - quantize_weight(input_model_path, output_model_path, "model.layers.{}.attention.query_key_value.weight.0.bin".format(layer_index), attention_qkv_weight_rows, attention_qkv_weight_columns, wbits, sym) - quantize_weight(input_model_path, output_model_path, "model.layers.{}.attention.dense.weight.0.bin".format(layer_index), attention_dense_weight_rows, attention_dense_weight_columns, wbits, sym) - if "llama" in model_type: - quantize_weight(input_model_path, output_model_path, "model.layers.{}.mlp.down_proj.weight.0.bin".format(layer_index), mlp_down_weight_rows, mlp_down_weight_columns, wbits, sym) - quantize_weight(input_model_path, output_model_path, "model.layers.{}.mlp.up_proj.weight.0.bin".format(layer_index), mlp_up_weight_rows, mlp_up_weight_columns, wbits, sym) - quantize_weight(input_model_path, output_model_path, "model.layers.{}.mlp.gate_proj.weight.0.bin".format(layer_index), mlp_gate_weight_rows, mlp_gate_weight_columns, wbits, sym) - elif "chatglm" in model_type: - quantize_weight(input_model_path, output_model_path, "model.layers.{}.mlp.dense_h_to_4h.weight.0.bin".format(layer_index), mlp_dense_h_to_4h_rows, mlp_dense_h_to_4h_columns, wbits, sym) - quantize_weight(input_model_path, output_model_path, "model.layers.{}.mlp.dense_4h_to_h.weight.0.bin".format(layer_index), mlp_dense_4h_to_h_rows, mlp_dense_4h_to_h_columns, wbits, sym) - - quantized_bin_files = ["attention.query_key_value.weight", "attention.dense.weight", "mlp.down_proj.weight", "mlp.up_proj.weight", "mlp.gate_proj.weight", "mlp.dense_h_to_4h.weight", "mlp.dense_4h_to_h.weight"] - for bin_file in os.listdir(input_model_path): - quantized = False - for quantized_bin_file in quantized_bin_files: - if quantized_bin_file in bin_file: - quantized = True - break - if quantized: - quantized = False - continue - else: - shutil.copyfile(os.path.join(input_model_path, bin_file), os.path.join(output_model_path, bin_file)) diff --git a/tools/quantization/llama_autogptq_convert.py b/tools/quantization/llama_autogptq_convert.py deleted file mode 100644 index be719571..00000000 --- a/tools/quantization/llama_autogptq_convert.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Convert huggingface ChatGLM model. Use https://huggingface.co/meta-llama -""" - -import argparse -import configparser -import multiprocessing -import numpy as np -import os -import sys -import torch - -from datetime import datetime -from pathlib import Path -from tqdm import tqdm -from transformers import LlamaForCausalLM, LlamaTokenizer - -from auto_gptq import AutoGPTQForCausalLM - -dir_path = os.path.dirname(os.path.realpath(__file__)) -sys.path.append(dir_path + "/../../../..") -sys.path.append(dir_path) - - -def get_weight_data_type(data_type): - if data_type == "fp32": - return np.float32 - elif data_type == "fp16": - return np.float16 - else: - assert False, f"Invalid weight data type {data_type}" - - -def split_and_convert_process(i, saved_dir, factor, key, args, val, old_name, dtype): - def save_val(val, key, tp_num=None): - if key.startswith("model."): - path = saved_dir + "/" + key - else: - path = saved_dir + "/model." + key - - if tp_num is not None: - path += "." + str(tp_num) - path += ".bin" - - val.tofile(path) - - if ( - "input_layernorm.weight" in key - or "input_layernorm.bias" in key - or "attention.dense.bias" in key - or "post_attention_layernorm.weight" in key - or "post_attention_layernorm.bias" in key - or "mlp.dense_4h_to_h.bias" in key - or "final_layernorm.weight" in key - or "final_layernorm.bias" in key - ): - # shared weights, only need to convert the weights of rank 0 - if i == 0: - save_val(val, key) - - elif "qweight" in key or "zeros" in key or "scales" in key: - save_val(val, key, 0) - else: - print("[ERROR] cannot find key '{}'".format(key)) - - """ - - elif "mlp.gate_proj.weight" in key or "mlp.up_proj.weight" in key or "mlp.down_proj.weight" in key: - split_vals = np.split(val, factor, axis=0) - for j in range(factor): - save_val(split_vals[j], key, i * factor + j) - - elif "attention.query_key_value.weight" in key: - hidden_dim = val.shape[0] - local_dim = (int)(val.shape[-1] / 3) - - val = val.reshape(hidden_dim, 3, local_dim) - - split_vals = np.split(val, factor, axis=-1) - for j in range(factor): - save_val(split_vals[j], key, i * factor + j) - - elif "attention.dense.weight" in key: - split_vals = np.split(val, factor, axis=0) - for j in range(factor): - save_val(split_vals[j], key, i * factor + j) - - else: - print("[ERROR] cannot find key '{}'".format(key)) - """ - -def split_and_convert(args): - saved_dir = args.saved_dir - - # create directory if not exist - if not os.path.exists(saved_dir): - os.makedirs(saved_dir) - - # load the model - model = AutoGPTQForCausalLM.from_quantized(args.in_file) - - hf_config = vars(model.config) - quantize_config = vars(model.quantize_config) - - print(hf_config) - print(quantize_config) - - layer_names = [name for name, param in model.named_parameters()] - - # save parameters to config file - config = configparser.ConfigParser() - config["llama"] = {} - has_post_decoder_layernorm = True - try: - config["llama"]["model_name"] = "llama" if hf_config["_name_or_path"] == "" else hf_config["_name_or_path"] - config["llama"]["head_num"] = str(hf_config["num_attention_heads"]) - hidden_size = hf_config["hidden_size"] - config["llama"]["size_per_head"] = str(hidden_size // hf_config["num_attention_heads"]) - config["llama"]["inter_size"] = str(hf_config["intermediate_size"]) - config["llama"]["max_pos_seq_len"] = str(hf_config["max_position_embeddings"]) - config["llama"]["num_layer"] = str(hf_config["num_hidden_layers"]) - config["llama"]["rms_norm_eps"] = "1e-6" - config["llama"]["layernorm_type"] = "pre_layernorm" - config["llama"]["activation_type"] = "silu" - config["llama"]["has_post_decoder_layernorm"] = "1" if has_post_decoder_layernorm else "0" - config["llama"]["vocab_size"] = str(hf_config["vocab_size"]) - config["llama"]["start_id"] = str(hf_config["bos_token_id"]) - config["llama"]["end_id"] = str(hf_config["eos_token_id"]) - config["llama"]["weight_data_type"] = args.weight_data_type - - config["llama"]["quant_decoder_weights"] = str(True) - wbits = quantize_config["bits"] - assert wbits == 8, "Only 8bits quantization is supported" - config["llama"]["quant_wbits"] = str(wbits) - assert quantize_config["group_size"] == -1, "Only column wise quantization is supported." - config["llama"]["quant_groupsize"] = str(quantize_config["group_size"]) - #config["llama"]["quant_scheme"] = "sym" if quantize_config["sym"] == True else "asym" - - with open(saved_dir + "/config.ini", "w") as configfile: - config.write(configfile) - except Exception as e: - print("Fail to save the config in config.ini.", str(e)) - - np_weight_data_type = get_weight_data_type(args.weight_data_type) - - - hf_model_name_pattern = [ - "input_layernorm.weight", - "self_attn.qkv_proj.qweight", - "self_attn.qkv_proj.qzeros", - "self_attn.qkv_proj.scales", - "self_attn.o_proj.qweight", - "self_attn.o_proj.qzeros", - "self_attn.o_proj.scales", - "post_attention_layernorm.weight", - "mlp.gate_proj.qweight", - "mlp.gate_proj.qzeros", - "mlp.gate_proj.scales", - "mlp.up_proj.qweight", - "mlp.up_proj.qzeros", - "mlp.up_proj.scales", - "mlp.down_proj.qweight", - "mlp.down_proj.qzeros", - "mlp.down_proj.scales", - ] - - ft_model_name_pattern = [ - "input_layernorm.weight", - "attention.query_key_value.qweight", - "attention.query_key_value.zeros", - "attention.query_key_value.scales", - "attention.dense.qweight", - "attention.dense.zeros", - "attention.dense.scales", - "post_attention_layernorm.weight", - "mlp.gate_proj.qweight", - "mlp.gate_proj.zeros", - "mlp.gate_proj.scales", - "mlp.up_proj.qweight", - "mlp.up_proj.zeros", - "mlp.up_proj.scales", - "mlp.down_proj.qweight", - "mlp.down_proj.zeros", - "mlp.down_proj.scales", - ] - - state_dict = model.state_dict() - - model_named_parameters = dict() - for name, param in state_dict.items(): - if name.startswith("model."): - name = name[6:] - wf = torch.tensor(list(range(0, 32, wbits)), dtype=torch.int32).unsqueeze(0) - - if "embed" in name: - model_named_parameters[name] = param - elif "lm_head" in name: - model_named_parameters[name] = param - elif "scales" in name: - model_named_parameters[name] = param.float() - elif "qzeros" in name: - qzeros = param - qzeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // wbits), wf.unsqueeze(0)).to(torch.int16) - qzeros = torch.bitwise_and(qzeros, (2 ** wbits) - 1) - qzeros = qzeros + 1 - 128 # uint to int - qzeros = torch.flatten(qzeros).float() - scales = state_dict["model." + name.replace("qzeros", "scales")].float() - zeros = - scales * qzeros - model_named_parameters[name] = zeros - elif "qweight" in name: - # qweight is not transposed - param = torch.bitwise_right_shift(torch.unsqueeze(param, 1).expand(-1, 32 // wbits, -1), wf.unsqueeze(-1)).to(torch.int16) - param = torch.bitwise_and(param, (2 ** wbits) - 1) - param = param.reshape(-1, param.shape[2]) - param = param - 128 # uint to int - model_named_parameters[name] = param.to(torch.int8) - else: - model_named_parameters[name] = param.permute(1, 0) if len(param.shape) == 2 else param - - pool = multiprocessing.Pool(args.processes) - for name, param in model_named_parameters.items(): - if name == "model.embed_tokens.weight": - param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.wte.bin") - elif name == "model.norm.weight": - param.detach().cpu().numpy().astype(np_weight_data_type).tofile( - saved_dir + "model.final_layernorm.weight.bin" - ) - # elif name == 'model.final_layernorm.bias': - # param.detach().cpu().numpy().astype(np_weight_data_type).tofile( - # saved_dir + "model.final_layernorm.bias.bin") - elif name == "lm_head.weight": - param.detach().cpu().numpy().astype(np_weight_data_type).tofile(saved_dir + "model.lm_head.weight.bin") - else: - starmap_args = [] - dtype = np_weight_data_type - if "qweight" in name: - dtype = np.int8 - if "qzero" in name or "scales" in name: - dtype = np.float32 - for i in range(len(hf_model_name_pattern)): - if hf_model_name_pattern[i] in name: - factor = 1 - new_name = name.replace(hf_model_name_pattern[i], ft_model_name_pattern[i]) - starmap_args.append( - ( - 0, - saved_dir, - factor, - new_name, - args, - param.detach().cpu().numpy().astype(dtype), - name, - dtype, - ) - ) - pool.starmap_async(split_and_convert_process, starmap_args) - pool.close() - pool.join() - - -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") - torch.multiprocessing.set_sharing_strategy("file_system") - - parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument("-saved_dir", "-o", type=str, help="file name of output file", required=True) - parser.add_argument("-in_file", "-i", type=str, help="file name of input checkpoint file", required=True) - parser.add_argument("-processes", "-p", type=int, help="processes to spawn for conversion (default: 8)", default=8) - parser.add_argument("-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]) - - args = parser.parse_args() - print("\n=============== Argument ===============") - for key in vars(args): - print(f"{key}: {vars(args)[key]}") - print("========================================") - - start_time = datetime.now() - split_and_convert(args) - stop_time = datetime.now() - run_time = stop_time - start_time - print(f"[INFO] Spend {run_time} (h:m:s) to convert the model") diff --git a/tools/quantization/requirements.txt b/tools/quantization/requirements.txt new file mode 100644 index 00000000..19cf01c1 --- /dev/null +++ b/tools/quantization/requirements.txt @@ -0,0 +1,138 @@ +absl-py==2.0.0 +accelerate==0.25.0 +aiohttp==3.9.1 +aiosignal==1.3.1 +annotated-types==0.6.0 +anyio==4.2.0 +async-timeout==4.0.3 +attributedict==0.3.0 +attrs==23.2.0 +auto_gptq @ file:///root/AutoGPTQ +-e /root/llm-awq +blessings==1.7 +blinker==1.4 +cachetools==5.3.2 +certifi==2023.11.17 +chardet==5.2.0 +charset-normalizer==3.3.2 +click==8.1.7 +codecov==2.1.13 +colorama==0.4.6 +coloredlogs==15.0.1 +colour-runner==0.1.1 +coverage==7.4.0 +cryptography==3.4.8 +DataProperty==1.0.1 +datasets==2.14.6 +dbus-python==1.2.18 +deepdiff==6.7.1 +dill==0.3.7 +distlib==0.3.8 +distro==1.7.0 +distro-info==1.1+ubuntu0.2 +evaluate==0.4.1 +exceptiongroup==1.2.0 +filelock==3.13.1 +frozenlist==1.4.1 +fsspec==2023.4.0 +gekko==1.0.6 +h11==0.14.0 +httpcore==1.0.2 +httplib2==0.20.2 +httpx==0.26.0 +huggingface-hub==0.17.3 +humanfriendly==10.0 +idna==3.6 +importlib-metadata==4.6.4 +importlib-resources==5.13.0 +inspecta==0.1.3 +jeepney==0.7.1 +Jinja2==3.1.2 +joblib==1.3.2 +jsonlines==4.0.0 +keyring==23.5.0 +launchpadlib==1.10.16 +lazr.restfulclient==0.14.4 +lazr.uri==1.0.6 +lm-eval==0.3.0 +lxml==5.0.0 +MarkupSafe==2.1.3 +mbstrdecoder==1.1.3 +more-itertools==8.10.0 +mpmath==1.3.0 +multidict==6.0.4 +multiprocess==0.70.15 +networkx==3.0 +nltk==3.8.1 +numexpr==2.8.6 +numpy==1.24.4 +oauthlib==3.2.0 +openai==1.6.1 +ordered-set==4.1.0 +packaging==23.2 +pandas==2.0.3 +pathvalidate==3.2.0 +peft==0.7.1 +Pillow==9.3.0 +platformdirs==4.1.0 +pluggy==1.3.0 +portalocker==2.8.2 +protobuf==4.25.1 +psutil==5.9.7 +pyarrow==14.0.2 +pybind11==2.11.1 +pycountry==23.12.11 +pydantic==2.5.3 +pydantic_core==2.14.6 +Pygments==2.17.2 +PyGObject==3.42.1 +PyJWT==2.3.0 +pyparsing==2.4.7 +pyproject-api==1.6.1 +pytablewriter==1.2.0 +python-apt==2.4.0+ubuntu2 +python-dateutil==2.8.2 +pytz==2023.3.post1 +PyYAML==6.0.1 +regex==2023.12.25 +requests==2.31.0 +responses==0.18.0 +rootpath==0.1.1 +rouge==1.0.1 +rouge_score==0.1.2 +sacrebleu==1.5.0 +safetensors==0.4.1 +scikit-learn==1.3.2 +scipy==1.10.1 +SecretStorage==3.3.1 +sentencepiece==0.1.99 +six==1.16.0 +sniffio==1.3.0 +sqlitedict==2.1.0 +sympy==1.12 +tabledata==1.3.3 +tabulate==0.9.0 +tcolorpy==0.1.4 +termcolor==2.4.0 +texttable==1.7.0 +threadpoolctl==3.2.0 +tokenizers==0.14.1 +toml==0.10.2 +tomli==2.0.1 +torch==2.1.0+cpu +torchvision==0.16.0+cpu +tox==4.11.4 +tqdm==4.66.1 +tqdm-multiprocess==0.0.11 +transformers==4.34.1 +typepy==1.3.2 +typing_extensions==4.9.0 +tzdata==2023.4 +unattended-upgrades==0.1 +urllib3==2.1.0 +virtualenv==20.25.0 +wadllib==1.3.6 +xxhash==3.4.1 +yarl==1.9.4 +zipp==3.17.0 +zstandard==0.22.0