Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_current_event_loop,
send_to_queue,
)
from mellea.helpers.event_loop_helper import _run_async_in_thread
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import (
CBlock,
Expand Down Expand Up @@ -404,28 +405,26 @@ def _generate_from_raw(
# See https://github.com/ollama/ollama/blob/main/docs/faq.md#how-does-ollama-handle-concurrent-requests.
prompts = [self.formatter.print(action) for action in actions]

async def get_response(coroutines):
async def get_response():
# Run async so that we can make use of Ollama's concurrency.
coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = []
for prompt in prompts:
co = self._async_client.generate(
model=self._get_ollama_model_id(),
prompt=prompt,
raw=True,
think=model_opts.get(ModelOption.THINKING, None),
format=format.model_json_schema() if format is not None else None,
options=self._make_backend_specific_and_remove(model_opts),
)
coroutines.append(co)

responses = await asyncio.gather(*coroutines, return_exceptions=True)
return responses

async_client = ollama.AsyncClient(self._base_url)
# Run async so that we can make use of Ollama's concurrency.
coroutines = []
for prompt in prompts:
co = async_client.generate(
model=self._get_ollama_model_id(),
prompt=prompt,
raw=True,
think=model_opts.get(ModelOption.THINKING, None),
format=format.model_json_schema() if format is not None else None,
options=self._make_backend_specific_and_remove(model_opts),
)
coroutines.append(co)

# Revisit this once we start using async elsewhere. Only one asyncio event
# loop can be running in a given thread.
responses: list[ollama.GenerateResponse | BaseException] = asyncio.run(
get_response(coroutines)
# Run in the same event_loop like other Mellea async code called from a sync function.
responses: list[ollama.GenerateResponse | BaseException] = _run_async_in_thread(
get_response()
)

results = []
Expand Down
5 changes: 5 additions & 0 deletions mellea/helpers/event_loop_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from collections.abc import Coroutine
from typing import Any, TypeVar

from mellea.helpers.async_helpers import get_current_event_loop

R = TypeVar("R")


Expand Down Expand Up @@ -52,6 +54,9 @@ async def finalize_tasks():

def __call__(self, co: Coroutine[Any, Any, R]) -> R:
"""Runs the coroutine in the event loop."""
if self._event_loop == get_current_event_loop():
# If this gets called from the same event loop, launch in a separate thread to prevent blocking.
return _EventLoopHandler()(co)
return asyncio.run_coroutine_threadsafe(co, self._event_loop).result()


Expand Down
3 changes: 1 addition & 2 deletions test/backends/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ async def get_client_async():

fourth_client = asyncio.run(get_client_async())
assert fourth_client in backend._client_cache.cache.values()
assert second_client not in backend._client_cache.cache.values()

assert len(backend._client_cache.cache.values()) == 2

if __name__ == "__main__":
pytest.main([__file__])
2 changes: 1 addition & 1 deletion test/backends/test_openai_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async def get_client_async():

fourth_client = asyncio.run(get_client_async())
assert fourth_client in backend._client_cache.cache.values()
assert second_client not in backend._client_cache.cache.values()
assert len(backend._client_cache.cache.values()) == 2

if __name__ == "__main__":
import pytest
Expand Down
2 changes: 1 addition & 1 deletion test/backends/test_watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def get_client_async():

fourth_client = asyncio.run(get_client_async())
assert fourth_client in backend._client_cache.cache.values()
assert second_client not in backend._client_cache.cache.values()
assert len(backend._client_cache.cache.values()) == 2

if __name__ == "__main__":
import pytest
Expand Down