Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,6 @@ def evaluate(
executor = Executor(
desc="Evaluating",
keep_progress_bar=True,
is_async=True,
max_workers=max_workers,
raise_exceptions=raise_exceptions,
)
# new evaluation chain
Expand Down
175 changes: 79 additions & 96 deletions src/ragas/executor.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,38 @@
from __future__ import annotations

import asyncio
import logging
import typing as t
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from threading import Thread

import numpy as np
from tqdm.auto import tqdm

logger = logging.getLogger(__name__)

@dataclass
class Executor:
desc: str = "Evaluating"
keep_progress_bar: bool = True
is_async: bool = True
max_workers: t.Optional[int] = None
futures: t.List[t.Any] = field(default_factory=list, repr=False)
raise_exceptions: bool = False
_is_new_eventloop: bool = False

def __post_init__(self):
if self.is_async:
try:
self.executor = asyncio.get_running_loop()
except RuntimeError:
self.executor = asyncio.new_event_loop()
self._is_new_eventloop = True
else:
self.executor = ThreadPoolExecutor(max_workers=self.max_workers)

def _validation_for_mode(self):
if self.is_async and self.max_workers is not None:
raise ValueError(
"Cannot evaluate with both async and threads. Either set is_async=False or max_workers=None." # noqa
)

def wrap_callable_with_index(self, callable: t.Callable, counter):
def wrapped_callable(*args, **kwargs):
return counter, callable(*args, **kwargs)

async def wrapped_callable_async(*args, **kwargs):
return counter, await callable(*args, **kwargs)

if self.is_async:
return wrapped_callable_async
else:
return wrapped_callable

def submit(
self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs
class Runner(Thread):
def __init__(
self,
name: str,
jobs: t.List[t.Tuple[t.Coroutine, str]],
desc: str,
keep_progress_bar: bool = True,
raise_exceptions: bool = True,
):
if self.is_async:
self.executor = t.cast(asyncio.AbstractEventLoop, self.executor)
callable_with_index = self.wrap_callable_with_index(
callable, len(self.futures)
)
# is type correct?
callable_with_index = t.cast(t.Callable, callable_with_index)
self.futures.append(
self.executor.create_task(
callable_with_index(*args, **kwargs), name=name
)
)
else:
self.executor = t.cast(ThreadPoolExecutor, self.executor)
callable_with_index = self.wrap_callable_with_index(
callable, len(self.futures)
)
self.futures.append(
self.executor.submit(callable_with_index, *args, **kwargs)
)
super().__init__(name=name)
self.jobs = jobs
self.desc = desc
self.keep_progress_bar = keep_progress_bar
self.raise_exceptions = raise_exceptions
self.futures = []

# create task
self.loop = asyncio.new_event_loop()
for job in self.jobs:
coroutine, name = job
self.futures.append(self.loop.create_task(coroutine, name=name))

async def _aresults(self) -> t.List[t.Any]:
results = []
Expand All @@ -88,44 +53,62 @@ async def _aresults(self) -> t.List[t.Any]:

return results

def results(self) -> t.List[t.Any]:
def run(self):
results = []
if self.is_async:
self.executor = t.cast(asyncio.AbstractEventLoop, self.executor)
try:
if self._is_new_eventloop:
results = self.executor.run_until_complete(self._aresults())
try:
results = self.loop.run_until_complete(self._aresults())
except Exception as e:
if self.raise_exceptions:
raise e
else:
logger.error("Runner in Executor raised an exception", exc_info=True)
results = None
finally:
self.results = results
[f.cancel() for f in self.futures]
self.loop.stop()

# event loop is running use nested_asyncio to hijack the event loop
else:
import nest_asyncio

nest_asyncio.apply()
results = self.executor.run_until_complete(self._aresults())
finally:
[f.cancel() for f in self.futures]
@dataclass
class Executor:
desc: str = "Evaluating"
keep_progress_bar: bool = True
jobs: t.List[t.Any] = field(default_factory=list, repr=False)
raise_exceptions: bool = False

else:
self.executor = t.cast(ThreadPoolExecutor, self.executor)
try:
for future in tqdm(
as_completed(self.futures),
desc=self.desc,
total=len(self.futures),
# whether you want to keep the progress bar after completion
leave=self.keep_progress_bar,
):
r = (-1, np.nan)
try:
r = future.result()
except Exception as e:
r = (-1, np.nan)
if self.raise_exceptions:
raise e
finally:
results.append(r)
finally:
self.executor.shutdown(wait=False)

sorted_results = sorted(results, key=lambda x: x[0])
def wrap_callable_with_index(self, callable: t.Callable, counter):
async def wrapped_callable_async(*args, **kwargs):
return counter, await callable(*args, **kwargs)

return wrapped_callable_async

def submit(
self, callable: t.Callable, *args, name: t.Optional[str] = None, **kwargs
):
callable_with_index = self.wrap_callable_with_index(callable, len(self.jobs))
self.jobs.append((callable_with_index(*args, **kwargs), name))

def results(self) -> t.List[t.Any]:
executor_job = Runner(
name="ExecutorRunner",
jobs=self.jobs,
desc=self.desc,
keep_progress_bar=self.keep_progress_bar,
raise_exceptions=self.raise_exceptions,
)
executor_job.start()
try:
executor_job.join()
finally:
...

if executor_job.results is None:
if self.raise_exceptions:
raise RuntimeError(
"Executor failed to complete. Please check logs above for full info."
)
else:
logger.error("Executor failed to complete. Please check logs above.")
return []
sorted_results = sorted(executor_job.results, key=lambda x: x[0])
return [r[1] for r in sorted_results]
1 change: 0 additions & 1 deletion src/ragas/testset/docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def add_nodes(
executor = Executor(
desc="embedding nodes",
keep_progress_bar=False,
is_async=True,
raise_exceptions=True,
)
result_idx = 0
Expand Down
1 change: 0 additions & 1 deletion src/ragas/testset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def generate(
desc="Generating",
keep_progress_bar=True,
raise_exceptions=True,
is_async=True,
)

current_nodes = [
Expand Down