Skip to content

enhancement: make streaming more convenient #9347

@tstadel

Description

@tstadel

Is your feature request related to a problem? Please describe.
Implementing streaming in a API facing app is not straight forward. Although streaming callback provides a very versatile interface, building a standard SSE streaming app needs some heavy lifting:

  • API frameworks like fastapi expect a generator for StreamingResponse whereas Haystack does not offer one (already discussed in Support both callback and generator-based streaming in all Chat Generators #8742)
  • rendering tool calls is error prone and highly depends on the model provider api
  • as pipelines can have multiple streaming components, getting to know which component produced a chunk/message needs additional custom code

Describe the solution you'd like
A simpler way of writing fastapi based llm apps that support streaming.

Additional context
If you want to use fastapi's StreamingResponse, you need to make several tweeks to get streaming to work. E.g. look at this snippet that uses a Queue as adapter between Haystack's streaming_callback approach and fastapi's Generator approach. Also note that adding additional metadata and parsing tool calls uses a decorator pattern for streaming callbacks. Ideally there would be a simpler way to do this. Any means that facilitate this custom logic would be highly appreciated:

async def run_pipeline_streaming(
    pipeline: PipelineBase,
    request: QueryStreamRequest,
    streaming_generators: list[str],
    background_tasks: BackgroundTasks,
) -> AsyncGenerator[str, None]:
    start = time()
    component_names = set(name for name, _ in pipeline.walk())
    pipeline_input = to_pipeline_input(request.model_dump())
    outputs = get_outputs(pipeline, request)
    query_id = uuid4()
    streaming_callback = CustomStreamingCallback(query_id)

    for streaming_generator in streaming_generators:
        streaming_generator_params = pipeline_input.setdefault(streaming_generator, {})
        decorated_callback = TimeToFirstTokenDecorator(
            ToolCallRenderingCallbackDecorator(
                ComponentCallbackDecorator(streaming_callback, component=streaming_generator)
            )
        )
        streaming_generator_params["streaming_callback"] = decorated_callback

    async def pipeline_run_task() -> dict[str, Any]:
        result = await _invoke_pipeline(pipeline, pipeline_input, outputs)
        streaming_callback.mark_as_done()
        return result

    # Start pipeline execution in background task group
    # Pipeline.run will be called in a separate thread while AsyncPipeline.run will be called on the main event loop
    try:
        async with asyncer.create_task_group() as task_group:
            task = task_group.soonify(pipeline_run_task)()
            async for chunk in streaming_callback.get_chunks():
                yield chunk
    except ExceptionGroup as exc_group:
        for error in exc_group.exceptions:
            # Make sure exc_info is set to the exception and not the ExceptionGroup
            # This can be removed once ExceptionGroups are supported in structlog or dc-unified-logging
            # See https://github.com/hynek/structlog/issues/676
            logger.error(f"Pipeline run failed. {error!s}", exc_info=error)
            yield error_to_stream(query_id, str(error))
        return

    try:
        haystack_result = task.value
        result = from_pipeline(haystack_result, mapping=PIPELINE_MAPPING_OUTPUT)
        query_response = _validate_query_response(result, request, haystack_result, query_id)
    except WrongQueryResponseFormatError as error:
        logger.exception(
            "Pipeline returned wrong format. Failed parsing the response. Please check your pipelines return values.",
            result=error.result,
        )
        yield error_to_stream(
            query_id,
            f"Pipeline returned wrong format. Failed parsing the response: {error!s}\n"
            f"Please check your pipelines return values: {error.result}",
        )
        return

    if request.include_result:
        payload = {"query_id": str(query_id), "result": query_response.serialize(), "type": "result"}
        yield to_stream_message(payload)


class CustomStreamingCallback:
    """
    A custom streaming callback that stores the tokens in a queue and provides a method to get the tokens as chunks.
    """

    DONE_MARKER = StreamingChunk("[DONE]", meta={"is_done": True})

    def __init__(self, query_id: UUID) -> None:
        self.query_id = query_id
        self.queue: Queue = Queue()

    def __call__(self, chunk_received: StreamingChunk) -> None:
        """
        This callback method is called when a new chunk is received from the stream.

        :param chunk_received: The chunk received from the stream.
        """
        self.queue.put_nowait(chunk_received)

    def mark_as_done(self) -> None:
        self.queue.put_nowait(self.DONE_MARKER)

    async def get_chunks(self) -> AsyncGenerator[str, None]:
        while True:
            next_chunk: StreamingChunk = await self.queue.get()
            if next_chunk == self.DONE_MARKER:
                break
            if next_chunk.content:
                payload = {
                    "query_id": str(self.query_id),
                    "delta": {"text": next_chunk.content, "meta": next_chunk.meta},
                    "type": "delta",
                }
                yield to_stream_message(payload)


class ComponentCallbackDecorator:
    """
    Decorator to augment the StreamingChunk's meta with the component's name.
    """

    def __init__(self, streaming_callback: Callable[[StreamingChunk], None], component: str) -> None:
        self.streaming_callback = streaming_callback
        self.component = component

    def __call__(self, chunk_received: StreamingChunk) -> None:
        """
        This callback method is called when a new chunk is received from the stream.

        :param chunk_received: The chunk received from the stream.
        """
        chunk_received.meta["deepset_cloud"] = {"component": self.component}
        self.streaming_callback(chunk_received)


class ToolCallRenderingCallbackDecorator:
    """
    Decorator to augment the StreamingChunk's content with the tool call data from meta.
    """

    TOOL_START = '\n\n**Tool Use:**\n```json\n{{\n  "name": "{tool_name}",\n  "arguments": '
    TOOL_END = "\n}\n```\n"

    def __init__(self, streaming_callback: Callable[[StreamingChunk], None]) -> None:
        self.streaming_callback = streaming_callback
        self._openai_tool_call_index = 0

    def __call__(self, chunk_received: StreamingChunk) -> None:
        """
        This callback method is called when a new chunk is received from the stream.

        :param chunk_received: The chunk received from the stream.
        """
        chunk_received = self._render_anthropic_tool_call(chunk_received)
        chunk_received = self._render_openai_tool_call(chunk_received)

        self.streaming_callback(chunk_received)

    def _render_openai_tool_call(self, chunk_received: StreamingChunk) -> StreamingChunk:
        tool_calls = chunk_received.meta.get("tool_calls") or []
        for tool_call in tool_calls:
            if not tool_call.function:
                continue
            # mutliple tool calls (distinguished by index) can be concatenated without finish_reason in between
            if self._openai_tool_call_index < tool_call.index:
                chunk_received.content += self.TOOL_END
            self._openai_tool_call_index = tool_call.index
            if tool_name := tool_call.function.name:
                chunk_received.content += self.TOOL_START.format(tool_name=tool_name)
            if arguments := tool_call.function.arguments:
                chunk_received.content += arguments
        if chunk_received.meta.get("finish_reason") == "tool_calls":
            chunk_received.content += self.TOOL_END
        return chunk_received

    def _render_anthropic_tool_call(self, chunk_received: StreamingChunk) -> StreamingChunk:
        content_block = chunk_received.meta.get("content_block") or {}
        if content_block.get("type") == "tool_use":
            tool_name = content_block.get("name") or ""
            content = self.TOOL_START.format(tool_name=tool_name)
            chunk_received.content += content
        delta = chunk_received.meta.get("delta") or {}
        if delta.get("type") == "input_json_delta":
            partial_json = delta.get("partial_json") or ""
            chunk_received.content += partial_json
        if delta.get("stop_reason") == "tool_use":
            content = self.TOOL_END
            chunk_received.content += content
        return chunk_received

Sub-issues

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions