Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support: yield sse_starlette.ServerSentEvent in /stream #701

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/custom_events/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Allows the `/stream` endpoint to return `sse_starlette.ServerSentEvent` from runnable,
allowing you to return custom events such as `event: error`.
"""

from typing import Any, AsyncIterator, Dict

from fastapi import FastAPI
from langchain_core.runnables import RunnableConfig, RunnableLambda
from sse_starlette import ServerSentEvent

from langserve import add_routes
from langserve.pydantic_v1 import BaseModel

app = FastAPI(
title="LangChain Server",
version="1.0",
description="Spin up a simple api server using Langchain's Runnable interfaces",
)


class InputType(BaseModel): ...


class OutputType(BaseModel):
message: str


async def error_event(
_: InputType,
config: RunnableConfig,
) -> AsyncIterator[Dict[str, Any] | ServerSentEvent]:
for i in range(4):
yield {
"message": f"Message {i}",
}

is_streaming = False
if "metadata" in config:
metadata = config["metadata"]
if "__langserve_endpoint" in metadata:
is_streaming = metadata["__langserve_endpoint"] == "stream"

if is_streaming:
yield ServerSentEvent(
data={
"message": "An error occurred",
},
event="error",
)
else:
yield {
"message": "An error occurred",
}


add_routes(
app,
RunnableLambda(error_event),
input_type=InputType,
output_type=OutputType,
)

if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="localhost", port=8000)
14 changes: 12 additions & 2 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@
from langserve.version import __version__

try:
from sse_starlette import EventSourceResponse
from sse_starlette import EventSourceResponse, ServerSentEvent
except ImportError:
EventSourceResponse = Any
ServerSentEvent = Any


def _is_hosted() -> bool:
Expand Down Expand Up @@ -1111,7 +1112,7 @@ async def stream(
feedback_key = None
task = None

async def _stream() -> AsyncIterator[dict]:
async def _stream() -> AsyncIterator[dict | ServerSentEvent]:
"""Stream the output of the runnable."""
try:
config_w_callbacks = config.copy()
Expand All @@ -1136,6 +1137,15 @@ async def _stream() -> AsyncIterator[dict]:
run_id, feedback_key, feedback_token
)

if ServerSentEvent is not Any and isinstance(
chunk, ServerSentEvent
):
yield {
"event": chunk.event,
"data": self._serializer.dumps(chunk.data).decode("utf-8"),
}
continue

yield {
# EventSourceResponse expects a string for data
# so after serializing into bytes, we decode into utf-8
Expand Down