From a85fef8b350171fe92ecf931d5068095fe8b9e53 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Wed, 31 Jan 2024 22:50:30 -0800 Subject: [PATCH 1/5] cleaned up ThreadpoolExecutor --- src/ragas/executor.py | 101 +++++++++--------------------------------- 1 file changed, 21 insertions(+), 80 deletions(-) diff --git a/src/ragas/executor.py b/src/ragas/executor.py index 079fca4f1..062a0eb04 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -11,63 +11,33 @@ class Executor: desc: str = "Evaluating" keep_progress_bar: bool = True - is_async: bool = True - max_workers: t.Optional[int] = None futures: t.List[t.Any] = field(default_factory=list, repr=False) raise_exceptions: bool = False _is_new_eventloop: bool = False def __post_init__(self): - if self.is_async: - try: - self.executor = asyncio.get_running_loop() - except RuntimeError: - self.executor = asyncio.new_event_loop() - self._is_new_eventloop = True - else: - self.executor = ThreadPoolExecutor(max_workers=self.max_workers) - - def _validation_for_mode(self): - if self.is_async and self.max_workers is not None: - raise ValueError( - "Cannot evaluate with both async and threads. Either set is_async=False or max_workers=None." # noqa - ) + try: + self.executor = asyncio.get_running_loop() + except RuntimeError: + self.executor = asyncio.new_event_loop() + self._is_new_eventloop = True def wrap_callable_with_index(self, callable: t.Callable, counter): - def wrapped_callable(*args, **kwargs): - return counter, callable(*args, **kwargs) - async def wrapped_callable_async(*args, **kwargs): return counter, await callable(*args, **kwargs) - if self.is_async: - return wrapped_callable_async - else: - return wrapped_callable + return wrapped_callable_async def submit( self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs ): - if self.is_async: - self.executor = t.cast(asyncio.AbstractEventLoop, self.executor) - callable_with_index = self.wrap_callable_with_index( - callable, len(self.futures) - ) - # is type correct? - callable_with_index = t.cast(t.Callable, callable_with_index) - self.futures.append( - self.executor.create_task( - callable_with_index(*args, **kwargs), name=name - ) - ) - else: - self.executor = t.cast(ThreadPoolExecutor, self.executor) - callable_with_index = self.wrap_callable_with_index( - callable, len(self.futures) - ) - self.futures.append( - self.executor.submit(callable_with_index, *args, **kwargs) - ) + self.executor = t.cast(asyncio.AbstractEventLoop, self.executor) + callable_with_index = self.wrap_callable_with_index(callable, len(self.futures)) + # is type correct? + callable_with_index = t.cast(t.Callable, callable_with_index) + self.futures.append( + self.executor.create_task(callable_with_index(*args, **kwargs), name=name) + ) async def _aresults(self) -> t.List[t.Any]: results = [] @@ -90,42 +60,13 @@ async def _aresults(self) -> t.List[t.Any]: def results(self) -> t.List[t.Any]: results = [] - if self.is_async: - self.executor = t.cast(asyncio.AbstractEventLoop, self.executor) - try: - if self._is_new_eventloop: - results = self.executor.run_until_complete(self._aresults()) - - # event loop is running use nested_asyncio to hijack the event loop - else: - import nest_asyncio - - nest_asyncio.apply() - results = self.executor.run_until_complete(self._aresults()) - finally: - [f.cancel() for f in self.futures] - - else: - self.executor = t.cast(ThreadPoolExecutor, self.executor) - try: - for future in tqdm( - as_completed(self.futures), - desc=self.desc, - total=len(self.futures), - # whether you want to keep the progress bar after completion - leave=self.keep_progress_bar, - ): - r = (-1, np.nan) - try: - r = future.result() - except Exception as e: - r = (-1, np.nan) - if self.raise_exceptions: - raise e - finally: - results.append(r) - finally: - self.executor.shutdown(wait=False) - + self.executor = t.cast(asyncio.AbstractEventLoop, self.executor) + try: + if self._is_new_eventloop: + results = self.executor.run_until_complete(self._aresults()) + else: + results = self.executor.create_task(self._aresults()) + finally: + [f.cancel() for f in self.futures] sorted_results = sorted(results, key=lambda x: x[0]) return [r[1] for r in sorted_results] From aacee3bef2cdcee38fad8dd4c6dd72bf0774b8a2 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Thu, 1 Feb 2024 00:09:16 -0800 Subject: [PATCH 2/5] remove cleanup from everywhere else --- src/ragas/evaluation.py | 2 -- src/ragas/testset/docstore.py | 1 - src/ragas/testset/generator.py | 1 - 3 files changed, 4 deletions(-) diff --git a/src/ragas/evaluation.py b/src/ragas/evaluation.py index b3ec588a8..858d0ab46 100644 --- a/src/ragas/evaluation.py +++ b/src/ragas/evaluation.py @@ -170,8 +170,6 @@ def evaluate( executor = Executor( desc="Evaluating", keep_progress_bar=True, - is_async=True, - max_workers=max_workers, raise_exceptions=raise_exceptions, ) # new evaluation chain diff --git a/src/ragas/testset/docstore.py b/src/ragas/testset/docstore.py index 41380b925..0106127bc 100644 --- a/src/ragas/testset/docstore.py +++ b/src/ragas/testset/docstore.py @@ -223,7 +223,6 @@ def add_nodes( executor = Executor( desc="embedding nodes", keep_progress_bar=False, - is_async=True, raise_exceptions=True, ) result_idx = 0 diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index efabd1237..75018d771 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -189,7 +189,6 @@ def generate( desc="Generating", keep_progress_bar=True, raise_exceptions=True, - is_async=True, ) current_nodes = [ From 94374477260d9b5cf30c93f5fe9b9dd38061b720 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Thu, 1 Feb 2024 01:20:03 -0800 Subject: [PATCH 3/5] run in new thread --- src/ragas/executor.py | 112 ++++++++++++++++++++++++++++-------------- 1 file changed, 76 insertions(+), 36 deletions(-) diff --git a/src/ragas/executor.py b/src/ragas/executor.py index 062a0eb04..7464eb209 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -1,43 +1,36 @@ import asyncio import typing as t -from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field +from threading import Thread +import logging import numpy as np from tqdm.auto import tqdm +logger = logging.getLogger(__name__) -@dataclass -class Executor: - desc: str = "Evaluating" - keep_progress_bar: bool = True - futures: t.List[t.Any] = field(default_factory=list, repr=False) - raise_exceptions: bool = False - _is_new_eventloop: bool = False - - def __post_init__(self): - try: - self.executor = asyncio.get_running_loop() - except RuntimeError: - self.executor = asyncio.new_event_loop() - self._is_new_eventloop = True - - def wrap_callable_with_index(self, callable: t.Callable, counter): - async def wrapped_callable_async(*args, **kwargs): - return counter, await callable(*args, **kwargs) - return wrapped_callable_async - - def submit( - self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs +class Runner(Thread): + def __init__( + self, + name: str, + jobs: list[t.Tuple[t.Coroutine, str]], + desc: str, + keep_progress_bar: bool = True, + raise_exceptions: bool = True, ): - self.executor = t.cast(asyncio.AbstractEventLoop, self.executor) - callable_with_index = self.wrap_callable_with_index(callable, len(self.futures)) - # is type correct? - callable_with_index = t.cast(t.Callable, callable_with_index) - self.futures.append( - self.executor.create_task(callable_with_index(*args, **kwargs), name=name) - ) + super().__init__(name=name) + self.jobs = jobs + self.desc = desc + self.keep_progress_bar = keep_progress_bar + self.raise_exceptions = raise_exceptions + self.futures = [] + + # create task + self.loop = asyncio.new_event_loop() + for job in self.jobs: + coroutine, name = job + self.futures.append(self.loop.create_task(coroutine, name=name)) async def _aresults(self) -> t.List[t.Any]: results = [] @@ -58,15 +51,62 @@ async def _aresults(self) -> t.List[t.Any]: return results - def results(self) -> t.List[t.Any]: + def run(self): results = [] - self.executor = t.cast(asyncio.AbstractEventLoop, self.executor) try: - if self._is_new_eventloop: - results = self.executor.run_until_complete(self._aresults()) + results = self.loop.run_until_complete(self._aresults()) + except Exception as e: + if self.raise_exceptions: + raise e else: - results = self.executor.create_task(self._aresults()) + logger.error("Runner in Executor raised an exception", exc_info=True) + results = None finally: + self.results = results [f.cancel() for f in self.futures] - sorted_results = sorted(results, key=lambda x: x[0]) + self.loop.stop() + + +@dataclass +class Executor: + desc: str = "Evaluating" + keep_progress_bar: bool = True + jobs: t.List[t.Any] = field(default_factory=list, repr=False) + raise_exceptions: bool = False + + def wrap_callable_with_index(self, callable: t.Callable, counter): + async def wrapped_callable_async(*args, **kwargs): + return counter, await callable(*args, **kwargs) + + return wrapped_callable_async + + def submit( + self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs + ): + callable_with_index = self.wrap_callable_with_index(callable, len(self.jobs)) + self.jobs.append((callable_with_index(*args, **kwargs), name)) + + def results(self) -> t.List[t.Any]: + executor_job = Runner( + name="ExecutorRunner", + jobs=self.jobs, + desc=self.desc, + keep_progress_bar=self.keep_progress_bar, + raise_exceptions=self.raise_exceptions, + ) + executor_job.start() + try: + executor_job.join() + finally: + ... + + if executor_job.results is None: + if self.raise_exceptions: + raise RuntimeError( + "Executor failed to complete. Please check logs above for full info." + ) + else: + logger.error("Executor failed to complete. Please check logs above.") + return [] + sorted_results = sorted(executor_job.results, key=lambda x: x[0]) return [r[1] for r in sorted_results] From 6e7bfd0896771b7258f7aaf07d6e2ba195301251 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Thu, 1 Feb 2024 01:20:28 -0800 Subject: [PATCH 4/5] fmt --- src/ragas/executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ragas/executor.py b/src/ragas/executor.py index 7464eb209..a4c2e514f 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -1,8 +1,8 @@ import asyncio +import logging import typing as t from dataclasses import dataclass, field from threading import Thread -import logging import numpy as np from tqdm.auto import tqdm From e0098da8e816fe53247c117668c407e8885accc9 Mon Sep 17 00:00:00 2001 From: jjmachan Date: Thu, 1 Feb 2024 01:24:48 -0800 Subject: [PATCH 5/5] fix fmt --- src/ragas/executor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ragas/executor.py b/src/ragas/executor.py index a4c2e514f..66b6d9746 100644 --- a/src/ragas/executor.py +++ b/src/ragas/executor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import logging import typing as t @@ -14,7 +16,7 @@ class Runner(Thread): def __init__( self, name: str, - jobs: list[t.Tuple[t.Coroutine, str]], + jobs: t.List[t.Tuple[t.Coroutine, str]], desc: str, keep_progress_bar: bool = True, raise_exceptions: bool = True,