Skip to content

Commit

Permalink
Merge pull request #37 from sgwhat/text-stream-opt
Browse files Browse the repository at this point in the history
Optimize streaming-chat performance
  • Loading branch information
sgwhat committed Apr 18, 2024
2 parents 5d9ad12 + 268d090 commit ca0413d
Showing 1 changed file with 19 additions and 25 deletions.
44 changes: 19 additions & 25 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import numpy as np
import torch
from tqdm import tqdm
import transformers
from transformers import LogitsProcessorList, is_torch_xpu_available
from transformers.generation import TextIteratorStreamer

import modules.shared as shared
from modules.callbacks import (
Expand Down Expand Up @@ -378,6 +380,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
import intel_extension_for_pytorch
shared.model = shared.model.to("xpu")

streamer = TextIteratorStreamer(shared.tokenizer, skip_prompt=True)

t0 = time.time()
try:
if not is_chat and not shared.is_seq2seq:
Expand All @@ -391,41 +395,31 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
yield get_reply_from_output_ids(output, state, starting_from=starting_from)

output_tokens = len(output)

# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.
else:
with torch.no_grad():
output = shared.model.generate(**generate_params, streamer=streamer)

cumulative_reply = ''
for new_content in tqdm(streamer, "Generating Tokens", unit="token"):
# check the partial unicode character
if chr(0xfffd) in new_content:
continue

def generate_with_callback(callback=None, *args, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
clear_torch_cache()
with torch.no_grad():
shared.model.generate(**kwargs)

def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, [], kwargs, callback=None)

with generate_with_streaming(**generate_params) as generator:
cumulative_reply = ''
starting_from = 0 if shared.is_seq2seq else len(input_ids[0])
for output in generator:
if output[-1] in eos_token_ids:
break

new_content = get_reply_from_output_ids(output, state, starting_from=starting_from)
# check the partial unicode character
if chr(0xfffd) in new_content:
continue
cumulative_reply += new_content
yield cumulative_reply

cumulative_reply += new_content
starting_from = len(output)
yield cumulative_reply
output_tokens = output.shape[1]

except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(original_input_ids[0])
new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
new_tokens = output_tokens - (original_tokens if not shared.is_seq2seq else 0)
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return

Expand Down Expand Up @@ -456,4 +450,4 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(original_question + reply)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
return

0 comments on commit ca0413d

Please sign in to comment.