diff --git a/extensions/HuggingFace/python/requirements.txt b/extensions/HuggingFace/python/requirements.txt index 6388e1c9e..79e5db10b 100644 --- a/extensions/HuggingFace/python/requirements.txt +++ b/extensions/HuggingFace/python/requirements.txt @@ -10,11 +10,12 @@ huggingface_hub #Hugging Face Libraries - Local Inference Tranformers & Diffusors accelerate # Used to help speed up image generation -diffusers # Used for image + audio generation +diffusers # Used for image generation +scipy # array -> wav file, text-speech. torchaudio.save seems broken. +sentencepiece # Used for text translation torch torchvision torchaudio -scipy # array -> wav file, text-speech. torchaudio.save seems broken. transformers # Used for text generation #Other diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py index 9ee8bb357..860a11e46 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py @@ -129,12 +129,19 @@ def construct_stream_output( "metadata": {}, } ) + accumulated_message = "" for new_text in streamer: if isinstance(new_text, str): + # For some reason these symbols aren't filtered out by the streamer + new_text = new_text.replace("", "") + new_text = new_text.replace("", "") + new_text = new_text.replace("", "") + accumulated_message += new_text options.stream_callback(new_text, accumulated_message, 0) output.data = accumulated_message + return output @@ -240,19 +247,26 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio model_name: str = aiconfig.get_model_name(prompt) if isinstance(model_name, str) and model_name not in self.translators: - self.translators[model_name] = pipeline(model_name) + self.translators[model_name] = pipeline("translation", model_name) translator = self.translators[model_name] # if stream enabled in runtime options and config, then stream. Otherwise don't stream. streamer = None - should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) + should_stream = (options.stream if options else False) and ( + not "stream" in completion_data or completion_data.get("stream") != False + ) if should_stream: - raise NotImplementedError("Streaming is not supported for HuggingFace Text Translation") + tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) + streamer = TextIteratorStreamer(tokenizer) + completion_data["streamer"] = streamer + + def _translate(): + return translator(inputs, **completion_data) outputs: List[Output] = [] output = None if not should_stream: - response: List[Any] = translator(inputs, **completion_data) + response: List[Any] = _translate() for count, result in enumerate(response): output = construct_regular_output(result, count) outputs.append(output) @@ -263,7 +277,7 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio raise ValueError("Stream option is selected but streamer is not initialized") # For streaming, cannot call `translator` directly otherwise response will be blocking - thread = threading.Thread(target=translator, kwargs=completion_data) + thread = threading.Thread(target=_translate) thread.start() output = construct_stream_output(streamer, options) if output is not None: