Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wait for all futures #6554

Merged
merged 3 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions langchain/callbacks/tracers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import logging
import os
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor, wait
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
Expand All @@ -21,6 +21,7 @@

logger = logging.getLogger(__name__)
_LOGGED = set()
_TRACERS: List[LangChainTracer] = []


def log_error_once(method: str, exception: Exception) -> None:
Expand All @@ -32,6 +33,12 @@ def log_error_once(method: str, exception: Exception) -> None:
logger.error(exception)


def wait_for_all_tracers() -> None:
global _TRACERS
for tracer in _TRACERS:
tracer.wait_for_futures()


class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""

Expand All @@ -52,6 +59,9 @@ def __init__(
# set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1)
self.client = client or LangChainPlusClient()
self._futures: List[Future] = []
global _TRACERS
_TRACERS.append(self)

def on_chat_model_start(
self,
Expand Down Expand Up @@ -93,7 +103,7 @@ def _persist_run_single(self, run: Run) -> None:
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
try:
run = self.client.create_run(**run_dict, session_name=self.session_name)
self.client.create_run(**run_dict, session_name=self.session_name)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("post", e)
Expand All @@ -110,40 +120,65 @@ def _update_run_single(self, run: Run) -> None:

def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)

def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)

def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)

def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)

def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.append(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""
wait(self._futures)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to do something like

futures = [*self._futures]
wait(futures)
for future in futures:
  # remove from self ._futures
  

in case other futures are added here while you're waiting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

self._futures.clear()
14 changes: 12 additions & 2 deletions langchain/client/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,17 @@ async def run_coroutine_with_semaphore(
tracer_queue.put_nowait(tracer)
return result

return await asyncio.gather(
results = await asyncio.gather(
*(run_coroutine_with_semaphore(function) for function in async_funcs)
)
while tracer_queue:
try:
tracer = tracer_queue.get_nowait()
except asyncio.QueueEmpty:
break
if tracer:
tracer.wait_for_futures()
return results


async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]:
Expand Down Expand Up @@ -411,7 +419,9 @@ def run_on_examples(
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result
results[str(example.id)] = result
if tracer:
tracer.wait_for_futures()
return results


Expand Down