Skip to content

Commit

Permalink
core[patch]: BaseTracer helper method for Run lookup (#14139)
Browse files Browse the repository at this point in the history
I observed the same run ID extraction logic is repeated many times in
`BaseTracer`.

This PR creates a helper method for DRY code.
  • Loading branch information
jamesbraza committed Dec 2, 2023
1 parent 41ee3be commit bdb6ae2
Showing 1 changed file with 21 additions and 62 deletions.
83 changes: 21 additions & 62 deletions libs/core/langchain_core/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:

return parent_run.child_execution_order + 1

def _get_run(self, run_id: UUID, run_type: str | None = None) -> Run:
try:
run = self.run_map[str(run_id)]
except KeyError as exc:
raise TracerException(f"No indexed run ID {run_id}.") from exc
if run_type is not None and run.run_type != run_type:
raise TracerException(
f"Found {run.run_type} run at ID {run_id}, but expected {run_type} run."
)
return run

def on_llm_start(
self,
serialized: Dict[str, Any],
Expand Down Expand Up @@ -138,13 +149,7 @@ def on_llm_new_token(
**kwargs: Any,
) -> Run:
"""Run on new LLM token. Only available when streaming is enabled."""
if not run_id:
raise TracerException("No run_id provided for on_llm_new_token callback.")

run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run = self._get_run(run_id, run_type="llm")
event_kwargs: Dict[str, Any] = {"token": token}
if chunk:
event_kwargs["chunk"] = chunk
Expand All @@ -165,12 +170,7 @@ def on_retry(
run_id: UUID,
**kwargs: Any,
) -> Run:
if not run_id:
raise TracerException("No run_id provided for on_retry callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None:
raise TracerException("No Run found to be traced for on_retry")
llm_run = self._get_run(run_id)
retry_d: Dict[str, Any] = {
"slept": retry_state.idle_for,
"attempt": retry_state.attempt_number,
Expand All @@ -196,13 +196,7 @@ def on_retry(

def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run."""
if not run_id:
raise TracerException("No run_id provided for on_llm_end callback.")

run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run = self._get_run(run_id, run_type="llm")
llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
Expand All @@ -225,13 +219,7 @@ def on_llm_error(
**kwargs: Any,
) -> Run:
"""Handle an error for an LLM run."""
if not run_id:
raise TracerException("No run_id provided for on_llm_error callback.")

run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != "llm":
raise TracerException(f"No LLM Run found to be traced for {run_id}")
llm_run = self._get_run(run_id, run_type="llm")
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
llm_run.events.append({"name": "error", "time": llm_run.end_time})
Expand Down Expand Up @@ -286,12 +274,7 @@ def on_chain_end(
**kwargs: Any,
) -> Run:
"""End a trace for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None:
raise TracerException(f"No chain Run found to be traced for {run_id}")

chain_run = self._get_run(run_id)
chain_run.outputs = (
outputs if isinstance(outputs, dict) else {"output": outputs}
)
Expand All @@ -312,12 +295,7 @@ def on_chain_error(
**kwargs: Any,
) -> Run:
"""Handle an error for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None:
raise TracerException(f"No chain Run found to be traced for {run_id}")

chain_run = self._get_run(run_id)
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time})
Expand Down Expand Up @@ -366,12 +344,7 @@ def on_tool_start(

def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run."""
if not run_id:
raise TracerException("No run_id provided for on_tool_end callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != "tool":
raise TracerException(f"No tool Run found to be traced for {run_id}")

tool_run = self._get_run(run_id, run_type="tool")
tool_run.outputs = {"output": output}
tool_run.end_time = datetime.utcnow()
tool_run.events.append({"name": "end", "time": tool_run.end_time})
Expand All @@ -387,12 +360,7 @@ def on_tool_error(
**kwargs: Any,
) -> Run:
"""Handle an error for a tool run."""
if not run_id:
raise TracerException("No run_id provided for on_tool_error callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != "tool":
raise TracerException(f"No tool Run found to be traced for {run_id}")

tool_run = self._get_run(run_id, run_type="tool")
tool_run.error = repr(error)
tool_run.end_time = datetime.utcnow()
tool_run.events.append({"name": "error", "time": tool_run.end_time})
Expand Down Expand Up @@ -445,12 +413,7 @@ def on_retriever_error(
**kwargs: Any,
) -> Run:
"""Run when Retriever errors."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException(f"No retriever Run found to be traced for {run_id}")

retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.error = repr(error)
retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
Expand All @@ -462,11 +425,7 @@ def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> Run:
"""Run when Retriever ends running."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException(f"No retriever Run found to be traced for {run_id}")
retrieval_run = self._get_run(run_id, run_type="retriever")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
Expand Down

0 comments on commit bdb6ae2

Please sign in to comment.