@@ -145,8 +145,9 @@ def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
145
145
146
146
def generate (self , input_ids , streamer = None , interactive = False , ignore_prompt = False , stopping_criteria = None , ** generate_kwargs ):
147
147
max_new_tokens = generate_kwargs .get ("max_new_tokens" , - 1 )
148
+ self .batch_size = input_ids .shape [0 ]
148
149
if self .model is None :
149
- self .init_from_bin (self .model_type , self .bin_file , batch_size = input_ids . shape [ 0 ] ,
150
+ self .init_from_bin (self .model_type , self .bin_file , batch_size = self . batch_size ,
150
151
** generate_kwargs )
151
152
self .generate_round = 0
152
153
elif not interactive :
@@ -160,9 +161,6 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
160
161
beam_search = False
161
162
if (generate_kwargs .get ("num_beams" , 1 ) > 1 ) and not generate_kwargs .get ("do_sample" , False ):
162
163
beam_search = True
163
- if not beam_search :
164
- # TODO support multi batch
165
- assert input_ids .shape [0 ] == 1 , "Unsupport multi-batch input ids."
166
164
167
165
if streamer :
168
166
assert input_ids .shape [0 ] == 1 , "Streamer only supports batch size 1."
@@ -190,9 +188,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
190
188
if stopping_criteria is not None :
191
189
if stopping_criteria (torch .tensor (ret ), None ):
192
190
break
193
- elif ret [0 ][- 1 ] == self .eos_token_id () or \
194
- (max_new_tokens != - 1 and out_count >= max_new_tokens ):
191
+ elif (max_new_tokens != - 1 and out_count >= max_new_tokens ):
195
192
break
193
+ else :
194
+ all_done = [(r [- 1 ] in [self .eos_token_id (), self .pad_token_id ()]) for r in ret ]
195
+ if False not in all_done :
196
+ break
196
197
if streamer :
197
198
streamer .end ()
198
199
@@ -206,6 +207,15 @@ def eos_token_id(self):
206
207
if self .model_type == 'qwen' :
207
208
return self .tokenizer .special_tokens ['<|endoftext|>' ]
208
209
return self .tokenizer .eos_token_id
210
+
211
+ def pad_token_id (self ):
212
+ if self .tokenizer .pad_token_id == None :
213
+ if self .batch_size == 1 :
214
+ return None
215
+ else :
216
+ raise ValueError ("Please set pad_token_id when doing multi batch inference" \
217
+ " with padding!" )
218
+ return self .tokenizer .pad_token_id
209
219
210
220
def __call__ (self , input_ids , reinit = False , ** kwargs ):
211
221
if self .model is None :
0 commit comments