Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ option(NE_GELU_VEC "neural_engine: enable vec in gelu"
if (NE_GELU_VEC)
add_compile_definitions(NE_GELU_USE_VEC)
endif()
option(NE_PYTHON_API "neural_engine: use python api" OFF)
option(NE_PYTHON_API "neural_engine: use python api" OFF)
option(NE_SIMD_VEC_DOT_F16 "neural_engine: enable vec_dot_fp16 SIMD optimization" ON)
if (NE_SIMD_VEC_DOT_F16)
add_compile_definitions(NE_SIMD_VEC_DOT_F16)
endif()

if(NE_BUILD_TESTS)
enable_testing()
Expand Down
22 changes: 16 additions & 6 deletions intel_extension_for_transformers/llm/runtime/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def quant_model(self, model_type, model_path, out_path, **quant_kwargs):

def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, stopping_criteria=None, **generate_kwargs):
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
self.batch_size = input_ids.shape[0]
if self.model is None:
self.init_from_bin(self.model_type, self.bin_file, batch_size=input_ids.shape[0],
self.init_from_bin(self.model_type, self.bin_file, batch_size=self.batch_size,
**generate_kwargs)
self.generate_round = 0
elif not interactive:
Expand All @@ -160,9 +161,6 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
beam_search = False
if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False):
beam_search = True
if not beam_search:
# TODO support multi batch
assert input_ids.shape[0] == 1, "Unsupport multi-batch input ids."

if streamer:
assert input_ids.shape[0] == 1, "Streamer only supports batch size 1."
Expand Down Expand Up @@ -190,9 +188,12 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
if stopping_criteria is not None:
if stopping_criteria(torch.tensor(ret), None):
break
elif ret[0][-1] == self.eos_token_id() or \
(max_new_tokens != -1 and out_count >= max_new_tokens):
elif (max_new_tokens != -1 and out_count >= max_new_tokens):
break
else:
all_done = [(r[-1] in [self.eos_token_id(), self.pad_token_id()]) for r in ret]
if False not in all_done:
break
if streamer:
streamer.end()

Expand All @@ -206,6 +207,15 @@ def eos_token_id(self):
if self.model_type == 'qwen':
return self.tokenizer.special_tokens['<|endoftext|>']
return self.tokenizer.eos_token_id

def pad_token_id(self):
if self.tokenizer.pad_token_id == None:
if self.batch_size == 1:
return None
else:
raise ValueError("Please set pad_token_id when doing multi batch inference"\
" with padding!")
return self.tokenizer.pad_token_id

def __call__(self, input_ids, reinit=False, **kwargs):
if self.model is None:
Expand Down
Loading