Skip to content

Commit

Permalink
Track RunnableAssign as a separate run trace (#13972)
Browse files Browse the repository at this point in the history
Addressing incorrect order being sent to callbacks / tracers, due to the
nature of threading

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
  • Loading branch information
dqbd and nfcampos committed Nov 28, 2023
1 parent 0f255bb commit eb67f07
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 11 deletions.
97 changes: 86 additions & 11 deletions libs/core/langchain_core/runnables/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import threading
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Expand All @@ -31,11 +32,18 @@
acall_func_with_variable_args,
call_func_with_variable_args,
get_executor_for_config,
patch_config,
)
from langchain_core.runnables.utils import AddableDict, ConfigurableFieldSpec
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.iter import safetee

if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)


def identity(x: Other) -> Other:
"""An identity function"""
Expand Down Expand Up @@ -345,46 +353,84 @@ def get_output_schema(
def config_specs(self) -> List[ConfigurableFieldSpec]:
return self.mapper.config_specs

def invoke(
def _invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."

return {
**input,
**self.mapper.invoke(input, config, **kwargs),
**self.mapper.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
),
}

async def ainvoke(
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self._call_with_config(self._invoke, input, config, **kwargs)

async def _ainvoke(
self,
input: Dict[str, Any],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."

return {
**input,
**await self.mapper.ainvoke(input, config, **kwargs),
**await self.mapper.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
),
}

def transform(
async def ainvoke(
self,
input: Iterator[Dict[str, Any]],
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)

def _transform(
self,
input: Iterator[Dict[str, Any]],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())

# create map output stream
map_output = self.mapper.transform(for_map, config, **kwargs)
map_output = self.mapper.transform(
for_map,
patch_config(
config,
callbacks=run_manager.get_child(),
),
**kwargs,
)

# get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor:
# start map output stream
Expand All @@ -409,18 +455,36 @@ def transform(
for chunk in map_output:
yield chunk

async def atransform(
def transform(
self,
input: AsyncIterator[Dict[str, Any]],
input: Iterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Dict[str, Any]]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
)

async def _atransform(
self,
input: AsyncIterator[Dict[str, Any]],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream
map_output = self.mapper.atransform(for_map, config, **kwargs)
map_output = self.mapper.atransform(
for_map,
patch_config(
config,
callbacks=run_manager.get_child(),
),
**kwargs,
)
# start map output stream
first_map_chunk_task: asyncio.Task = asyncio.create_task(
py_anext(map_output, None), # type: ignore[arg-type]
Expand All @@ -441,6 +505,17 @@ async def atransform(
async for chunk in map_output:
yield chunk

async def atransform(
self,
input: AsyncIterator[Dict[str, Any]],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
):
yield chunk

def stream(
self,
input: Dict[str, Any],
Expand Down
41 changes: 41 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4146,3 +4146,44 @@ def func(__input: dict) -> Runnable:
return idchain

assert await RunnableLambda(func).ainvoke({})


def test_invoke_stream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
return False

chain = RunnablePassthrough.assign(urls=idchain_sync)

tracer = FakeTracer()
chain.invoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))

assert tracer.runs[0].name == "RunnableAssign"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"

tracer = FakeTracer()
for item in chain.stream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass

assert tracer.runs[0].name == "RunnableAssign"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"


@pytest.mark.asyncio
async def test_ainvoke_astream_passthrough_assign_trace() -> None:
def idchain_sync(__input: dict) -> bool:
return False

chain = RunnablePassthrough.assign(urls=idchain_sync)

tracer = FakeTracer()
await chain.ainvoke({"example": [1, 2, 3]}, dict(callbacks=[tracer]))

assert tracer.runs[0].name == "RunnableAssign"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"

tracer = FakeTracer()
async for item in chain.astream({"example": [1, 2, 3]}, dict(callbacks=[tracer])):
pass

assert tracer.runs[0].name == "RunnableAssign"
assert tracer.runs[0].child_runs[0].name == "RunnableParallel"

0 comments on commit eb67f07

Please sign in to comment.