diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index b3ec588a8..858d0ab46 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -170,8 +170,6 @@ def evaluate( executor = Executor( desc="Evaluating", keep_progress_bar=True, - is_async=True, - max_workers=max_workers, raise_exceptions=raise_exceptions, ) # new evaluation chain diff --git a/src/ragas/executor.py b/src/ragas/executor.py index 079fca4f1..66b6d9746 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -1,73 +1,38 @@ +from __future__ import annotations + import asyncio +import logging import typing as t -from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field +from threading import Thread import numpy as np from tqdm.auto import tqdm +logger = logging.getLogger(__name__) -@dataclass -class Executor: - desc: str = "Evaluating" - keep_progress_bar: bool = True - is_async: bool = True - max_workers: t.Optional[int] = None - futures: t.List[t.Any] = field(default_factory=list, repr=False) - raise_exceptions: bool = False - _is_new_eventloop: bool = False - - def __post_init__(self): - if self.is_async: - try: - self.executor = asyncio.get_running_loop() - except RuntimeError: - self.executor = asyncio.new_event_loop() - self._is_new_eventloop = True - else: - self.executor = ThreadPoolExecutor(max_workers=self.max_workers) - - def _validation_for_mode(self): - if self.is_async and self.max_workers is not None: - raise ValueError( - "Cannot evaluate with both async and threads. Either set is_async=False or max_workers=None." # noqa - ) - - def wrap_callable_with_index(self, callable: t.Callable, counter): - def wrapped_callable(*args, **kwargs): - return counter, callable(*args, **kwargs) - async def wrapped_callable_async(*args, **kwargs): - return counter, await callable(*args, **kwargs) - - if self.is_async: - return wrapped_callable_async - else: - return wrapped_callable - - def submit( - self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs +class Runner(Thread): + def __init__( + self, + name: str, + jobs: t.List[t.Tuple[t.Coroutine, str]], + desc: str, + keep_progress_bar: bool = True, + raise_exceptions: bool = True, ): - if self.is_async: - self.executor = t.cast(asyncio.AbstractEventLoop, self.executor) - callable_with_index = self.wrap_callable_with_index( - callable, len(self.futures) - ) - # is type correct? - callable_with_index = t.cast(t.Callable, callable_with_index) - self.futures.append( - self.executor.create_task( - callable_with_index(*args, **kwargs), name=name - ) - ) - else: - self.executor = t.cast(ThreadPoolExecutor, self.executor) - callable_with_index = self.wrap_callable_with_index( - callable, len(self.futures) - ) - self.futures.append( - self.executor.submit(callable_with_index, *args, **kwargs) - ) + super().__init__(name=name) + self.jobs = jobs + self.desc = desc + self.keep_progress_bar = keep_progress_bar + self.raise_exceptions = raise_exceptions + self.futures = [] + + # create task + self.loop = asyncio.new_event_loop() + for job in self.jobs: + coroutine, name = job + self.futures.append(self.loop.create_task(coroutine, name=name)) async def _aresults(self) -> t.List[t.Any]: results = [] @@ -88,44 +53,62 @@ async def _aresults(self) -> t.List[t.Any]: return results - def results(self) -> t.List[t.Any]: + def run(self): results = [] - if self.is_async: - self.executor = t.cast(asyncio.AbstractEventLoop, self.executor) - try: - if self._is_new_eventloop: - results = self.executor.run_until_complete(self._aresults()) + try: + results = self.loop.run_until_complete(self._aresults()) + except Exception as e: + if self.raise_exceptions: + raise e + else: + logger.error("Runner in Executor raised an exception", exc_info=True) + results = None + finally: + self.results = results + [f.cancel() for f in self.futures] + self.loop.stop() - # event loop is running use nested_asyncio to hijack the event loop - else: - import nest_asyncio - nest_asyncio.apply() - results = self.executor.run_until_complete(self._aresults()) - finally: - [f.cancel() for f in self.futures] +@dataclass +class Executor: + desc: str = "Evaluating" + keep_progress_bar: bool = True + jobs: t.List[t.Any] = field(default_factory=list, repr=False) + raise_exceptions: bool = False - else: - self.executor = t.cast(ThreadPoolExecutor, self.executor) - try: - for future in tqdm( - as_completed(self.futures), - desc=self.desc, - total=len(self.futures), - # whether you want to keep the progress bar after completion - leave=self.keep_progress_bar, - ): - r = (-1, np.nan) - try: - r = future.result() - except Exception as e: - r = (-1, np.nan) - if self.raise_exceptions: - raise e - finally: - results.append(r) - finally: - self.executor.shutdown(wait=False) - - sorted_results = sorted(results, key=lambda x: x[0]) + def wrap_callable_with_index(self, callable: t.Callable, counter): + async def wrapped_callable_async(*args, **kwargs): + return counter, await callable(*args, **kwargs) + + return wrapped_callable_async + + def submit( + self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs + ): + callable_with_index = self.wrap_callable_with_index(callable, len(self.jobs)) + self.jobs.append((callable_with_index(*args, **kwargs), name)) + + def results(self) -> t.List[t.Any]: + executor_job = Runner( + name="ExecutorRunner", + jobs=self.jobs, + desc=self.desc, + keep_progress_bar=self.keep_progress_bar, + raise_exceptions=self.raise_exceptions, + ) + executor_job.start() + try: + executor_job.join() + finally: + ... + + if executor_job.results is None: + if self.raise_exceptions: + raise RuntimeError( + "Executor failed to complete. Please check logs above for full info." + ) + else: + logger.error("Executor failed to complete. Please check logs above.") + return [] + sorted_results = sorted(executor_job.results, key=lambda x: x[0]) return [r[1] for r in sorted_results] diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 41380b925..0106127bc 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -223,7 +223,6 @@ def add_nodes( executor = Executor( desc="embedding nodes", keep_progress_bar=False, - is_async=True, raise_exceptions=True, ) result_idx = 0 diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index efabd1237..75018d771 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -189,7 +189,6 @@ def generate( desc="Generating", keep_progress_bar=True, raise_exceptions=True, - is_async=True, ) current_nodes = [