Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit eb41b91

Browse files
authored
[Runtime Enhence] Extend long input tokens length (#157)
1 parent 863859b commit eb41b91

40 files changed

+443
-165
lines changed

developer_document.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ For simplicity, we take [polyglot](https://huggingface.co/EleutherAI/polyglot-ko
88

99
Firstly, we need to add its temp buffer in its [related model-arch header file](neural_speed/models/gptneox/gptneox.h) and [re-compile](README.md#Install).
1010
```diff
11-
static const model_scratch gptneox_mem_req(int n_layers) {
11+
static const model_scratch gptneox_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
1212
switch (n_layers) {
1313
case 44:
1414
return {2048ull * MB, 2048ull * MB, 4096ull * MB};
@@ -167,7 +167,7 @@ and update [model_name_to_arch()](neural_speed/models/model_utils/model_types.h#
167167
+ NEW_MODEL_13B,
168168
+};
169169

170-
+static const model_scratch new_model_mem_req(int n_layers) {
170+
+static const model_scratch new_model_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
171171
+ switch (n_layers) {
172172
+ case N:
173173
+ return {8192ull * MB, 8192ull * MB, 8192ull * MB};
@@ -390,7 +390,7 @@ We recommend to use continuous batching way since it has no padding effect and c
390390
+ ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, head_size * n_head * ne_element_size(KQV_merged_contiguous), ne_element_size(KQV_merged_contiguous) * off_sl)));
391391
+ off_sl += head_size * n_head * attn_sl * attn_bs;
392392
```
393-
>Note: You can set larger [`NE_MAX_NODES`](neural_speed/core/ne.h#43) and [`model_scratch_enlarge_scale`](neural_speed/models/llama/llama.h#29) values if out of memory when the inputs' batch size becomes larger.
393+
>Note: You can set larger [`NE_MAX_NODES`](neural_speed/core/ne.h#43) and [`scratch_size_ratio`](neural_speed/models/llama/llama.h#29) values if out of memory when the inputs' batch size becomes larger.
394394
395395
## 2.3. Application
396396
- Q4_0 quant : We can quantize the model generated by convert by adding a quant layer class to quantize it into an int4 low-bit file, so as to obtain better inference performance. Register quant layer class in your new_model_utils.cpp, just like [gptneox_utils.cpp](neural_speed/models/gptneox/gptneox_utils.cpp#L163), replace `gptneox_quant_layer` to your `new_model_quant_layer`.

docs/gptq_and_awq.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Validated GPTQ & AWQ models directly from the HuggingFace:
1313
* [Qwen-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-GPTQ) & [Qwen-7B-Chat-AWQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-AWQ) & * [Qwen1.5-7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GPTQ-Int4)
1414
* [SOLAR-10.7B-v1.0-GPTQ](https://huggingface.co/TheBloke/SOLAR-10.7B-v1.0-GPTQ)
1515

16-
Please check more validated GPTQ & AWQ models in the list of [supported_models](./docs/supported_models.md).
16+
Please check more validated GPTQ & AWQ models in the list of [supported_models](./supported_models.md).
1717

1818
## Examples
1919

docs/supported_models.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Neural Speed supports the following models:
1010
<th colspan="4">INT8</th>
1111
<th colspan="4">INT4</th>
1212
<th rowspan="2">Transformer Version</th>
13+
<th rowspan="2">Max tokens length</th>
1314
</tr>
1415
<tr>
1516
<th>RTN</th>
@@ -36,6 +37,7 @@ Neural Speed supports the following models:
3637
<td>✅</td>
3738
<td>✅</td>
3839
<td>Latest</td>
40+
<td>4096</td>
3941
</tr>
4042
<tr>
4143
<td><a href="https://huggingface.co/decapoda-research/llama-7b-hf" target="_blank" rel="noopener noreferrer">LLaMA-7B</a>,
@@ -49,6 +51,7 @@ Neural Speed supports the following models:
4951
<td>✅</td>
5052
<td>✅</td>
5153
<td>Latest</td>
54+
<td>2048</td>
5255
</tr>
5356
<td><a href="https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf" target="_blank" rel="noopener noreferrer">CodeLlama-7b</a></td>
5457
<td>✅</td>
@@ -60,6 +63,7 @@ Neural Speed supports the following models:
6063
<td>✅</td>
6164
<td>✅</td>
6265
<td>Latest</td>
66+
<td>16384</td>
6367
</tr>
6468
</tr>
6569
<td><a href="https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0" target="_blank" rel="noopener noreferrer">Solar-10.7B</a></td>
@@ -72,6 +76,7 @@ Neural Speed supports the following models:
7276
<td>✅</td>
7377
<td>✅</td>
7478
<td>Latest</td>
79+
<td>4096</td>
7580
</tr>
7681
<tr>
7782
<td><a href="https://huggingface.co/Intel/neural-chat-7b-v3-1" target="_blank" rel="noopener noreferrer">Neural-Chat-7B-v3-1</a>,
@@ -85,6 +90,7 @@ Neural Speed supports the following models:
8590
<td>✅</td>
8691
<td>✅</td>
8792
<td>Latest</td>
93+
<td>32768</td>
8894
</tr>
8995
<tr>
9096
<td><a href="https://huggingface.co/mistralai/Mistral-7B-v0.1" target="_blank" rel="noopener noreferrer">Mistral-7B</a>,
@@ -98,6 +104,7 @@ Neural Speed supports the following models:
98104
<td>✅</td>
99105
<td>✅</td>
100106
<td>4.36.0 or newer</td>
107+
<td>32768</td>
101108
</tr>
102109
<tr>
103110
<td><a href="https://huggingface.co/Qwen/Qwen-7B-Chat" target="_blank" rel="noopener noreferrer">Qwen-7B</a>,
@@ -113,6 +120,7 @@ Neural Speed supports the following models:
113120
<td>✅</td>
114121
<td>✅</td>
115122
<td>Latest</td>
123+
<td>8192 / 32768</td>
116124
</tr>
117125
<tr>
118126
<td><a href="https://huggingface.co/EleutherAI/gpt-j-6b" target="_blank" rel="noopener noreferrer">GPT-J-6B</a></td>
@@ -125,6 +133,7 @@ Neural Speed supports the following models:
125133
<td>✅</td>
126134
<td>✅</td>
127135
<td>Latest</td>
136+
<td>2048</td>
128137
</tr>
129138
<tr>
130139
<td><a href="https://huggingface.co/EleutherAI/gpt-neox-20b" target="_blank" rel="noopener noreferrer">GPT-NeoX-20B</a></td>
@@ -137,6 +146,7 @@ Neural Speed supports the following models:
137146
<td> </td>
138147
<td> </td>
139148
<td>Latest</td>
149+
<td>2048</td>
140150
</tr>
141151
<tr>
142152
<td><a href="https://huggingface.co/databricks/dolly-v2-3b" target="_blank" rel="noopener noreferrer">Dolly-v2-3B</a></td>
@@ -149,6 +159,7 @@ Neural Speed supports the following models:
149159
<td> </td>
150160
<td> </td>
151161
<td>4.28.1 or newer</td>
162+
<td>2048</td>
152163
</tr>
153164
<tr>
154165
<td><a href="https://huggingface.co/mosaicml/mpt-7b" target="_blank" rel="noopener noreferrer">MPT-7B</a>,
@@ -162,6 +173,7 @@ Neural Speed supports the following models:
162173
<td> </td>
163174
<td> </td>
164175
<td>Latest</td>
176+
<td>2048</td>
165177
</tr>
166178
<tr>
167179
<td><a href="https://huggingface.co/tiiuae/falcon-7b" target="_blank" rel="noopener noreferrer">Falcon-7B</a>,
@@ -175,6 +187,7 @@ Neural Speed supports the following models:
175187
<td> </td>
176188
<td> </td>
177189
<td>Latest</td>
190+
<td>2048</td>
178191
</tr>
179192
<tr>
180193
<td><a href="https://huggingface.co/bigscience/bloomz-7b1" target="_blank" rel="noopener noreferrer">BLOOM-7B</a></td>
@@ -187,6 +200,7 @@ Neural Speed supports the following models:
187200
<td> </td>
188201
<td> </td>
189202
<td>Latest</td>
203+
<td>2048</td>
190204
</tr>
191205
<tr>
192206
<td><a href="https://huggingface.co/facebook/opt-125m" target="_blank" rel="noopener noreferrer">OPT-125m</a>,
@@ -201,6 +215,7 @@ Neural Speed supports the following models:
201215
<td> </td>
202216
<td> </td>
203217
<td>Latest</td>
218+
<td>2048</td>
204219
</tr>
205220
<tr>
206221
<td><a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank" rel="noopener noreferrer">ChatGLM-6B</a>,
@@ -214,6 +229,7 @@ Neural Speed supports the following models:
214229
<td> </td>
215230
<td> </td>
216231
<td>4.33.1</td>
232+
<td>2048 / 32768</td>
217233
</tr>
218234
<tr>
219235
<td><a href="https://huggingface.co/baichuan-inc/Baichuan-13B-Chat" target="_blank" rel="noopener noreferrer">Baichuan-13B-Chat</a>,
@@ -227,6 +243,7 @@ Neural Speed supports the following models:
227243
<td> </td>
228244
<td> </td>
229245
<td>4.33.1</td>
246+
<td>4096</td>
230247
</tr>
231248
<tr>
232249
<td><a href="https://huggingface.co/microsoft/phi-2" target="_blank" rel="noopener noreferrer">phi-2</a>,
@@ -241,6 +258,7 @@ Neural Speed supports the following models:
241258
<td> </td>
242259
<td> </td>
243260
<td>Latest</td>
261+
<td>2048</td>
244262
</tr>
245263
<tr>
246264
<td><a href="https://huggingface.co/openai/whisper-tiny" target="_blank" rel="noopener noreferrer">Whisper-tiny</a>,
@@ -257,6 +275,7 @@ Neural Speed supports the following models:
257275
<td> </td>
258276
<td> </td>
259277
<td>Latest</td>
278+
<td>448</td>
260279
</tr>
261280
</tbody>
262281
</table>

neural_speed/__init__.py

Lines changed: 116 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
class Model:
27+
2728
def __init__(self):
2829
self.module = None
2930
self.model = None
@@ -84,9 +85,19 @@ def get_model_type(model_config):
8485
model_type = "chatglm2"
8586
return model_type
8687

87-
def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_autoround=False,
88-
weight_dtype="int4", alg="sym", group_size=32,
89-
scale_dtype="fp32", compute_dtype="int8", use_ggml=False, model_hub="huggingface"):
88+
def init(self,
89+
model_name,
90+
use_quant=True,
91+
use_gptq=False,
92+
use_awq=False,
93+
use_autoround=False,
94+
weight_dtype="int4",
95+
alg="sym",
96+
group_size=32,
97+
scale_dtype="fp32",
98+
compute_dtype="int8",
99+
use_ggml=False,
100+
model_hub="huggingface"):
90101
if model_hub == "modelscope":
91102
from modelscope import AutoConfig
92103
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@@ -124,24 +135,28 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au
124135
self.bin_file = quant_bin
125136

126137
if os.path.exists(self.bin_file):
127-
print("{} existed, will use cache file. Otherwise please remove the file".
128-
format(self.bin_file))
138+
print("{} existed, will use cache file. Otherwise please remove the file".format(self.bin_file))
129139
return
130140

131141
if use_gptq or use_awq or use_autoround:
132142
convert_model(model_name, quant_bin, use_quantized_model=True)
133143
return
134144

135145
if not os.path.exists(fp32_bin):
136-
convert_model(model_name, fp32_bin, "f32", model_hub = model_hub)
146+
convert_model(model_name, fp32_bin, "f32", model_hub=model_hub)
137147
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"
138148

139149
if not use_quant:
140150
print("FP32 model will be used.")
141151
return
142-
self.module.Model.quant_model(model_path=fp32_bin, out_path=quant_bin,
143-
weight_dtype=weight_dtype, alg=alg, group_size=group_size,
144-
scale_dtype=scale_dtype, compute_dtype=compute_dtype, use_ggml=use_ggml)
152+
self.module.Model.quant_model(model_path=fp32_bin,
153+
out_path=quant_bin,
154+
weight_dtype=weight_dtype,
155+
alg=alg,
156+
group_size=group_size,
157+
scale_dtype=scale_dtype,
158+
compute_dtype=compute_dtype,
159+
use_ggml=use_ggml)
145160
assert os.path.exists(quant_bin), "Fail to quantize model"
146161

147162
# clean
@@ -150,9 +165,11 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au
150165
def init_from_bin(self, model_type, model_path, **generate_kwargs):
151166
self.__import_package(model_type)
152167
self.model = self.module.Model()
168+
153169
if self.max_request_num == -1:
154-
self.max_request_num = max(generate_kwargs.get("max_request_num",
155-
max_request_num_default), generate_kwargs.get("batch_size", 1))
170+
self.max_request_num = max(generate_kwargs.get("max_request_num", max_request_num_default),
171+
generate_kwargs.get("batch_size", 1))
172+
156173
if "threads" not in generate_kwargs:
157174
threads = os.getenv("OMP_NUM_THREADS")
158175
import platform
@@ -165,29 +182,107 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs):
165182
generate_kwargs["threads"] = len(os.sched_getaffinity(0))
166183
else:
167184
generate_kwargs["threads"] = int(threads)
168-
self.model.init_model(model_path, **generate_kwargs)
169185

186+
# Setting scratch_size_ratio according to the ctx_size & tokens_length
187+
# If scratch_size_ratio has been set, will not enter this branch.
188+
if generate_kwargs.get("ctx_size") is not None and generate_kwargs.get(
189+
"ctx_size") > 2048 and generate_kwargs.get("scratch_size_ratio") is None:
190+
191+
def get_max_seq_length():
192+
config = self.config.to_dict()
193+
# chatglm2, bloom
194+
if 'seq_length' in config:
195+
return config['seq_length']
196+
# qwen2, llama-2, llama, dolly, gptneox, qwen, qwen1.5, opt, phi
197+
elif 'max_position_embeddings' in config:
198+
return config['max_position_embeddings']
199+
# baichuan, baichuan2
200+
elif 'model_max_length' in config:
201+
return config['model_max_length']
202+
# gptj
203+
elif 'n_positions' in config:
204+
return config['n_positions']
205+
# mpt
206+
elif 'max_seq_len' in config:
207+
return config['max_seq_len']
208+
# chatglm
209+
elif 'max_sequence_length' in config:
210+
return config['max_sequence_length']
211+
# whisper
212+
elif 'max_length' in config:
213+
return config['max_length']
214+
# Falcon does not have these parameters.
215+
elif model_type == "falcon":
216+
return 2048
217+
else:
218+
print("Not found max seq length, setting to default 512")
219+
return 512
220+
221+
# when tokens less than 10240
222+
def get_scratch_size_ratio(size):
223+
if size > 2048 and size <= 4096:
224+
return 2
225+
elif size > 4096 and size <= 8192:
226+
return 4
227+
elif size > 8192 and size <= 10240:
228+
return 8
229+
else:
230+
# more than 10240
231+
return -1
232+
233+
max_seq_length = get_max_seq_length()
234+
ctx_size = generate_kwargs.get("ctx_size")
235+
236+
if ctx_size > max_seq_length:
237+
print(f'max_seq_length is {max_seq_length}, but ctx_size is {ctx_size}. Please reduce ctx_size.')
238+
exit(0)
239+
240+
if max_seq_length > 2048 and max_seq_length <= 4096:
241+
generate_kwargs["scratch_size_ratio"] = 2
242+
elif max_seq_length > 4096 and max_seq_length <= 8192:
243+
generate_kwargs["scratch_size_ratio"] = 4
244+
elif max_seq_length > 8192:
245+
if get_scratch_size_ratio(ctx_size) != -1:
246+
generate_kwargs["scratch_size_ratio"] = get_scratch_size_ratio(ctx_size)
247+
else:
248+
if max_seq_length == 16384:
249+
generate_kwargs["scratch_size_ratio"] = 12
250+
elif max_seq_length == 32768:
251+
if ctx_size < 20480:
252+
generate_kwargs["scratch_size_ratio"] = 20
253+
else:
254+
generate_kwargs["scratch_size_ratio"] = 35
255+
256+
self.model.init_model(model_path, **generate_kwargs)
170257

171258
def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
172259
self.__import_package(model_type)
173260
self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs)
174261

262+
def generate(self,
263+
input_ids,
264+
streamer=None,
265+
interactive=False,
266+
ignore_prompt=False,
267+
stopping_criteria=None,
268+
**generate_kwargs):
269+
batch_size = input_ids.shape[0]
175270

176-
def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False,
177-
stopping_criteria=None, **generate_kwargs):
178271
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
179-
input_bs = input_ids.shape[0]
180272
max_request_num = generate_kwargs.pop("max_request_num", max_request_num_default)
181273
reinit_from_bin = False
182-
if max_request_num > self.max_request_num or input_bs > self.max_request_num:
274+
if max_request_num > self.max_request_num or batch_size > self.max_request_num:
183275
reinit_from_bin = True
184276
if self.max_request_num > 0:
185277
print("Will start to reinit model from bin due to different max request num.")
186-
self.max_request_num = max(input_bs, max_request_num)
278+
self.max_request_num = max(batch_size, max_request_num)
187279

188280
if self.model is None or reinit_from_bin:
189-
self.init_from_bin(self.model_type, self.bin_file, batch_size=input_bs,
190-
max_request_num = self.max_request_num, **generate_kwargs)
281+
self.init_from_bin(self.model_type,
282+
self.bin_file,
283+
batch_size=batch_size,
284+
max_request_num=self.max_request_num,
285+
**generate_kwargs)
191286
self.generate_round = 0
192287
elif not interactive:
193288
self.model.reinit()
@@ -208,6 +303,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
208303
assert input_ids.shape[0] == 1, "Streamer only supports batch size 1."
209304
assert beam_search == False, "ERROR, can not use streamer when use beam search for generation! \
210305
Make sure that `num_beams` is set to 1."
306+
211307
if self.generate_round == 0 and not ignore_prompt:
212308
streamer.put(input_ids)
213309

@@ -284,6 +380,6 @@ def _cont_batching_input(self, input_ids, pad_token_id=None):
284380
for il in range(len(input_list)):
285381
count = input_list[il].count(pti)
286382
# padding left
287-
del input_list[il][0: count]
383+
del input_list[il][0:count]
288384
assert input_list[il] != [], "there are all pad tokens in batch {}.".format(il)
289385
return input_list

0 commit comments

Comments
 (0)