24
24
25
25
26
26
class Model :
27
+
27
28
def __init__ (self ):
28
29
self .module = None
29
30
self .model = None
@@ -84,9 +85,19 @@ def get_model_type(model_config):
84
85
model_type = "chatglm2"
85
86
return model_type
86
87
87
- def init (self , model_name , use_quant = True , use_gptq = False , use_awq = False , use_autoround = False ,
88
- weight_dtype = "int4" , alg = "sym" , group_size = 32 ,
89
- scale_dtype = "fp32" , compute_dtype = "int8" , use_ggml = False , model_hub = "huggingface" ):
88
+ def init (self ,
89
+ model_name ,
90
+ use_quant = True ,
91
+ use_gptq = False ,
92
+ use_awq = False ,
93
+ use_autoround = False ,
94
+ weight_dtype = "int4" ,
95
+ alg = "sym" ,
96
+ group_size = 32 ,
97
+ scale_dtype = "fp32" ,
98
+ compute_dtype = "int8" ,
99
+ use_ggml = False ,
100
+ model_hub = "huggingface" ):
90
101
if model_hub == "modelscope" :
91
102
from modelscope import AutoConfig
92
103
self .config = AutoConfig .from_pretrained (model_name , trust_remote_code = True )
@@ -124,24 +135,28 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au
124
135
self .bin_file = quant_bin
125
136
126
137
if os .path .exists (self .bin_file ):
127
- print ("{} existed, will use cache file. Otherwise please remove the file" .
128
- format (self .bin_file ))
138
+ print ("{} existed, will use cache file. Otherwise please remove the file" .format (self .bin_file ))
129
139
return
130
140
131
141
if use_gptq or use_awq or use_autoround :
132
142
convert_model (model_name , quant_bin , use_quantized_model = True )
133
143
return
134
144
135
145
if not os .path .exists (fp32_bin ):
136
- convert_model (model_name , fp32_bin , "f32" , model_hub = model_hub )
146
+ convert_model (model_name , fp32_bin , "f32" , model_hub = model_hub )
137
147
assert os .path .exists (fp32_bin ), "Fail to convert pytorch model"
138
148
139
149
if not use_quant :
140
150
print ("FP32 model will be used." )
141
151
return
142
- self .module .Model .quant_model (model_path = fp32_bin , out_path = quant_bin ,
143
- weight_dtype = weight_dtype , alg = alg , group_size = group_size ,
144
- scale_dtype = scale_dtype , compute_dtype = compute_dtype , use_ggml = use_ggml )
152
+ self .module .Model .quant_model (model_path = fp32_bin ,
153
+ out_path = quant_bin ,
154
+ weight_dtype = weight_dtype ,
155
+ alg = alg ,
156
+ group_size = group_size ,
157
+ scale_dtype = scale_dtype ,
158
+ compute_dtype = compute_dtype ,
159
+ use_ggml = use_ggml )
145
160
assert os .path .exists (quant_bin ), "Fail to quantize model"
146
161
147
162
# clean
@@ -150,9 +165,11 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au
150
165
def init_from_bin (self , model_type , model_path , ** generate_kwargs ):
151
166
self .__import_package (model_type )
152
167
self .model = self .module .Model ()
168
+
153
169
if self .max_request_num == - 1 :
154
- self .max_request_num = max (generate_kwargs .get ("max_request_num" ,
155
- max_request_num_default ), generate_kwargs .get ("batch_size" , 1 ))
170
+ self .max_request_num = max (generate_kwargs .get ("max_request_num" , max_request_num_default ),
171
+ generate_kwargs .get ("batch_size" , 1 ))
172
+
156
173
if "threads" not in generate_kwargs :
157
174
threads = os .getenv ("OMP_NUM_THREADS" )
158
175
import platform
@@ -165,29 +182,107 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs):
165
182
generate_kwargs ["threads" ] = len (os .sched_getaffinity (0 ))
166
183
else :
167
184
generate_kwargs ["threads" ] = int (threads )
168
- self .model .init_model (model_path , ** generate_kwargs )
169
185
186
+ # Setting scratch_size_ratio according to the ctx_size & tokens_length
187
+ # If scratch_size_ratio has been set, will not enter this branch.
188
+ if generate_kwargs .get ("ctx_size" ) is not None and generate_kwargs .get (
189
+ "ctx_size" ) > 2048 and generate_kwargs .get ("scratch_size_ratio" ) is None :
190
+
191
+ def get_max_seq_length ():
192
+ config = self .config .to_dict ()
193
+ # chatglm2, bloom
194
+ if 'seq_length' in config :
195
+ return config ['seq_length' ]
196
+ # qwen2, llama-2, llama, dolly, gptneox, qwen, qwen1.5, opt, phi
197
+ elif 'max_position_embeddings' in config :
198
+ return config ['max_position_embeddings' ]
199
+ # baichuan, baichuan2
200
+ elif 'model_max_length' in config :
201
+ return config ['model_max_length' ]
202
+ # gptj
203
+ elif 'n_positions' in config :
204
+ return config ['n_positions' ]
205
+ # mpt
206
+ elif 'max_seq_len' in config :
207
+ return config ['max_seq_len' ]
208
+ # chatglm
209
+ elif 'max_sequence_length' in config :
210
+ return config ['max_sequence_length' ]
211
+ # whisper
212
+ elif 'max_length' in config :
213
+ return config ['max_length' ]
214
+ # Falcon does not have these parameters.
215
+ elif model_type == "falcon" :
216
+ return 2048
217
+ else :
218
+ print ("Not found max seq length, setting to default 512" )
219
+ return 512
220
+
221
+ # when tokens less than 10240
222
+ def get_scratch_size_ratio (size ):
223
+ if size > 2048 and size <= 4096 :
224
+ return 2
225
+ elif size > 4096 and size <= 8192 :
226
+ return 4
227
+ elif size > 8192 and size <= 10240 :
228
+ return 8
229
+ else :
230
+ # more than 10240
231
+ return - 1
232
+
233
+ max_seq_length = get_max_seq_length ()
234
+ ctx_size = generate_kwargs .get ("ctx_size" )
235
+
236
+ if ctx_size > max_seq_length :
237
+ print (f'max_seq_length is { max_seq_length } , but ctx_size is { ctx_size } . Please reduce ctx_size.' )
238
+ exit (0 )
239
+
240
+ if max_seq_length > 2048 and max_seq_length <= 4096 :
241
+ generate_kwargs ["scratch_size_ratio" ] = 2
242
+ elif max_seq_length > 4096 and max_seq_length <= 8192 :
243
+ generate_kwargs ["scratch_size_ratio" ] = 4
244
+ elif max_seq_length > 8192 :
245
+ if get_scratch_size_ratio (ctx_size ) != - 1 :
246
+ generate_kwargs ["scratch_size_ratio" ] = get_scratch_size_ratio (ctx_size )
247
+ else :
248
+ if max_seq_length == 16384 :
249
+ generate_kwargs ["scratch_size_ratio" ] = 12
250
+ elif max_seq_length == 32768 :
251
+ if ctx_size < 20480 :
252
+ generate_kwargs ["scratch_size_ratio" ] = 20
253
+ else :
254
+ generate_kwargs ["scratch_size_ratio" ] = 35
255
+
256
+ self .model .init_model (model_path , ** generate_kwargs )
170
257
171
258
def quant_model (self , model_type , model_path , out_path , ** quant_kwargs ):
172
259
self .__import_package (model_type )
173
260
self .module .Model .quant_model (model_path = model_path , out_path = out_path , ** quant_kwargs )
174
261
262
+ def generate (self ,
263
+ input_ids ,
264
+ streamer = None ,
265
+ interactive = False ,
266
+ ignore_prompt = False ,
267
+ stopping_criteria = None ,
268
+ ** generate_kwargs ):
269
+ batch_size = input_ids .shape [0 ]
175
270
176
- def generate (self , input_ids , streamer = None , interactive = False , ignore_prompt = False ,
177
- stopping_criteria = None , ** generate_kwargs ):
178
271
max_new_tokens = generate_kwargs .get ("max_new_tokens" , - 1 )
179
- input_bs = input_ids .shape [0 ]
180
272
max_request_num = generate_kwargs .pop ("max_request_num" , max_request_num_default )
181
273
reinit_from_bin = False
182
- if max_request_num > self .max_request_num or input_bs > self .max_request_num :
274
+ if max_request_num > self .max_request_num or batch_size > self .max_request_num :
183
275
reinit_from_bin = True
184
276
if self .max_request_num > 0 :
185
277
print ("Will start to reinit model from bin due to different max request num." )
186
- self .max_request_num = max (input_bs , max_request_num )
278
+ self .max_request_num = max (batch_size , max_request_num )
187
279
188
280
if self .model is None or reinit_from_bin :
189
- self .init_from_bin (self .model_type , self .bin_file , batch_size = input_bs ,
190
- max_request_num = self .max_request_num , ** generate_kwargs )
281
+ self .init_from_bin (self .model_type ,
282
+ self .bin_file ,
283
+ batch_size = batch_size ,
284
+ max_request_num = self .max_request_num ,
285
+ ** generate_kwargs )
191
286
self .generate_round = 0
192
287
elif not interactive :
193
288
self .model .reinit ()
@@ -208,6 +303,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
208
303
assert input_ids .shape [0 ] == 1 , "Streamer only supports batch size 1."
209
304
assert beam_search == False , "ERROR, can not use streamer when use beam search for generation! \
210
305
Make sure that `num_beams` is set to 1."
306
+
211
307
if self .generate_round == 0 and not ignore_prompt :
212
308
streamer .put (input_ids )
213
309
@@ -284,6 +380,6 @@ def _cont_batching_input(self, input_ids, pad_token_id=None):
284
380
for il in range (len (input_list )):
285
381
count = input_list [il ].count (pti )
286
382
# padding left
287
- del input_list [il ][0 : count ]
383
+ del input_list [il ][0 :count ]
288
384
assert input_list [il ] != [], "there are all pad tokens in batch {}." .format (il )
289
385
return input_list
0 commit comments