Skip to content

Commit

Permalink
[HF][streaming][2/n] Text Translation
Browse files Browse the repository at this point in the history
TSIA

Adding streaming output support for text translation model parser. I also fixed a bug where we didn't pass in `"translation"` key into the pipeline

## Test Plan
Rebase onto and test it: 5b74344.

Follow the README from AIConfig Editor https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev, then run these command
```bash
aiconfig_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/huggingface.aiconfig.json
parsers_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/hf_model_parsers.py
alias aiconfig="python3 -m 'aiconfig.scripts.aiconfig_cli'"
aiconfig edit --aiconfig-path=$aiconfig_path --server-port=8080 --server-mode=debug_servers --parsers-module-path=$parsers_path
```

With Streaming

https://github.com/lastmile-ai/aiconfig/assets/151060367/d7bc9df2-2993-4709-bf9b-c5b7979fb00f

Without Streaming

https://github.com/lastmile-ai/aiconfig/assets/151060367/71eb6ab3-5d6f-4c5d-8b82-f3daf4c5e610
  • Loading branch information
Rossdan Craig rossdan@lastmileai.dev committed Jan 10, 2024
1 parent cd0e811 commit 5888ac3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
5 changes: 3 additions & 2 deletions extensions/HuggingFace/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ huggingface_hub

#Hugging Face Libraries - Local Inference Tranformers & Diffusors
accelerate # Used to help speed up image generation
diffusers # Used for image + audio generation
diffusers # Used for image generation
scipy # array -> wav file, text-speech. torchaudio.save seems broken.
sentencepiece # Used for text translation
torch
torchvision
torchaudio
scipy # array -> wav file, text-speech. torchaudio.save seems broken.
transformers # Used for text generation

#Other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,19 @@ def construct_stream_output(
"metadata": {},
}
)

accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
# For some reason these symbols aren't filtered out by the streamer
new_text = new_text.replace("</s>", "")
new_text = new_text.replace("<s>", "")
new_text = new_text.replace("<pad>", "")

accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
output.data = accumulated_message

return output


Expand Down Expand Up @@ -240,19 +247,26 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio

model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.translators:
self.translators[model_name] = pipeline(model_name)
self.translators[model_name] = pipeline("translation", model_name)
translator = self.translators[model_name]

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
should_stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
if should_stream:
raise NotImplementedError("Streaming is not supported for HuggingFace Text Translation")
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer

def _translate():
return translator(inputs, **completion_data)

outputs: List[Output] = []
output = None
if not should_stream:
response: List[Any] = translator(inputs, **completion_data)
response: List[Any] = _translate()
for count, result in enumerate(response):
output = construct_regular_output(result, count)
outputs.append(output)
Expand All @@ -263,7 +277,7 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
raise ValueError("Stream option is selected but streamer is not initialized")

# For streaming, cannot call `translator` directly otherwise response will be blocking
thread = threading.Thread(target=translator, kwargs=completion_data)
thread = threading.Thread(target=_translate)
thread.start()
output = construct_stream_output(streamer, options)
if output is not None:
Expand Down

0 comments on commit 5888ac3

Please sign in to comment.