-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Description
Is your feature request related to a problem? Please describe.
Implementing streaming in a API facing app is not straight forward. Although streaming callback provides a very versatile interface, building a standard SSE streaming app needs some heavy lifting:
- API frameworks like fastapi expect a generator for
StreamingResponsewhereas Haystack does not offer one (already discussed in Support both callback and generator-based streaming in all Chat Generators #8742) - rendering tool calls is error prone and highly depends on the model provider api
- as pipelines can have multiple streaming components, getting to know which component produced a chunk/message needs additional custom code
Describe the solution you'd like
A simpler way of writing fastapi based llm apps that support streaming.
Additional context
If you want to use fastapi's StreamingResponse, you need to make several tweeks to get streaming to work. E.g. look at this snippet that uses a Queue as adapter between Haystack's streaming_callback approach and fastapi's Generator approach. Also note that adding additional metadata and parsing tool calls uses a decorator pattern for streaming callbacks. Ideally there would be a simpler way to do this. Any means that facilitate this custom logic would be highly appreciated:
async def run_pipeline_streaming(
pipeline: PipelineBase,
request: QueryStreamRequest,
streaming_generators: list[str],
background_tasks: BackgroundTasks,
) -> AsyncGenerator[str, None]:
start = time()
component_names = set(name for name, _ in pipeline.walk())
pipeline_input = to_pipeline_input(request.model_dump())
outputs = get_outputs(pipeline, request)
query_id = uuid4()
streaming_callback = CustomStreamingCallback(query_id)
for streaming_generator in streaming_generators:
streaming_generator_params = pipeline_input.setdefault(streaming_generator, {})
decorated_callback = TimeToFirstTokenDecorator(
ToolCallRenderingCallbackDecorator(
ComponentCallbackDecorator(streaming_callback, component=streaming_generator)
)
)
streaming_generator_params["streaming_callback"] = decorated_callback
async def pipeline_run_task() -> dict[str, Any]:
result = await _invoke_pipeline(pipeline, pipeline_input, outputs)
streaming_callback.mark_as_done()
return result
# Start pipeline execution in background task group
# Pipeline.run will be called in a separate thread while AsyncPipeline.run will be called on the main event loop
try:
async with asyncer.create_task_group() as task_group:
task = task_group.soonify(pipeline_run_task)()
async for chunk in streaming_callback.get_chunks():
yield chunk
except ExceptionGroup as exc_group:
for error in exc_group.exceptions:
# Make sure exc_info is set to the exception and not the ExceptionGroup
# This can be removed once ExceptionGroups are supported in structlog or dc-unified-logging
# See https://github.com/hynek/structlog/issues/676
logger.error(f"Pipeline run failed. {error!s}", exc_info=error)
yield error_to_stream(query_id, str(error))
return
try:
haystack_result = task.value
result = from_pipeline(haystack_result, mapping=PIPELINE_MAPPING_OUTPUT)
query_response = _validate_query_response(result, request, haystack_result, query_id)
except WrongQueryResponseFormatError as error:
logger.exception(
"Pipeline returned wrong format. Failed parsing the response. Please check your pipelines return values.",
result=error.result,
)
yield error_to_stream(
query_id,
f"Pipeline returned wrong format. Failed parsing the response: {error!s}\n"
f"Please check your pipelines return values: {error.result}",
)
return
if request.include_result:
payload = {"query_id": str(query_id), "result": query_response.serialize(), "type": "result"}
yield to_stream_message(payload)
class CustomStreamingCallback:
"""
A custom streaming callback that stores the tokens in a queue and provides a method to get the tokens as chunks.
"""
DONE_MARKER = StreamingChunk("[DONE]", meta={"is_done": True})
def __init__(self, query_id: UUID) -> None:
self.query_id = query_id
self.queue: Queue = Queue()
def __call__(self, chunk_received: StreamingChunk) -> None:
"""
This callback method is called when a new chunk is received from the stream.
:param chunk_received: The chunk received from the stream.
"""
self.queue.put_nowait(chunk_received)
def mark_as_done(self) -> None:
self.queue.put_nowait(self.DONE_MARKER)
async def get_chunks(self) -> AsyncGenerator[str, None]:
while True:
next_chunk: StreamingChunk = await self.queue.get()
if next_chunk == self.DONE_MARKER:
break
if next_chunk.content:
payload = {
"query_id": str(self.query_id),
"delta": {"text": next_chunk.content, "meta": next_chunk.meta},
"type": "delta",
}
yield to_stream_message(payload)
class ComponentCallbackDecorator:
"""
Decorator to augment the StreamingChunk's meta with the component's name.
"""
def __init__(self, streaming_callback: Callable[[StreamingChunk], None], component: str) -> None:
self.streaming_callback = streaming_callback
self.component = component
def __call__(self, chunk_received: StreamingChunk) -> None:
"""
This callback method is called when a new chunk is received from the stream.
:param chunk_received: The chunk received from the stream.
"""
chunk_received.meta["deepset_cloud"] = {"component": self.component}
self.streaming_callback(chunk_received)
class ToolCallRenderingCallbackDecorator:
"""
Decorator to augment the StreamingChunk's content with the tool call data from meta.
"""
TOOL_START = '\n\n**Tool Use:**\n```json\n{{\n "name": "{tool_name}",\n "arguments": '
TOOL_END = "\n}\n```\n"
def __init__(self, streaming_callback: Callable[[StreamingChunk], None]) -> None:
self.streaming_callback = streaming_callback
self._openai_tool_call_index = 0
def __call__(self, chunk_received: StreamingChunk) -> None:
"""
This callback method is called when a new chunk is received from the stream.
:param chunk_received: The chunk received from the stream.
"""
chunk_received = self._render_anthropic_tool_call(chunk_received)
chunk_received = self._render_openai_tool_call(chunk_received)
self.streaming_callback(chunk_received)
def _render_openai_tool_call(self, chunk_received: StreamingChunk) -> StreamingChunk:
tool_calls = chunk_received.meta.get("tool_calls") or []
for tool_call in tool_calls:
if not tool_call.function:
continue
# mutliple tool calls (distinguished by index) can be concatenated without finish_reason in between
if self._openai_tool_call_index < tool_call.index:
chunk_received.content += self.TOOL_END
self._openai_tool_call_index = tool_call.index
if tool_name := tool_call.function.name:
chunk_received.content += self.TOOL_START.format(tool_name=tool_name)
if arguments := tool_call.function.arguments:
chunk_received.content += arguments
if chunk_received.meta.get("finish_reason") == "tool_calls":
chunk_received.content += self.TOOL_END
return chunk_received
def _render_anthropic_tool_call(self, chunk_received: StreamingChunk) -> StreamingChunk:
content_block = chunk_received.meta.get("content_block") or {}
if content_block.get("type") == "tool_use":
tool_name = content_block.get("name") or ""
content = self.TOOL_START.format(tool_name=tool_name)
chunk_received.content += content
delta = chunk_received.meta.get("delta") or {}
if delta.get("type") == "input_json_delta":
partial_json = delta.get("partial_json") or ""
chunk_received.content += partial_json
if delta.get("stop_reason") == "tool_use":
content = self.TOOL_END
chunk_received.content += content
return chunk_received