diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1cdc942afea..b45de32dddd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -44,7 +44,7 @@ jobs: - name: Install packages and dependencies run: | python -m pip install --upgrade pip wheel - pip install -e . + pip install -e .[cosmosdb] python -c "import autogen" pip install pytest mock - name: Install optional dependencies for code executors @@ -67,12 +67,16 @@ jobs: if: matrix.python-version != '3.10' && matrix.os != 'ubuntu-latest' run: | pytest test --ignore=test/agentchat/contrib --skip-openai --skip-docker --durations=10 --durations-min=1.0 - - name: Coverage + - name: Coverage with Redis if: matrix.python-version == '3.10' run: | pip install -e .[test,redis,websockets] coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0 coverage xml + - name: Test with Cosmos DB + run: | + pip install -e .[test,cosmosdb] + coverage run -a -m pytest test/cache/test_cosmos_db_cache.py --skip-openai --durations=10 --durations-min=1.0 - name: Upload coverage to Codecov if: matrix.python-version == '3.10' uses: codecov/codecov-action@v3 diff --git a/autogen/cache/cache.py b/autogen/cache/cache.py index 0770079f295..6a15d993ff6 100644 --- a/autogen/cache/cache.py +++ b/autogen/cache/cache.py @@ -2,7 +2,7 @@ import sys from types import TracebackType -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, TypedDict, Union from .abstract_cache_base import AbstractCache from .cache_factory import CacheFactory @@ -26,7 +26,12 @@ class Cache(AbstractCache): cache: The cache instance created based on the provided configuration. """ - ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"] + ALLOWED_CONFIG_KEYS = [ + "cache_seed", + "redis_url", + "cache_path_root", + "cosmos_db_config", + ] @staticmethod def redis(cache_seed: Union[str, int] = 42, redis_url: str = "redis://localhost:6379/0") -> "Cache": @@ -56,6 +61,32 @@ def disk(cache_seed: Union[str, int] = 42, cache_path_root: str = ".cache") -> " """ return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root}) + @staticmethod + def cosmos_db( + connection_string: Optional[str] = None, + container_id: Optional[str] = None, + cache_seed: Union[str, int] = 42, + client: Optional[any] = None, + ) -> "Cache": + """ + Create a Cosmos DB cache instance with 'autogen_cache' as database ID. + + Args: + connection_string (str, optional): Connection string to the Cosmos DB account. + container_id (str, optional): The container ID for the Cosmos DB account. + cache_seed (Union[str, int], optional): A seed for the cache. + client: Optional[CosmosClient]: Pass an existing Cosmos DB client. + Returns: + Cache: A Cache instance configured for Cosmos DB. + """ + cosmos_db_config = { + "connection_string": connection_string, + "database_id": "autogen_cache", + "container_id": container_id, + "client": client, + } + return Cache({"cache_seed": str(cache_seed), "cosmos_db_config": cosmos_db_config}) + def __init__(self, config: Dict[str, Any]): """ Initialize the Cache with the given configuration. @@ -69,15 +100,19 @@ def __init__(self, config: Dict[str, Any]): ValueError: If an invalid configuration key is provided. """ self.config = config + # Ensure that the seed is always treated as a string before being passed to any cache factory or stored. + self.config["cache_seed"] = str(self.config.get("cache_seed", 42)) + # validate config for key in self.config.keys(): if key not in self.ALLOWED_CONFIG_KEYS: raise ValueError(f"Invalid config key: {key}") # create cache instance self.cache = CacheFactory.cache_factory( - self.config.get("cache_seed", "42"), - self.config.get("redis_url", None), - self.config.get("cache_path_root", None), + seed=self.config["cache_seed"], + redis_url=self.config.get("redis_url"), + cache_path_root=self.config.get("cache_path_root"), + cosmosdb_config=self.config.get("cosmos_db_config"), ) def __enter__(self) -> "Cache": diff --git a/autogen/cache/cache_factory.py b/autogen/cache/cache_factory.py index 8fc4713f06e..437893570b4 100644 --- a/autogen/cache/cache_factory.py +++ b/autogen/cache/cache_factory.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Union +from typing import Any, Dict, Optional, Union from .abstract_cache_base import AbstractCache from .disk_cache import DiskCache @@ -8,25 +8,28 @@ class CacheFactory: @staticmethod def cache_factory( - seed: Union[str, int], redis_url: Optional[str] = None, cache_path_root: str = ".cache" + seed: Union[str, int], + redis_url: Optional[str] = None, + cache_path_root: str = ".cache", + cosmosdb_config: Optional[Dict[str, Any]] = None, ) -> AbstractCache: """ Factory function for creating cache instances. - Based on the provided redis_url, this function decides whether to create a RedisCache - or DiskCache instance. If RedisCache is available and redis_url is provided, - a RedisCache instance is created. Otherwise, a DiskCache instance is used. + This function decides whether to create a RedisCache, DiskCache, or CosmosDBCache instance + based on the provided parameters. If RedisCache is available and a redis_url is provided, + a RedisCache instance is created. If connection_string, database_id, and container_id + are provided, a CosmosDBCache is created. Otherwise, a DiskCache instance is used. Args: - seed (Union[str, int]): A string or int used as a seed or namespace for the cache. - This could be useful for creating distinct cache instances - or for namespacing keys in the cache. - redis_url (str or None): The URL for the Redis server. If this is None - or if RedisCache is not available, a DiskCache instance is created. + seed (Union[str, int]): Used as a seed or namespace for the cache. + redis_url (Optional[str]): URL for the Redis server. + cache_path_root (str): Root path for the disk cache. + cosmosdb_config (Optional[Dict[str, str]]): Dictionary containing 'connection_string', + 'database_id', and 'container_id' for Cosmos DB cache. Returns: - An instance of either RedisCache or DiskCache, depending on the availability of RedisCache - and the provided redis_url. + An instance of RedisCache, DiskCache, or CosmosDBCache. Examples: @@ -40,14 +43,35 @@ def cache_factory( ```python disk_cache = cache_factory("myseed", None) ``` + + Creating a Cosmos DB cache: + ```python + cosmos_cache = cache_factory("myseed", cosmosdb_config={ + "connection_string": "your_connection_string", + "database_id": "your_database_id", + "container_id": "your_container_id"} + ) + ``` + """ - if redis_url is not None: + if redis_url: try: from .redis_cache import RedisCache return RedisCache(seed, redis_url) except ImportError: - logging.warning("RedisCache is not available. Creating a DiskCache instance instead.") - return DiskCache(f"./{cache_path_root}/{seed}") - else: - return DiskCache(f"./{cache_path_root}/{seed}") + logging.warning( + "RedisCache is not available. Checking other cache options. The last fallback is DiskCache." + ) + + if cosmosdb_config: + try: + from .cosmos_db_cache import CosmosDBCache + + return CosmosDBCache.create_cache(seed, cosmosdb_config) + + except ImportError: + logging.warning("CosmosDBCache is not available. Fallback to DiskCache.") + + # Default to DiskCache if neither Redis nor Cosmos DB configurations are provided + return DiskCache(f"./{cache_path_root}/{seed}") diff --git a/autogen/cache/cosmos_db_cache.py b/autogen/cache/cosmos_db_cache.py new file mode 100644 index 00000000000..b85be923c2f --- /dev/null +++ b/autogen/cache/cosmos_db_cache.py @@ -0,0 +1,144 @@ +# Install Azure Cosmos DB SDK if not already + +import pickle +from typing import Any, Optional, TypedDict, Union + +from azure.cosmos import CosmosClient, PartitionKey, exceptions +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +from autogen.cache.abstract_cache_base import AbstractCache + + +class CosmosDBConfig(TypedDict, total=False): + connection_string: str + database_id: str + container_id: str + cache_seed: Optional[Union[str, int]] + client: Optional[CosmosClient] + + +class CosmosDBCache(AbstractCache): + """ + Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API. + + This class provides a concrete implementation of the AbstractCache + interface using Azure Cosmos DB for caching data, with synchronous operations. + + Attributes: + seed (Union[str, int]): A seed or namespace used as a partition key. + client (CosmosClient): The Cosmos DB client used for caching. + container: The container instance used for caching. + """ + + def __init__(self, seed: Union[str, int], cosmosdb_config: CosmosDBConfig): + """ + Initialize the CosmosDBCache instance. + + Args: + seed (Union[str, int]): A seed or namespace for the cache, used as a partition key. + connection_string (str): The connection string for the Cosmos DB account. + container_id (str): The container ID to be used for caching. + client (Optional[CosmosClient]): An existing CosmosClient instance to be used for caching. + """ + self.seed = str(seed) + self.client = cosmosdb_config.get("client") or CosmosClient.from_connection_string( + cosmosdb_config["connection_string"] + ) + database_id = cosmosdb_config.get("database_id", "autogen_cache") + self.database = self.client.get_database_client(database_id) + container_id = cosmosdb_config.get("container_id") + self.container = self.database.create_container_if_not_exists( + id=container_id, partition_key=PartitionKey(path="/partitionKey") + ) + + @classmethod + def create_cache(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig): + """ + Factory method to create a CosmosDBCache instance based on the provided configuration. + This method decides whether to use an existing CosmosClient or create a new one. + """ + if "client" in cosmosdb_config and isinstance(cosmosdb_config["client"], CosmosClient): + return cls.from_existing_client(seed, **cosmosdb_config) + else: + return cls.from_config(seed, cosmosdb_config) + + @classmethod + def from_config(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig): + return cls(str(seed), cosmosdb_config) + + @classmethod + def from_connection_string(cls, seed: Union[str, int], connection_string: str, database_id: str, container_id: str): + config = {"connection_string": connection_string, "database_id": database_id, "container_id": container_id} + return cls(str(seed), config) + + @classmethod + def from_existing_client(cls, seed: Union[str, int], client: CosmosClient, database_id: str, container_id: str): + config = {"client": client, "database_id": database_id, "container_id": container_id} + return cls(str(seed), config) + + def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: + """ + Retrieve an item from the Cosmos DB cache. + + Args: + key (str): The key identifying the item in the cache. + default (optional): The default value to return if the key is not found. + + Returns: + The deserialized value associated with the key if found, else the default value. + """ + try: + response = self.container.read_item(item=key, partition_key=str(self.seed)) + return pickle.loads(response["data"]) + except CosmosResourceNotFoundError: + return default + except Exception as e: + # Log the exception or rethrow after logging if needed + # Consider logging or handling the error appropriately here + raise e + + def set(self, key: str, value: Any) -> None: + """ + Set an item in the Cosmos DB cache. + + Args: + key (str): The key under which the item is to be stored. + value: The value to be stored in the cache. + + Notes: + The value is serialized using pickle before being stored. + """ + try: + serialized_value = pickle.dumps(value) + item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value} + self.container.upsert_item(item) + except Exception as e: + # Log or handle exception + raise e + + def close(self) -> None: + """ + Close the Cosmos DB client. + + Perform any necessary cleanup, such as closing network connections. + """ + # CosmosClient doesn"t require explicit close in the current SDK + # If you created the client inside this class, you should close it if necessary + pass + + def __enter__(self): + """ + Context management entry. + + Returns: + self: The instance itself. + """ + return self + + def __exit__(self, exc_type: Optional[type], exc_value: Optional[Exception], traceback: Optional[Any]) -> None: + """ + Context management exit. + + Perform cleanup actions such as closing the Cosmos DB client. + """ + self.close() diff --git a/test/cache/test_cache.py b/test/cache/test_cache.py index 45043ccc9e7..d01b1cf4952 100755 --- a/test/cache/test_cache.py +++ b/test/cache/test_cache.py @@ -1,55 +1,103 @@ #!/usr/bin/env python3 -m pytest import unittest -from unittest.mock import MagicMock, patch +from typing import Optional, TypedDict, Union +from unittest.mock import ANY, MagicMock, patch + +try: + from azure.cosmos import CosmosClient +except ImportError: + CosmosClient = None from autogen.cache.cache import Cache +from autogen.cache.cosmos_db_cache import CosmosDBConfig class TestCache(unittest.TestCase): def setUp(self): - self.config = {"cache_seed": "test_seed", "redis_url": "redis://test", "cache_path_root": ".test_cache"} + self.redis_config = { + "cache_seed": "test_seed", + "redis_url": "redis://test", + "cache_path_root": ".test_cache", + } + self.cosmos_config = { + "cosmos_db_config": { + "connection_string": "AccountEndpoint=https://example.documents.azure.com:443/;", + "database_id": "autogen_cache", + "container_id": "TestContainer", + "cache_seed": "42", + "client": MagicMock(spec=CosmosClient), + } + } @patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock()) - def test_init(self, mock_cache_factory): - cache = Cache(self.config) + def test_redis_cache_initialization(self, mock_cache_factory): + cache = Cache(self.redis_config) self.assertIsInstance(cache.cache, MagicMock) - mock_cache_factory.assert_called_with("test_seed", "redis://test", ".test_cache") + mock_cache_factory.assert_called() @patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock()) - def test_context_manager(self, mock_cache_factory): + def test_cosmosdb_cache_initialization(self, mock_cache_factory): + cache = Cache(self.cosmos_config) + self.assertIsInstance(cache.cache, MagicMock) + mock_cache_factory.assert_called_with( + seed="42", + redis_url=None, + cache_path_root=None, + cosmosdb_config={ + "connection_string": "AccountEndpoint=https://example.documents.azure.com:443/;", + "database_id": "autogen_cache", + "container_id": "TestContainer", + "cache_seed": "42", + "client": ANY, + }, + ) + + def context_manager_common(self, config): mock_cache_instance = MagicMock() - mock_cache_factory.return_value = mock_cache_instance + with patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=mock_cache_instance): + with Cache(config) as cache: + self.assertIsInstance(cache, MagicMock) - with Cache(self.config) as cache: - self.assertIsInstance(cache, MagicMock) + mock_cache_instance.__enter__.assert_called() + mock_cache_instance.__exit__.assert_called() - mock_cache_instance.__enter__.assert_called() - mock_cache_instance.__exit__.assert_called() + def test_redis_context_manager(self): + self.context_manager_common(self.redis_config) - @patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock()) - def test_get_set(self, mock_cache_factory): + def test_cosmos_context_manager(self): + self.context_manager_common(self.cosmos_config) + + def get_set_common(self, config): key = "key" value = "value" mock_cache_instance = MagicMock() - mock_cache_factory.return_value = mock_cache_instance + with patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=mock_cache_instance): + cache = Cache(config) + cache.set(key, value) + cache.get(key) - cache = Cache(self.config) - cache.set(key, value) - cache.get(key) + mock_cache_instance.set.assert_called_with(key, value) + mock_cache_instance.get.assert_called_with(key, None) - mock_cache_instance.set.assert_called_with(key, value) - mock_cache_instance.get.assert_called_with(key, None) + def test_redis_get_set(self): + self.get_set_common(self.redis_config) - @patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=MagicMock()) - def test_close(self, mock_cache_factory): + def test_cosmos_get_set(self): + self.get_set_common(self.cosmos_config) + + def close_common(self, config): mock_cache_instance = MagicMock() - mock_cache_factory.return_value = mock_cache_instance + with patch("autogen.cache.cache_factory.CacheFactory.cache_factory", return_value=mock_cache_instance): + cache = Cache(config) + cache.close() + mock_cache_instance.close.assert_called() - cache = Cache(self.config) - cache.close() + def test_redis_close(self): + self.close_common(self.redis_config) - mock_cache_instance.close.assert_called() + def test_cosmos_close(self): + self.close_common(self.cosmos_config) if __name__ == "__main__": diff --git a/test/cache/test_cosmos_db_cache.py b/test/cache/test_cosmos_db_cache.py new file mode 100644 index 00000000000..f89a4c96cf4 --- /dev/null +++ b/test/cache/test_cosmos_db_cache.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 -m pytest + +import pickle +import unittest +from unittest.mock import MagicMock, patch + +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +from autogen.cache.cosmos_db_cache import CosmosDBCache + + +class TestCosmosDBCache(unittest.TestCase): + def setUp(self): + self.seed = "42" + self.connection_string = "AccountEndpoint=https://example.documents.azure.com:443/;" + self.database_id = "autogen_cache" + self.container_id = "TestContainer" + self.client = MagicMock() + + @patch("autogen.cache.cosmos_db_cache.CosmosClient.from_connection_string", return_value=MagicMock()) + def test_init(self, mock_from_connection_string): + cache = CosmosDBCache.from_connection_string( + self.seed, self.connection_string, self.database_id, self.container_id + ) + self.assertEqual(cache.seed, self.seed) + mock_from_connection_string.assert_called_with(self.connection_string) + + def test_get(self): + key = "key" + value = "value" + serialized_value = pickle.dumps(value) + cache = CosmosDBCache( + self.seed, + { + "connection_string": self.connection_string, + "database_id": self.database_id, + "container_id": self.container_id, + "client": self.client, + }, + ) + cache.container.read_item.return_value = {"data": serialized_value} + self.assertEqual(cache.get(key), value) + cache.container.read_item.assert_called_with(item=key, partition_key=str(self.seed)) + + cache.container.read_item.side_effect = CosmosResourceNotFoundError(status_code=404, message="Item not found") + self.assertIsNone(cache.get(key, default=None)) + + def test_set(self): + key = "key" + value = "value" + serialized_value = pickle.dumps(value) + cache = CosmosDBCache( + self.seed, + { + "connection_string": self.connection_string, + "database_id": self.database_id, + "container_id": self.container_id, + "client": self.client, + }, + ) + cache.set(key, value) + expected_item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value} + cache.container.upsert_item.assert_called_with(expected_item) + + def test_context_manager(self): + with patch("autogen.cache.cosmos_db_cache.CosmosDBCache.close", MagicMock()) as mock_close: + with CosmosDBCache( + self.seed, + { + "connection_string": self.connection_string, + "database_id": self.database_id, + "container_id": self.container_id, + "client": self.client, + }, + ) as cache: + self.assertIsInstance(cache, CosmosDBCache) + mock_close.assert_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/website/docs/topics/llm-caching.md b/website/docs/topics/llm-caching.md index 870fdad56b5..3d14fe49f51 100644 --- a/website/docs/topics/llm-caching.md +++ b/website/docs/topics/llm-caching.md @@ -3,8 +3,7 @@ AutoGen supports caching API requests so that they can be reused when the same request is issued. This is useful when repeating or continuing experiments for reproducibility and cost saving. Since version [`0.2.8`](https://github.com/microsoft/autogen/releases/tag/v0.2.8), a configurable context manager allows you to easily -configure LLM cache, using either [`DiskCache`](/docs/reference/cache/disk_cache#diskcache) or [`RedisCache`](/docs/reference/cache/redis_cache#rediscache). All agents inside the -context manager will use the same cache. +configure LLM cache, using either [`DiskCache`](/docs/reference/cache/disk_cache#diskcache), [`RedisCache`](/docs/reference/cache/redis_cache#rediscache), or Cosmos DB Cache. All agents inside the context manager will use the same cache. ```python from autogen import Cache @@ -16,6 +15,11 @@ with Cache.redis(redis_url="redis://localhost:6379/0") as cache: # Use DiskCache as cache with Cache.disk() as cache: user.initiate_chat(assistant, message=coding_task, cache=cache) + +# Use Azure Cosmos DB as cache +with Cache.cosmos_db(connection_string="your_connection_string", database_id="your_database_id", container_id="your_container_id") as cache: + user.initiate_chat(assistant, message=coding_task, cache=cache) + ``` The cache can also be passed directly to the model client's create call.