Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit c9fb9d1

Browse files
zhentaoyuVincyZhang
authored andcommitted
[LLM Runtime] ChatGLM-V1 multi-batch infer and batched greedy search generation (#700)
1 parent 2ee9fec commit c9fb9d1

File tree

15 files changed

+575
-224
lines changed

15 files changed

+575
-224
lines changed

graph/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ option(NE_GELU_VEC "neural_engine: enable vec in gelu"
9191
if (NE_GELU_VEC)
9292
add_compile_definitions(NE_GELU_USE_VEC)
9393
endif()
94-
option(NE_PYTHON_API "neural_engine: use python api" OFF)
94+
option(NE_PYTHON_API "neural_engine: use python api" OFF)
95+
option(NE_SIMD_VEC_DOT_F16 "neural_engine: enable vec_dot_fp16 SIMD optimization" ON)
96+
if (NE_SIMD_VEC_DOT_F16)
97+
add_compile_definitions(NE_SIMD_VEC_DOT_F16)
98+
endif()
9599

96100
if(NE_BUILD_TESTS)
97101
enable_testing()

graph/__init__.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,9 @@ def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
145145

146146
def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs):
147147
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
148+
self.batch_size = input_ids.shape[0]
148149
if self.model is None:
149-
self.init_from_bin(self.model_type, self.bin_file, batch_size=input_ids.shape[0],
150+
self.init_from_bin(self.model_type, self.bin_file, batch_size=self.batch_size,
150151
**generate_kwargs)
151152
self.generate_round = 0
152153
elif not interactive:
@@ -160,9 +161,6 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
160161
beam_search = False
161162
if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False):
162163
beam_search = True
163-
if not beam_search:
164-
# TODO support multi batch
165-
assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids."
166164

167165
if streamer:
168166
assert input_ids.shape[0] == 1, "Streamer only supports batch size 1."
@@ -190,9 +188,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
190188
if stopping_criteria is not None:
191189
if stopping_criteria(torch.tensor(ret), None):
192190
break
193-
elif ret[0][-1] == self.eos_token_id() or \
194-
(max_new_tokens != -1 and out_count >= max_new_tokens):
191+
elif (max_new_tokens != -1 and out_count >= max_new_tokens):
195192
break
193+
else:
194+
all_done = [(r[-1] in [self.eos_token_id(), self.pad_token_id()]) for r in ret]
195+
if False not in all_done:
196+
break
196197
if streamer:
197198
streamer.end()
198199

@@ -206,6 +207,15 @@ def eos_token_id(self):
206207
if self.model_type == 'qwen':
207208
return self.tokenizer.special_tokens['<|endoftext|>']
208209
return self.tokenizer.eos_token_id
210+
211+
def pad_token_id(self):
212+
if self.tokenizer.pad_token_id == None:
213+
if self.batch_size == 1:
214+
return None
215+
else:
216+
raise ValueError("Please set pad_token_id when doing multi batch inference"\
217+
" with padding!")
218+
return self.tokenizer.pad_token_id
209219

210220
def __call__(self, input_ids, reinit=False, **kwargs):
211221
if self.model is None:

0 commit comments

Comments
 (0)