Skip to content

Commit

Permalink
core, standard tests, partner packages: add test for model params (#2…
Browse files Browse the repository at this point in the history
…1677)

1. Adds `.get_ls_params` to BaseChatModel which returns
```python
class LangSmithParams(TypedDict, total=False):
    ls_provider: str
    ls_model_name: str
    ls_model_type: Literal["chat"]
    ls_temperature: Optional[float]
    ls_max_tokens: Optional[int]
    ls_stop: Optional[List[str]]
```
by default it will only return
```python
{ls_model_type="chat", ls_stop=stop}
```

2. Add these params to inheritable metadata in
`CallbackManager.configure`

3. Implement `.get_ls_params` and populate all params for Anthropic +
all subclasses of BaseChatOpenAI

Sample trace:
https://smith.langchain.com/public/d2962673-4c83-47c7-b51e-61d07aaffb1b/r

**OpenAI**:
<img width="984" alt="Screenshot 2024-05-17 at 10 03 35 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/2ef41f74-a9df-4e0e-905d-da74fa82a910">

**Anthropic**:
<img width="978" alt="Screenshot 2024-05-17 at 10 06 07 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/39701c9f-7da5-4f1a-ab14-84e9169d63e7">

**Mistral** (and all others for which params are not yet populated):
<img width="977" alt="Screenshot 2024-05-17 at 10 08 43 AM"
src="https://github.com/langchain-ai/langchain/assets/26529506/37d7d894-fec2-4300-986f-49a5f0191b03">
  • Loading branch information
ccurme committed May 17, 2024
1 parent 4ca2149 commit 181dfef
Show file tree
Hide file tree
Showing 17 changed files with 293 additions and 50 deletions.
47 changes: 43 additions & 4 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
cast,
)

from typing_extensions import TypedDict

from langchain_core._api import deprecated
from langchain_core.caches import BaseCache
from langchain_core.callbacks import (
Expand Down Expand Up @@ -60,6 +63,15 @@
from langchain_core.tools import BaseTool


class LangSmithParams(TypedDict, total=False):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]


def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream."""

Expand Down Expand Up @@ -206,13 +218,17 @@ def stream(
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
inheritable_metadata,
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
Expand Down Expand Up @@ -273,13 +289,17 @@ async def astream(
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
inheritable_metadata,
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
Expand Down Expand Up @@ -336,6 +356,17 @@ def _get_invocation_params(
params["stop"] = stop
return {**params, **kwargs}

def _get_ls_params(
self,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
ls_params = LangSmithParams(ls_model_type="chat")
if stop:
ls_params["ls_stop"] = stop
return ls_params

def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
Expand Down Expand Up @@ -385,14 +416,18 @@ def generate(
"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
inheritable_metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}

callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
inheritable_metadata,
self.metadata,
)
run_managers = callback_manager.on_chat_model_start(
Expand Down Expand Up @@ -472,14 +507,18 @@ async def agenerate(
"""
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop}
inheritable_metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}

callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
inheritable_metadata,
self.metadata,
)

Expand Down

Large diffs are not rendered by default.

55 changes: 40 additions & 15 deletions libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,31 +476,31 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
{
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
"event": "on_chat_model_start",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
Expand All @@ -526,7 +526,7 @@ def i_dont_stream(input: Any, config: RunnableConfig) -> Any:
},
},
"event": "on_chat_model_end",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
Expand Down Expand Up @@ -569,31 +569,31 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
{
"data": {"input": {"messages": [[HumanMessage(content="hello")]]}},
"event": "on_chat_model_start",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
},
{
"data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())},
"event": "on_chat_model_stream",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
Expand All @@ -619,7 +619,7 @@ async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any:
},
},
"event": "on_chat_model_end",
"metadata": {"a": "b"},
"metadata": {"a": "b", "ls_model_type": "chat", "ls_stop": "<stop_token>"},
"name": "my_model",
"run_id": "",
"tags": ["my_model"],
Expand Down Expand Up @@ -724,15 +724,25 @@ async def test_event_stream_with_simple_chain() -> None:
}
},
"event": "on_chat_model_start",
"metadata": {"a": "b", "foo": "bar"},
"metadata": {
"a": "b",
"foo": "bar",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
},
{
"data": {"chunk": AIMessageChunk(content="hello", id="ai1")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"metadata": {
"a": "b",
"foo": "bar",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
Expand All @@ -748,7 +758,12 @@ async def test_event_stream_with_simple_chain() -> None:
{
"data": {"chunk": AIMessageChunk(content=" ", id="ai1")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"metadata": {
"a": "b",
"foo": "bar",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
Expand All @@ -764,7 +779,12 @@ async def test_event_stream_with_simple_chain() -> None:
{
"data": {"chunk": AIMessageChunk(content="world!", id="ai1")},
"event": "on_chat_model_stream",
"metadata": {"a": "b", "foo": "bar"},
"metadata": {
"a": "b",
"foo": "bar",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
Expand Down Expand Up @@ -805,7 +825,12 @@ async def test_event_stream_with_simple_chain() -> None:
},
},
"event": "on_chat_model_end",
"metadata": {"a": "b", "foo": "bar"},
"metadata": {
"a": "b",
"foo": "bar",
"ls_model_type": "chat",
"ls_stop": "<stop_token>",
},
"name": "my_model",
"run_id": "",
"tags": ["my_chain", "my_model", "seq:step:2"],
Expand Down
Loading

0 comments on commit 181dfef

Please sign in to comment.