Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
convert_tools_to_json,
)
from mellea.backends.types import ModelOption
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.async_helpers import get_current_event_loop, send_to_queue
from mellea.helpers.fancy_logger import FancyLogger
from mellea.helpers.openai_compatible_helpers import (
chat_completion_delta_merge,
Expand Down Expand Up @@ -105,6 +105,8 @@ def __init__(
ModelOption.STREAM: "stream",
}

self._past_event_loops: set[int] = set()

def generate_from_context(
self,
action: Component | CBlock,
Expand Down Expand Up @@ -270,6 +272,11 @@ def _generate_from_chat_context_standard(

model_specific_options = self._make_backend_specific_and_remove(model_opts)

if self._has_potential_event_loop_errors():
FancyLogger().get_logger().warning(
"There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call."
)

chat_response: Coroutine[
Any, Any, litellm.ModelResponse | litellm.ModelResponseStream # type: ignore
] = litellm.acompletion(
Expand Down Expand Up @@ -488,3 +495,24 @@ def _extract_model_tool_requests(
if len(model_tool_calls) > 0:
return model_tool_calls
return None

def _has_potential_event_loop_errors(self) -> bool:
"""In some cases litellm doesn't create a new async client. There doesn't appear to be any way for us to force that behavior. As a result, log a warning for known cases.

This whole function can be removed once the bug is fixed: https://github.com/BerriAI/litellm/issues/15294.
"""
# Async clients are tied to event loops.
key = id(get_current_event_loop())

has_potential_issue = False
if (
len(self._past_event_loops) > 0
and key not in self._past_event_loops
and "watsonx/" in str(self.model_id)
):
has_potential_issue = True

# Add this loop to the known set.
self._past_event_loops.add(key)

return has_potential_issue
27 changes: 22 additions & 5 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
add_tools_from_model_options,
)
from mellea.backends.types import ModelOption
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.async_helpers import (
ClientCache,
get_current_event_loop,
send_to_queue,
)
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import (
CBlock,
Expand Down Expand Up @@ -69,6 +73,11 @@ def __init__(
self._base_url = base_url
self._client = ollama.Client(base_url)

self._client_cache = ClientCache(2)

# Call once to set up an async client and prepopulate the cache.
_ = self._async_client

if not self._check_ollama_server():
err = f"could not create OllamaModelBackend: ollama server not running at {base_url}"
FancyLogger.get_logger().error(err)
Expand Down Expand Up @@ -181,6 +190,17 @@ def _pull_ollama_model(self) -> bool:
except ollama.ResponseError:
return False

@property
def _async_client(self) -> ollama.AsyncClient:
"""Ollama's client gets tied to a specific event loop. Reset it if needed here."""
key = id(get_current_event_loop())

_async_client = self._client_cache.get(key)
if _async_client is None:
_async_client = ollama.AsyncClient(self._base_url)
self._client_cache.put(key, _async_client)
return _async_client

def _simplify_and_merge(
self, model_options: dict[str, Any] | None
) -> dict[str, Any]:
Expand Down Expand Up @@ -318,13 +338,10 @@ def generate_from_chat_context(
add_tools_from_context_actions(tools, [action])
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")

# Ollama ties its async client to an event loop so we have to create it here.
async_client = ollama.AsyncClient(self._base_url)

# Generate a chat response from ollama, using the chat messages. Can be either type since stream is passed as a model option.
chat_response: Coroutine[
Any, Any, AsyncIterator[ollama.ChatResponse] | ollama.ChatResponse
] = async_client.chat(
] = self._async_client.chat(
model=self._get_ollama_model_id(),
messages=conversation,
tools=list(tools.values()),
Expand Down
33 changes: 27 additions & 6 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
convert_tools_to_json,
)
from mellea.backends.types import ModelOption
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.async_helpers import (
ClientCache,
get_current_event_loop,
send_to_queue,
)
from mellea.helpers.fancy_logger import FancyLogger
from mellea.helpers.openai_compatible_helpers import (
chat_completion_delta_merge,
Expand Down Expand Up @@ -164,18 +168,35 @@ def __init__(
else:
self._api_key = api_key

openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)
self._openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)

self._client = openai.OpenAI( # type: ignore
api_key=self._api_key, base_url=self._base_url, **openai_client_kwargs
)
self._async_client = openai.AsyncOpenAI(
api_key=self._api_key, base_url=self._base_url, **openai_client_kwargs
api_key=self._api_key, base_url=self._base_url, **self._openai_client_kwargs
)

self._client_cache = ClientCache(2)

# Call once to create an async_client and populate the cache.
_ = self._async_client

# ALoras that have been loaded for this model.
self._aloras: dict[str, OpenAIAlora] = {}

@property
def _async_client(self) -> openai.AsyncOpenAI:
"""OpenAI's client usually handles changing event loops but explicitly handle it here for edge cases."""
key = id(get_current_event_loop())

_async_client = self._client_cache.get(key)
if _async_client is None:
_async_client = openai.AsyncOpenAI(
api_key=self._api_key,
base_url=self._base_url,
**self._openai_client_kwargs,
)
self._client_cache.put(key, _async_client)
return _async_client

@staticmethod
def filter_openai_client_kwargs(**kwargs) -> dict:
"""Filter kwargs to only include valid OpenAI client parameters."""
Expand Down
47 changes: 27 additions & 20 deletions mellea/backends/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
convert_tools_to_json,
)
from mellea.backends.types import ModelOption
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.async_helpers import (
ClientCache,
get_current_event_loop,
send_to_queue,
)
from mellea.helpers.fancy_logger import FancyLogger
from mellea.helpers.openai_compatible_helpers import (
chat_completion_delta_merge,
Expand Down Expand Up @@ -93,15 +97,12 @@ def __init__(
self._project_id = os.environ.get("WATSONX_PROJECT_ID")

self._creds = Credentials(url=base_url, api_key=api_key)
_client = APIClient(credentials=self._creds)
self._model_inference = ModelInference(
model_id=self._get_watsonx_model_id(),
api_client=_client,
credentials=self._creds,
project_id=self._project_id,
params=self.model_options,
**kwargs,
)
self._kwargs = kwargs

self._client_cache = ClientCache(2)

# Call once to set up the model inference and prepopulate the cache.
_ = self._model

# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
# These are usually values that must be extracted before hand or that are common among backend providers.
Expand Down Expand Up @@ -134,16 +135,22 @@ def __init__(

@property
def _model(self) -> ModelInference:
"""Watsonx's client gets tied to a specific event loop. Reset it here."""
_client = APIClient(credentials=self._creds)
self._model_inference = ModelInference(
model_id=self._get_watsonx_model_id(),
api_client=_client,
credentials=self._creds,
project_id=self._project_id,
params=self.model_options,
)
return self._model_inference
"""Watsonx's client gets tied to a specific event loop. Reset it if needed here."""
key = id(get_current_event_loop())

_model_inference = self._client_cache.get(key)
if _model_inference is None:
_client = APIClient(credentials=self._creds)
_model_inference = ModelInference(
model_id=self._get_watsonx_model_id(),
api_client=_client,
credentials=self._creds,
project_id=self._project_id,
params=self.model_options,
**self._kwargs,
)
self._client_cache.put(key, _model_inference)
return _model_inference

def _get_watsonx_model_id(self) -> str:
"""Gets the watsonx model id from the model_id that was provided in the constructor. Raises AssertionError if the ModelIdentifier does not provide a watsonx_name."""
Expand Down
53 changes: 52 additions & 1 deletion mellea/helpers/async_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Async helper functions."""

import asyncio
from collections import OrderedDict
from collections.abc import AsyncIterator, Coroutine
from typing import Any
from typing import Any, TypeVar

from mellea.stdlib.base import ModelOutputThunk

Expand Down Expand Up @@ -46,3 +47,53 @@ async def wait_for_all_mots(mots: list[ModelOutputThunk]):
coroutines.append(mot.avalue())

await asyncio.gather(*coroutines)


def get_current_event_loop() -> None | asyncio.AbstractEventLoop:
"""Get the current event loop without having to catch exceptions."""
loop = None
try:
loop = asyncio.get_running_loop()
except RuntimeError:
pass
return loop


class ClientCache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just use the inbuilt LRU here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that inbuilt LRU is for caching function results. It seems like it could be used here but may be somewhat clunky? How would you go about using it here since we have to pass in a key to get?

I guess we would have to do something like:

@functools.lru_cache
def _get_client(self, key):
  return Client()

def _client(self):
  key = id(event_loop)
  return self._get_client(key)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this out more. I think I would like to keep the separate client cache. The functools version is much harder to debug / test, and I think the multiple functions makes it less straightforward what is happening. If you feel strongly about this, I am willing to implement it using the functools.lru_cache though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that inbuilt LRU is for caching function results. It seems like it could be used here but may be somewhat clunky? How would you go about using it here since we have to pass in a key to get?

I guess we would have to do something like:

@functools.lru_cache
def _get_client(self, key):
  return Client()

def _client(self):
  key = id(event_loop)
  return self._get_client(key)

Ah, for some reason I thought that it could cache both objects (with @cached_property). But I agree with re: debugging efficiency.

"""A simple [LRU](https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_Recently_Used_(LRU)) cache.

Used to keep track of clients for backends where the client is tied to a specific event loop.
"""

def __init__(self, capacity: int):
"""Initializes the LRU cache with a certain capacity.

The `ClientCache` either contains a value or it doesn't.
"""
self.capacity = capacity
self.cache: OrderedDict = OrderedDict()

def current_size(self):
"""Just return the size of the key set. This isn't necessarily safe."""
return len(self.cache.keys())

def get(self, key: int) -> Any | None:
"""Gets a value from the cache."""
if key not in self.cache:
return None
else:
# Move the accessed item to the end (most recent)
value = self.cache.pop(key)
self.cache[key] = value
return value

def put(self, key: int, value: Any):
"""Put a value into the cache."""
if key in self.cache:
# If the key exists, move it to the end (most recent)
self.cache.pop(key)
elif len(self.cache) >= self.capacity:
# If the cache is full, remove the least recently used item
self.cache.popitem(last=False)
# Add the new key-value pair to the end (most recent)
self.cache[key] = value
86 changes: 86 additions & 0 deletions mellea/helpers/event_loop_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Helper for event loop management. Allows consistently running async generate requests in sync code."""

import asyncio
import threading
from collections.abc import Coroutine
from typing import Any, TypeVar

R = TypeVar("R")


class _EventLoopHandler:
"""A class that handles the event loop for Mellea code. Do not directly instantiate this. Use `_run_async_in_thread`."""

def __init__(self):
"""Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea.

Do not instantiate this class. Rely on the exported `_run_async_in_thread` function.
"""
self._event_loop = asyncio.new_event_loop()
self._thread: threading.Thread = threading.Thread(
target=self._event_loop.run_forever, daemon=True
)
self._thread.start()

def __del__(self):
"""Delete the event loop handler."""
self._close_event_loop()

def _close_event_loop(self) -> None:
"""Called when deleting the event loop handler. Cleans up the event loop and thread."""
if self._event_loop:
try:
tasks = asyncio.all_tasks(self._event_loop)
for task in tasks:
task.cancel()

async def finalize_tasks():
# TODO: We can log errors here if needed.
await asyncio.gather(*tasks, return_exceptions=True)

out = asyncio.run_coroutine_threadsafe(
finalize_tasks(), self._event_loop
)

# Timeout if needed.
out.result(5)
except Exception:
pass

# Finally stop the event loop for this session.
self._event_loop.stop()

def __call__(self, co: Coroutine[Any, Any, R]) -> R:
"""Runs the coroutine in the event loop."""
return asyncio.run_coroutine_threadsafe(co, self._event_loop).result()


# Instantiate this class once. It will not be re-instantiated.
__event_loop_handler = _EventLoopHandler()


def _run_async_in_thread(co: Coroutine[Any, Any, R]) -> R:
"""Call to run async code from synchronous code in Mellea.

In Mellea, we utilize async code underneath sync code to speed up
inference requests. This puts us in a difficult situation since most
api providers and sdks use async clients that get bound to a specific event
loop to make requests. These clients are typically long-lasting and sometimes
cannot be easily reinstantiated on demand to avoid these issues.
By declaring a single event loop for these async requests,
Mellea avoids these client issues.

Note: This implementation requires that sessions/backends be run only through
the top-level / session sync or async interfaces, not both. You will need to
reinstantiate your backend if switching between the two.

Args:
co: coroutine to run

Returns:
output of the coroutine
"""
return __event_loop_handler(co)


__all__ = ["_run_async_in_thread"]
Loading