Skip to content

Commit

Permalink
feat: Add the streaming generator for streaming pipeline
Browse files Browse the repository at this point in the history
* Add the dropping prompt feature in HFTokenStreamingHandler

* Return a streaming generator when pipeline including streaming mode
  PromptNode runs.

* Add the streaming interface for rest_api

Signed-off-by: yuanwu <yuan.wu@intel.com>
  • Loading branch information
yuanwu2017 committed May 12, 2023
1 parent 6c84a05 commit f44ec6d
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 15 deletions.
54 changes: 49 additions & 5 deletions haystack/nodes/prompt/invocation_layer/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod, ABC
from typing import Union
from typing import Union, Optional
from queue import Queue

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, TextStreamer

Expand All @@ -24,6 +25,9 @@ def __call__(self, token_received: str, **kwargs) -> str:


class DefaultTokenStreamingHandler(TokenStreamingHandler):
def __init__(self):
self.q = Queue()

def __call__(self, token_received, **kwargs) -> str:
"""
This callback method is called when a new token is received from the stream.
Expand All @@ -32,17 +36,57 @@ def __call__(self, token_received, **kwargs) -> str:
:param kwargs: Additional keyword arguments passed to the handler.
:return: The token to be sent to the stream.
"""
print(token_received, flush=True, end="")
self.q.put(token_received)
return token_received

def generator(self):
while True:
next_token = self.q.get(True) # Blocks until an input is available
if next_token == TokenStreamingHandler.DONE_MARKER:
break
yield next_token


class HFTokenStreamingHandler(TextStreamer):
def __init__(
self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stream_handler: TokenStreamingHandler
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: TokenStreamingHandler,
dropping_prompt: Optional[str] = None,
):
super().__init__(tokenizer=tokenizer)
self.token_handler = stream_handler
self.answer_start = False

self.dropping_prompt = ""
if dropping_prompt and dropping_prompt != "":
self.dropping_prompt = dropping_prompt.strip().replace(" ", "")

self.special_tokens = ["<s>", "</s>", "<pad>"]
if tokenizer.eos_token and tokenizer.eos_token != "" and tokenizer.eos_token not in self.special_tokens:
self.special_tokens.append(tokenizer.eos_token)
if tokenizer.bos_token and tokenizer.bos_token != "" and tokenizer.bos_token not in self.special_tokens:
self.special_tokens.append(tokenizer.bos_token)
if tokenizer.pad_token and tokenizer.pad_token != "" and tokenizer.pad_token not in self.special_tokens:
self.special_tokens.append(tokenizer.pad_token)

def on_finalized_text(self, token: str, stream_end: bool = False):
token_to_send = token + "\n" if stream_end else token
self.token_handler(token_received=token_to_send, **{})
if self.dropping_prompt != "":
if self.answer_start == False:
token_no_space = token.strip()
token_no_space = token_no_space.replace(" ", "")
if token_no_space in self.dropping_prompt:
self.dropping_prompt = self.dropping_prompt.replace(token_no_space, "")
else:
self.answer_start = True
else:
self.answer_start = True

if self.answer_start:
token_to_send = token
for special_token in self.special_tokens:
token_to_send.replace(special_token, "")
self.token_handler(token_received=token_to_send, **{})
if stream_end == True:
self.token_handler(token_received=TokenStreamingHandler.DONE_MARKER, **{})
self.answer_start = False
12 changes: 11 additions & 1 deletion haystack/nodes/prompt/invocation_layer/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Union, List, Dict
import logging
import os
from threading import Thread

import torch

Expand Down Expand Up @@ -160,6 +161,10 @@ def __init__(
self.pipe.tokenizer.model_max_length,
)

def _run_pipe_task(self, prompt, model_input_kwargs):
output = self.pipe(prompt, **model_input_kwargs)
logger.info("The final output: %s", output)

