diff --git a/modules/text_generation.py b/modules/text_generation.py index 1dc8dcf499..8072fdf987 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -9,8 +9,10 @@ import numpy as np import torch +from tqdm import tqdm import transformers from transformers import LogitsProcessorList, is_torch_xpu_available +from transformers.generation import TextIteratorStreamer import modules.shared as shared from modules.callbacks import ( @@ -378,6 +380,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings import intel_extension_for_pytorch shared.model = shared.model.to("xpu") + streamer = TextIteratorStreamer(shared.tokenizer, skip_prompt=True) + t0 = time.time() try: if not is_chat and not shared.is_seq2seq: @@ -391,41 +395,31 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) yield get_reply_from_output_ids(output, state, starting_from=starting_from) + output_tokens = len(output) + # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator. else: + with torch.no_grad(): + output = shared.model.generate(**generate_params, streamer=streamer) + + cumulative_reply = '' + for new_content in tqdm(streamer, "Generating Tokens", unit="token"): + # check the partial unicode character + if chr(0xfffd) in new_content: + continue - def generate_with_callback(callback=None, *args, **kwargs): - kwargs['stopping_criteria'].append(Stream(callback_func=callback)) - clear_torch_cache() - with torch.no_grad(): - shared.model.generate(**kwargs) - - def generate_with_streaming(**kwargs): - return Iteratorize(generate_with_callback, [], kwargs, callback=None) - - with generate_with_streaming(**generate_params) as generator: - cumulative_reply = '' - starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) - for output in generator: - if output[-1] in eos_token_ids: - break - - new_content = get_reply_from_output_ids(output, state, starting_from=starting_from) - # check the partial unicode character - if chr(0xfffd) in new_content: - continue + cumulative_reply += new_content + yield cumulative_reply - cumulative_reply += new_content - starting_from = len(output) - yield cumulative_reply + output_tokens = output.shape[1] except Exception: traceback.print_exc() finally: t1 = time.time() original_tokens = len(original_input_ids[0]) - new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0) + new_tokens = output_tokens - (original_tokens if not shared.is_seq2seq else 0) print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return @@ -456,4 +450,4 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str original_tokens = len(encode(original_question)[0]) new_tokens = len(encode(original_question + reply)[0]) - original_tokens print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') - return + return \ No newline at end of file