Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit eed9b30

Browse files
authored
[GPTQ Enhence] Support GPTQ for Baichuan2-13B & Falcon 7B & Phi-1.5 (#169)
1 parent e76a58e commit eed9b30

10 files changed

+819
-351
lines changed

docs/gptq_and_awq.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ Validated GPTQ & AWQ models directly from the HuggingFace:
1212
* [Mixtral-8x7B-Instruct-v0.1-GPTQ](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ) & [Mixtral-8x7B-Instruct-v0.1-AWQ](https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ)
1313
* [Qwen-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-GPTQ) & [Qwen-7B-Chat-AWQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-AWQ) & * [Qwen1.5-7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GPTQ-Int4)
1414
* [SOLAR-10.7B-v1.0-GPTQ](https://huggingface.co/TheBloke/SOLAR-10.7B-v1.0-GPTQ)
15+
* [Baichuan2-13B-Chat-GPTQ](https://hf-mirror.com/TheBloke/Baichuan2-13B-Chat-GPTQ)
16+
* [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b/tree/main)
17+
* [onlinex/phi-1_5-gptq-4bit](https://hf-mirror.com/onlinex/phi-1_5-gptq-4bit)
1518

16-
Please check more validated GPTQ & AWQ models in the list of [supported_models](./supported_models.md).
19+
For more details, please check the list of [supported_models](./supported_models.md).
1720

1821
## Examples
1922

docs/supported_models.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,13 @@ Neural Speed supports the following models:
235235
<td><a href="https://huggingface.co/baichuan-inc/Baichuan-13B-Chat" target="_blank" rel="noopener noreferrer">Baichuan-13B-Chat</a>,
236236
<a href="https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat" target="_blank" rel="noopener noreferrer">Baichuan2-13B-Chat</a></td>
237237
<td>✅</td>
238-
<td> </td>
239-
<td> </td>
240-
<td> </td>
241238
<td>✅</td>
242-
<td> </td>
243-
<td> </td>
244-
<td> </td>
239+
<td>✅</td>
240+
<td>✅</td>
241+
<td>✅</td>
242+
<td>✅</td>
243+
<td>✅</td>
244+
<td>✅</td>
245245
<td>4.33.1</td>
246246
<td>4096</td>
247247
</tr>

neural_speed/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525

2626
class Model:
27-
2827
def __init__(self):
2928
self.module = None
3029
self.model = None
@@ -83,6 +82,15 @@ def get_model_type(model_config):
8382
model_type = model_maps.get(model_config.model_type, model_config.model_type)
8483
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
8584
model_type = "chatglm2"
85+
86+
# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
87+
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
88+
model_type = "falcon"
89+
90+
# for TheBloke/phi-2-GPTQ
91+
if model_type == "phi-msft":
92+
model_type = "phi"
93+
8694
return model_type
8795

8896
def init(self,

neural_speed/convert/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,15 @@
1818
from pathlib import Path
1919
import subprocess
2020

21-
model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper", "qwen2": "qwen"}
21+
model_maps = {
22+
"gpt_neox": "gptneox",
23+
"gpt_bigcode": "starcoder",
24+
"whisper": "whisper",
25+
"qwen2": "qwen",
26+
"RefinedWebModel": "falcon",
27+
"RefinedWeb": "falcon",
28+
"phi-msft": "phi"
29+
}
2230

2331

2432
def convert_model(model, outfile, outtype="f32", model_hub="huggingface", use_quantized_model=False):
@@ -28,6 +36,7 @@ def convert_model(model, outfile, outtype="f32", model_hub="huggingface", use_qu
2836
else:
2937
from transformers import AutoConfig
3038
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
39+
3140
model_type = model_maps.get(config.model_type, config.model_type)
3241

3342
if use_quantized_model:

neural_speed/convert/common.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,3 +516,86 @@ def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head,
516516
compute_dtype="int8")
517517
dst.flatten()[:byte_size].tofile(fout)
518518
print(f"converting {dst_name} qauntized tensor to bestla q4 block")
519+
520+
521+
def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
522+
# unpack weight and repack into 3bits / 4bits BestLA format
523+
import neural_speed.llama_cpp as cpp_model
524+
if ".weight" in src_name:
525+
src_name = src_name.replace(".weight", "")
526+
qzeros = model[f"{src_name}.qzeros"]
527+
zeros = qzeros_to_zeros(qzeros)
528+
scales = model[f"{src_name}.scales"]
529+
qweight = model[f"{src_name}.qweight"]
530+
531+
int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
532+
int_weight = int_weight.view(-1, int_weight.shape[-1])
533+
534+
# shuffle weight in GPTQ when act order is on
535+
if 'desc_act' in q_config and q_config['desc_act']:
536+
g_idx = model[f"{src_name}.g_idx"]
537+
int_weight2 = int_weight.clone()
538+
group_size = q_config['group_size']
539+
group_dict = {}
540+
for i in range(len(g_idx)):
541+
group_idx = g_idx[i].item()
542+
if group_idx not in group_dict:
543+
target_idx = group_idx * group_size
544+
group_dict[group_idx] = 0
545+
else:
546+
group_dict[group_idx] = group_dict[group_idx] + 1
547+
target_idx = group_idx * group_size + group_dict[group_idx]
548+
int_weight2[target_idx] = int_weight[i]
549+
int_weight = int_weight2
550+
551+
# shape = int_weight.shape[::-1]
552+
shape = int_weight.shape[::-1]
553+
# write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)
554+
n_dims = len(shape)
555+
str = dst_name.encode('utf-8')
556+
fout.write(struct.pack("iii", n_dims, len(str), GGML_QJBLAS_TYPE))
557+
for i in range(n_dims):
558+
fout.write(struct.pack("i", shape[n_dims - 1 - i]))
559+
fout.write(str)
560+
561+
# INC stores sig-int4 value as u4(range 0~15, they add a offset),
562+
# BesTLA requires s4_clip((-8,7)*16), so we sub the offset and then mul 16.
563+
# Int3 is the same as int4, but offset=4, mul scale==32.
564+
weight_dtype = "int8"
565+
if q_config['bits'] == 4:
566+
int_weight = (int_weight - 8) * 16
567+
gptq_scales = gptq_scales / 16
568+
gptq_zeros = (gptq_zeros - 8) * 16
569+
weight_dtype = "int4"
570+
elif q_config['bits'] == 3:
571+
int_weight = (int_weight - 4) * 32
572+
gptq_scales = gptq_scales / 32
573+
gptq_zeros = (gptq_zeros - 4) * 32
574+
weight_dtype = "int3"
575+
else:
576+
ValueError(f"Unsupported q_config[bits]: {q_config['bits']}")
577+
578+
dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
579+
int_weight = np.ascontiguousarray(int_weight.numpy())
580+
gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
581+
if q_config['sym']:
582+
gptq_zeros = np.empty(0, dtype=np.int8)
583+
else:
584+
gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
585+
if 'desc_act' in q_config and q_config['desc_act']:
586+
g_idx = np.ascontiguousarray(g_idx.numpy())
587+
else:
588+
g_idx = np.empty(0, dtype=np.int32)
589+
590+
# repack int weight in BesTLA format
591+
byte_size = cpp_model.Model.np_bestla_qpack(int_weight,
592+
gptq_scales,
593+
gptq_zeros,
594+
g_idx,
595+
dst,
596+
weight_dtype=weight_dtype,
597+
group_size=q_config['group_size'],
598+
alg="sym" if q_config['sym'] else "asym",
599+
compute_dtype="int8")
600+
dst.flatten()[:byte_size].tofile(fout)
601+
print(f"convert_to_qx_bestla_tensor: {src_name:>40} -> {dst_name:<40} shape: {shape}, byte_size: {byte_size:<10}")
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright (c) 2023 Intel Corporation
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import os
19+
import json
20+
import sys
21+
import re
22+
import argparse
23+
from common import *
24+
from sentencepiece import SentencePieceProcessor
25+
from transformers import AutoModelForCausalLM, AutoTokenizer
26+
27+
28+
def load_vocab_for_baichuan(path: Path) -> SentencePieceVocab:
29+
# Be extra-friendly and accept either a file or a directory. Also, if it's
30+
# a directory, it might be the model directory, and tokenizer.model might
31+
# be in the parent of that.
32+
if path.is_dir():
33+
path2 = path / "tokenizer.model"
34+
# Use `.parent` instead of /.. to handle the symlink case better.
35+
path3 = path.parent / "tokenizer.model"
36+
if path2.exists():
37+
path = path2
38+
elif path3.exists():
39+
path = path3
40+
else:
41+
raise FileNotFoundError(
42+
f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \
43+
pass the directory as --vocab-dir")
44+
added_tokens_path = path.parent / "added_tokens.json"
45+
print(f"Loading vocab file {path}")
46+
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
47+
48+
49+
def main(args_in: Optional[List[str]] = None) -> None:
50+
parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
51+
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
52+
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
53+
parser.add_argument("--model_hub",
54+
choices=["huggingface", "modelscope"],
55+
default="huggingface",
56+
help="hub to load model")
57+
parser.add_argument("model", type=Path, help="directory containing model file")
58+
args = parser.parse_args(args_in)
59+
60+
out_path = args.outfile.as_posix()
61+
model_path = args.model.as_posix()
62+
63+
model, hparams, quantize_config = load_quantized_safetensors(model_path)
64+
list_vars = model
65+
66+
print(hparams)
67+
68+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
69+
fout = open(out_path, "wb")
70+
71+
# possible data types
72+
# ftype == 0 -> float32, ftype == 1 -> float16
73+
ftype = 0
74+
if args.outtype == "f16":
75+
ftype = 1
76+
77+
# 1. write hparams
78+
print(hparams)
79+
ne_file_magic = 0x67676d66
80+
fout.write(struct.pack("i", ne_file_magic)) # magic: ne in hex
81+
fout.write(struct.pack("i", 1))
82+
83+
fout.write(struct.pack("i", hparams["vocab_size"]))
84+
fout.write(struct.pack("i", hparams["hidden_size"]))
85+
fout.write(struct.pack("i", 0))
86+
fout.write(struct.pack("i", hparams["num_attention_heads"]))
87+
fout.write(struct.pack("i", 0))
88+
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
89+
fout.write(struct.pack("i", 0))
90+
fout.write(struct.pack("i", ftype))
91+
fout.write(struct.pack("i", hparams["model_max_length"]))
92+
fout.write(struct.pack("f", 0))
93+
fout.write(struct.pack("f", 0))
94+
fout.write(struct.pack("i", 0))
95+
96+
fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
97+
fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
98+
99+
fout.write(struct.pack("i", 0))
100+
fout.write(struct.pack("i", 0))
101+
fout.write(struct.pack("i", hparams["intermediate_size"]))
102+
fout.write(struct.pack("i", 0)) # n_experts
103+
fout.write(struct.pack("i", 0)) # n_expert_used
104+
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
105+
fout.write(struct.pack("f", 10000.0)) # freq_base
106+
fout.write(struct.pack("f", 1.0)) # rope_factor
107+
108+
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
109+
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
110+
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
111+
112+
fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
113+
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
114+
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
115+
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
116+
117+
# 2. vocab
118+
tokenizer_path = Path(tokenizer.vocab_file).parent
119+
vocab = load_vocab_for_baichuan(Path(tokenizer_path))
120+
counter = 0
121+
for text, score in vocab.all_tokens():
122+
fout.write(struct.pack("i", len(text)))
123+
fout.write(text)
124+
fout.write(struct.pack("f", score))
125+
counter += 1
126+
127+
while counter < hparams["vocab_size"]:
128+
fout.write(struct.pack("i", len(text)))
129+
fout.write(text)
130+
fout.write(struct.pack("f", 0))
131+
counter += 1
132+
133+
def convert_qwen_to_fp32_tensor(src_name, dst_name, model, fout):
134+
# qwen-gptq is torch.bfloat16 mostly.
135+
if model[src_name].dtype == torch.float32:
136+
data = model[src_name].squeeze().numpy()
137+
else:
138+
data = model[src_name].squeeze().to(torch.float32).numpy()
139+
data = data.astype(np.float32)
140+
shape = data.shape
141+
n_dims = len(shape)
142+
print("convert_qwen_to_fp32_tensor: %40s" % src_name + "-> %-40s" % dst_name + " shape: ", shape, " type: ",
143+
data.dtype)
144+
145+
#ftype_cur = {torch.float16: 1, torch.float32: 0}[data.dtype]
146+
# default type is fp32
147+
ftype_cur = 0
148+
if ftype == 1 and n_dims > 1:
149+
data = data.astype(np.float16)
150+
ftype_cur = 1
151+
else:
152+
data = data.astype(np.float32)
153+
154+
# header
155+
# write_header(fout, shape, dst_name, ftype_cur)
156+
str = src_name.encode('utf-8')
157+
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
158+
for i in range(n_dims):
159+
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
160+
fout.write(str)
161+
162+
# data
163+
data.tofile(fout)
164+
165+
#3. write tensors
166+
convert_qwen_to_fp32_tensor("model.embed_tokens.weight", "model.embed_tokens.weight", list_vars, fout)
167+
convert_qwen_to_fp32_tensor("model.norm.weight", "model.norm.weight", list_vars, fout)
168+
convert_qwen_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)
169+
170+
for i in range(hparams["num_hidden_layers"]):
171+
prefix = "model.layers." + str(i)
172+
173+
convert_qwen_to_fp32_tensor(f"{prefix}.input_layernorm.weight", f"{prefix}.input_layernorm.weight", list_vars,
174+
fout)
175+
convert_qwen_to_fp32_tensor(f"{prefix}.post_attention_layernorm.weight",
176+
f"{prefix}.post_attention_layernorm.weight", list_vars, fout)
177+
# qkv GEMM
178+
convert_to_qx_bestla_tensor(f"{prefix}.self_attn.W_pack.weight", f"{prefix}.self_attn.W_pack.weight", list_vars,
179+
fout, quantize_config)
180+
convert_to_qx_bestla_tensor(f"{prefix}.self_attn.o_proj.weight", f"{prefix}.self_attn.o_proj.weight", list_vars,
181+
fout, quantize_config)
182+
183+
# ffn GEMM
184+
convert_to_qx_bestla_tensor(f"{prefix}.mlp.gate_proj", f"{prefix}.mlp.gate_proj.weight", list_vars, fout,
185+
quantize_config)
186+
convert_to_qx_bestla_tensor(f"{prefix}.mlp.down_proj", f"{prefix}.mlp.down_proj.weight", list_vars, fout,
187+
quantize_config)
188+
convert_to_qx_bestla_tensor(f"{prefix}.mlp.up_proj", f"{prefix}.mlp.up_proj.weight", list_vars, fout,
189+
quantize_config)
190+
191+
fout.close()
192+
print(f"Success! saved as {out_path}")
193+
194+
195+
if __name__ == '__main__':
196+
main()

0 commit comments

Comments
 (0)