def invoke(self, *args, **kwargs):
"""
It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
Expand Down Expand Up @@ -223,7 +228,12 @@ def invoke(self, *args, **kwargs):

if stream:
stream_handler: TokenStreamingHandler = kwargs.pop("stream_handler", DefaultTokenStreamingHandler())
model_input_kwargs["streamer"] = HFTokenStreamingHandler(self.pipe.tokenizer, stream_handler)
model_input_kwargs["streamer"] = HFTokenStreamingHandler(
self.pipe.tokenizer, stream_handler, dropping_prompt=prompt
)
thread = Thread(target=self._run_pipe_task, args=[prompt, model_input_kwargs])
thread.start()
return stream_handler.generator()

output = self.pipe(prompt, **model_input_kwargs)
generated_texts = [o["generated_text"] for o in output if "generated_text" in o]
Expand Down
16 changes: 15 additions & 1 deletion haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, *

# kwargs override model kwargs
kwargs = {**self._prepare_model_kwargs(), **kwargs}
stream = kwargs.get("stream", False)
template_to_fill = self.get_prompt_template(prompt_template)
if template_to_fill:
# prompt template used, yield prompts from inputs args
Expand All @@ -172,6 +173,9 @@ def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, *
prompt_collector.append(prompt)
logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s", prompt, kwargs_copy)
output = self.prompt_model.invoke(prompt, **kwargs_copy)
if stream:
return output

results.extend(output)

kwargs["prompts"] = prompt_collector
Expand All @@ -184,6 +188,9 @@ def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, *
prompt_collector.append(prompt)
logger.debug("Prompt being sent to LLM with prompt %s and kwargs %s ", prompt, kwargs_copy)
output = self.prompt_model.invoke(prompt, **kwargs_copy)
if stream:
return output

results.extend(output)
return results

Expand Down Expand Up @@ -310,6 +317,7 @@ def run(
meta: Optional[dict] = None,
invocation_context: Optional[Dict[str, Any]] = None,
prompt_template: Optional[Union[str, PromptTemplate]] = None,
stream: Optional[bool] = False,
) -> Tuple[Dict, str]:
"""
Runs the PromptNode on these input parameters. Returns the output of the prompt model.
Expand Down Expand Up @@ -358,13 +366,19 @@ def run(
if "prompt_template" not in invocation_context.keys():
invocation_context["prompt_template"] = self.get_prompt_template(prompt_template)

results = self(prompt_collector=prompt_collector, **invocation_context)
if "stream" not in invocation_context.keys():
invocation_context["stream"] = stream

kwargs = {**invocation_context}
results = self(prompt_collector=prompt_collector, **kwargs)

prompt_template_resolved: PromptTemplate = invocation_context.pop("prompt_template")
output_variable = self.output_variable or prompt_template_resolved.output_variable or "results"
invocation_context[output_variable] = results
invocation_context["prompts"] = prompt_collector
final_result: Dict[str, Any] = {output_variable: results, "invocation_context": invocation_context}
if stream:
final_result = {"generator": results, "invocation_context": invocation_context}

if self.debug:
final_result["_debug"] = {"prompts_used": prompt_collector}
Expand Down
5 changes: 3 additions & 2 deletions rest_api/rest_api/controller/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

from fastapi import FastAPI, APIRouter
from haystack.schema import Label
from haystack.schema import Label, Span
from haystack.document_stores import BaseDocumentStore
from rest_api.schema import FilterRequest, CreateLabelSerialized
from rest_api.utils import get_app, get_pipelines
Expand Down Expand Up @@ -113,7 +113,8 @@ def export_feedback(

offset_start_in_document = 0
if label.answer and label.answer.offsets_in_document:
offset_start_in_document = label.answer.offsets_in_document[0].start
if isinstance(label.answer.offsets_in_document[0], Span):
offset_start_in_document = label.answer.offsets_in_document[0].start

if full_document_context:
context = label.document.content
Expand Down
37 changes: 36 additions & 1 deletion rest_api/rest_api/controller/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import json

from pydantic import BaseConfig
from fastapi import FastAPI, APIRouter
from fastapi import FastAPI, APIRouter, HTTPException
from fastapi.responses import StreamingResponse

import haystack
from haystack import Pipeline
from haystack.nodes.prompt import PromptNode

from rest_api.utils import get_app, get_pipelines
from rest_api.config import LOG_LEVEL
Expand Down Expand Up @@ -58,6 +61,21 @@ def query(request: QueryRequest):
return result


@router.post("/query-streaming", response_model=StreamingResponse)
def query_streaming(request: QueryRequest):
"""
This streaming endpoint receives the question as a string and allows the requester to set
additional parameters that will be passed on to the Haystack pipeline.
"""
with concurrency_limiter.run():
generator = _get_streaming_generator(query_pipeline, request)
if generator == None:
raise HTTPException(
status_code=501, detail="The pipeline cannot support the streaming mode. The PromptNode is not found!"
)
return StreamingResponse(generator, media_type="text/event-stream")


def _process_request(pipeline, request) -> Dict[str, Any]:
start_time = time.time()

Expand All @@ -74,3 +92,20 @@ def _process_request(pipeline, request) -> Dict[str, Any]:
json.dumps({"request": request, "response": result, "time": f"{(time.time() - start_time):.2f}"}, default=str)
)
return result


def _get_streaming_generator(pipeline, request=None):
params = request.params or {}
components = pipeline.components
node_name = None
generator = None
for name in components.keys():
if isinstance(components[name], PromptNode):
node_name = name

if node_name != None:
streaming_param = {"stream": True}
params[node_name].update(streaming_param)
generator = pipeline.run(query=request.query, params=params)["generator"]

return generator
11 changes: 6 additions & 5 deletions test/prompt/test_prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,12 @@ def test_prompt_node_hf_model_streaming():
# the required HF type transformers.generation.streamers.TextStreamer

pn = PromptNode(model_kwargs={"stream_handler": DefaultTokenStreamingHandler()})
with patch.object(pn.prompt_model.model_invocation_layer.pipe, "run_single", MagicMock()) as mock_call:
pn("Irrelevant prompt")
args, kwargs = mock_call.call_args
assert "streamer" in args[2]
assert isinstance(args[2]["streamer"], TextStreamer)
# with patch.object(pn.prompt_model.model_invocation_layer, "_run_pipe_task", MagicMock()) as mock_call:
generator = pn("What is the capital of Germany?")
answer = ""
for token in generator:
answer = answer + token
assert "berlin" in answer


@pytest.mark.unit
Expand Down

0 comments on commit f44ec6d

Please sign in to comment.