Skip to content

Commit

Permalink
[HF][streaming][1/n] Text Summarization
Browse files Browse the repository at this point in the history
TSIA

Adding streaming functionality to text summarization model parser

## Test Plan
Rebase onto and test it with 11ace0a.

Follow the README from AIConfig Editor https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev, then run these command
```bash
aiconfig_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/huggingface.aiconfig.json
parsers_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/hf_model_parsers.py
alias aiconfig="python3 -m 'aiconfig.scripts.aiconfig_cli'"
aiconfig edit --aiconfig-path=$aiconfig_path --server-port=8080 --server-mode=debug_servers --parsers-module-path=$parsers_path
```

Then in AIConfig Editor run the prompt (it will be streaming format by default
  • Loading branch information
Rossdan Craig rossdan@lastmileai.dev committed Jan 10, 2024
1 parent 0c9e704 commit c962624
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ async def run_inference(
not "stream" in completion_data or completion_data.get("stream") != False
)
if should_stream:
tokenizer : AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,18 @@ def construct_stream_output(
"metadata": {},
}
)

accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
# For some reason these symbols aren't filtered out by the streamer
new_text = new_text.replace("</s>", "")
new_text = new_text.replace("<s>", "")

accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)

output.data = accumulated_message

return output


Expand Down Expand Up @@ -245,7 +250,9 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
should_stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
if should_stream:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
Expand Down

0 comments on commit c962624

Please sign in to comment.