Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests: test stream events with chat model output #420

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 180 additions & 1 deletion tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from enum import Enum
from itertools import cycle
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -33,6 +34,10 @@
from langchain.schema.runnable.base import RunnableLambda
from langchain.schema.runnable.utils import ConfigurableField, Input, Output
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.output_parsers import StrOutputParser
from langchain_core.outputs import ChatGenerationChunk, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langsmith import schemas as ls_schemas
from pytest import MonkeyPatch
from pytest_mock import MockerFixture
Expand All @@ -53,7 +58,7 @@
except ImportError:
from pydantic import BaseModel, Field
from langserve.server import add_routes
from tests.unit_tests.utils.llms import FakeListLLM
from tests.unit_tests.utils.llms import FakeListLLM, GenericFakeChatModel
from tests.unit_tests.utils.tracer import FakeTracer


Expand Down Expand Up @@ -2475,6 +2480,180 @@ def back_to_serializable(inputs) -> str:
assert cb.value.response.status_code == 500


async def test_astream_events_with_prompt_model_parser_chain(
async_remote_runnable: RemoteRunnable,
) -> None:
"""Test prompt + model + parser chain"""

app = FastAPI()

messages = cycle([AIMessage(content="Hello World!")])

model = GenericFakeChatModel(messages=messages)

prompt = ChatPromptTemplate.from_messages(
[("system", "You are a cat."), ("user", "{question}")]
)

chain = prompt | model | StrOutputParser()
add_routes(app, chain)

async with get_async_remote_runnable(app) as runnable:
# Test good requests
events = [
event
async for event in runnable.astream_events(
{"question": "hello"}, version="v1"
)
]
_clean_up_events(events)
assert events == [
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
"name": "RunnableSequence",
"tags": [],
},
{
"data": {"input": {"question": "hello"}},
"event": "on_prompt_start",
"name": "ChatPromptTemplate",
"tags": ["seq:step:1"],
},
{
"data": {
"input": {"question": "hello"},
"output": {
"messages": [
SystemMessage(content="You are a cat."),
HumanMessage(content="hello"),
]
},
},
"event": "on_prompt_end",
"name": "ChatPromptTemplate",
"tags": ["seq:step:1"],
},
{
"data": {
"input": {
"messages": [
[
SystemMessage(content="You are a cat."),
HumanMessage(content="hello"),
]
]
}
},
"event": "on_chat_model_start",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {},
"event": "on_parser_start",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": "Hello"},
"event": "on_parser_stream",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": "Hello"},
"event": "on_chain_stream",
"name": "RunnableSequence",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="Hello")},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": " "},
"event": "on_parser_stream",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": " "},
"event": "on_chain_stream",
"name": "RunnableSequence",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": "World!"},
"event": "on_parser_stream",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": "World!"},
"event": "on_chain_stream",
"name": "RunnableSequence",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="World!")},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {
"input": {
"messages": [
[
SystemMessage(content="You are a cat."),
HumanMessage(content="hello"),
]
]
},
"output": LLMResult(
generations=[
[
ChatGenerationChunk(
text="Hello World!",
message=AIMessageChunk(content="Hello World!"),
)
]
],
llm_output=None,
run=None,
),
},
"event": "on_chat_model_end",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {
"input": AIMessageChunk(content="Hello World!"),
"output": "Hello World!",
},
"event": "on_parser_end",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"output": "Hello World!"},
"event": "on_chain_end",
"name": "RunnableSequence",
"tags": [],
},
]


async def test_path_dependencies() -> None:
"""Test path dependencies."""

Expand Down
Loading