Skip to content

Commit

Permalink
Tests: test stream events with chat model output (#420)
Browse files Browse the repository at this point in the history
Verify that serialization with chat model output works
  • Loading branch information
eyurtsev committed Jan 26, 2024
1 parent 5a9adbb commit 65f50b3
Showing 1 changed file with 180 additions and 1 deletion.
181 changes: 180 additions & 1 deletion tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from enum import Enum
from itertools import cycle
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -33,6 +34,10 @@
from langchain.schema.runnable.base import RunnableLambda
from langchain.schema.runnable.utils import ConfigurableField, Input, Output
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.output_parsers import StrOutputParser
from langchain_core.outputs import ChatGenerationChunk, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langsmith import schemas as ls_schemas
from pytest import MonkeyPatch
from pytest_mock import MockerFixture
Expand All @@ -53,7 +58,7 @@
except ImportError:
from pydantic import BaseModel, Field
from langserve.server import add_routes
from tests.unit_tests.utils.llms import FakeListLLM
from tests.unit_tests.utils.llms import FakeListLLM, GenericFakeChatModel
from tests.unit_tests.utils.tracer import FakeTracer


Expand Down Expand Up @@ -2475,6 +2480,180 @@ def back_to_serializable(inputs) -> str:
assert cb.value.response.status_code == 500


async def test_astream_events_with_prompt_model_parser_chain(
async_remote_runnable: RemoteRunnable,
) -> None:
"""Test prompt + model + parser chain"""

app = FastAPI()

messages = cycle([AIMessage(content="Hello World!")])

model = GenericFakeChatModel(messages=messages)

prompt = ChatPromptTemplate.from_messages(
[("system", "You are a cat."), ("user", "{question}")]
)

chain = prompt | model | StrOutputParser()
add_routes(app, chain)

async with get_async_remote_runnable(app) as runnable:
# Test good requests
events = [
event
async for event in runnable.astream_events(
{"question": "hello"}, version="v1"
)
]
_clean_up_events(events)
assert events == [
{
"data": {"input": {"question": "hello"}},
"event": "on_chain_start",
"name": "RunnableSequence",
"tags": [],
},
{
"data": {"input": {"question": "hello"}},
"event": "on_prompt_start",
"name": "ChatPromptTemplate",
"tags": ["seq:step:1"],
},
{
"data": {
"input": {"question": "hello"},
"output": {
"messages": [
SystemMessage(content="You are a cat."),
HumanMessage(content="hello"),
]
},
},
"event": "on_prompt_end",
"name": "ChatPromptTemplate",
"tags": ["seq:step:1"],
},
{
"data": {
"input": {
"messages": [
[
SystemMessage(content="You are a cat."),
HumanMessage(content="hello"),
]
]
}
},
"event": "on_chat_model_start",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {},
"event": "on_parser_start",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": "Hello"},
"event": "on_parser_stream",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": "Hello"},
"event": "on_chain_stream",
"name": "RunnableSequence",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="Hello")},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": " "},
"event": "on_parser_stream",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": " "},
"event": "on_chain_stream",
"name": "RunnableSequence",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content=" ")},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {"chunk": "World!"},
"event": "on_parser_stream",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"chunk": "World!"},
"event": "on_chain_stream",
"name": "RunnableSequence",
"tags": [],
},
{
"data": {"chunk": AIMessageChunk(content="World!")},
"event": "on_chat_model_stream",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {
"input": {
"messages": [
[
SystemMessage(content="You are a cat."),
HumanMessage(content="hello"),
]
]
},
"output": LLMResult(
generations=[
[
ChatGenerationChunk(
text="Hello World!",
message=AIMessageChunk(content="Hello World!"),
)
]
],
llm_output=None,
run=None,
),
},
"event": "on_chat_model_end",
"name": "GenericFakeChatModel",
"tags": ["seq:step:2"],
},
{
"data": {
"input": AIMessageChunk(content="Hello World!"),
"output": "Hello World!",
},
"event": "on_parser_end",
"name": "StrOutputParser",
"tags": ["seq:step:3"],
},
{
"data": {"output": "Hello World!"},
"event": "on_chain_end",
"name": "RunnableSequence",
"tags": [],
},
]


async def test_path_dependencies() -> None:
"""Test path dependencies."""

Expand Down

0 comments on commit 65f50b3

Please sign in to comment.