diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index c49c4eb0acf47c..5631caa1a59416 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -1,5 +1,6 @@ from __future__ import annotations +from copy import deepcopy from typing import Any from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity @@ -14,8 +15,21 @@ class DifyCredentialsProvider: + """Resolves and returns LLM credentials for a given provider and model. + + Fetched credentials are stored in :attr:`credentials_cache` and reused for + subsequent ``fetch`` calls for the same ``(provider_name, model_name)``. + Because of that cache, a single instance can return stale credentials after + the tenant or provider configuration changes (e.g. API key rotation). + + Do **not** keep one instance for the lifetime of a process or across + unrelated invocations. Create a new provider per request, workflow run, or + other bounded scope where up-to-date credentials matter. + """ + tenant_id: str provider_manager: ProviderManager + credentials_cache: dict[tuple[str, str], dict[str, Any]] def __init__( self, @@ -30,8 +44,12 @@ def __init__( user_id=run_context.user_id, ) self.provider_manager = provider_manager + self.credentials_cache = {} def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + if (provider_name, model_name) in self.credentials_cache: + return deepcopy(self.credentials_cache[(provider_name, model_name)]) + provider_configurations = self.provider_manager.get_configurations(self.tenant_id) provider_configuration = provider_configurations.get(provider_name) if not provider_configuration: @@ -46,6 +64,7 @@ def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: if credentials is None: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials) return credentials @@ -65,7 +84,8 @@ def __init__( provider_manager=create_plugin_provider_manager( tenant_id=run_context.tenant_id, user_id=run_context.user_id, - ) + ), + enable_credentials_cache=True, ) self.model_manager = model_manager @@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro tenant_id=run_context.tenant_id, user_id=run_context.user_id, ) - model_manager = ModelManager(provider_manager=provider_manager) + model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True) return ( DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 86d0e3baaa6a29..457c888e33ce2a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,5 +1,6 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence +from copy import deepcopy from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from configs import dify_config @@ -36,11 +37,13 @@ class ModelInstance: Model instance class. """ - def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): + def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None: self.provider_model_bundle = provider_model_bundle self.model_name = model self.provider = provider_model_bundle.configuration.provider.provider - self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + if credentials is None: + credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + self.credentials = credentials # Runtime LLM invocation fields. self.parameters: Mapping[str, Any] = {} self.stop: Sequence[str] = () @@ -434,8 +437,30 @@ def get_tts_voices(self, language: str | None = None): class ModelManager: - def __init__(self, provider_manager: ProviderManager): + """Resolves :class:`ModelInstance` objects for a tenant and provider. + + When ``enable_credentials_cache`` is ``True``, resolved credentials for each + ``(tenant_id, provider, model_type, model)`` are stored in + ``_credentials_cache`` and reused. That can return **stale** credentials after + API keys or provider settings change, so a manager constructed with + ``enable_credentials_cache=True`` should not be kept for the lifetime of a + process or shared across unrelated work. Prefer a new manager per request, + workflow run, or similar bounded scope. + + The default is ``enable_credentials_cache=False``; in that mode the internal + credential cache is not populated, and each ``get_model_instance`` call + loads credentials from the current provider configuration. + """ + + def __init__( + self, + provider_manager: ProviderManager, + *, + enable_credentials_cache: bool = False, + ) -> None: self._provider_manager = provider_manager + self._credentials_cache: dict[tuple[str, str, str, str], Any] = {} + self._enable_credentials_cache = enable_credentials_cache @classmethod def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": @@ -463,8 +488,19 @@ def get_model_instance( tenant_id=tenant_id, provider=provider, model_type=model_type ) - model_instance = ModelInstance(provider_model_bundle, model) - return model_instance + cred_cache_key = (tenant_id, provider, model_type.value, model) + + if cred_cache_key in self._credentials_cache: + return ModelInstance( + provider_model_bundle, + model, + deepcopy(self._credentials_cache[cred_cache_key]), + ) + + ret = ModelInstance(provider_model_bundle, model) + if self._enable_credentials_cache: + self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials) + return ret def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 5040fcc7e3b5f1..bd7758f1c08bf8 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import TYPE_CHECKING +from cachetools.func import ttl_cache from pydantic import BaseModel, ConfigDict, Field, model_validator from configs import dify_config @@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None: class EnterpriseService: @classmethod + @ttl_cache(ttl=5) def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index afea9144c06e94..5a7e7e30a50861 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -5,7 +5,7 @@ from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration -from core.model_manager import LBModelManager +from core.model_manager import LBModelManager, ModelManager from extensions.ext_redis import redis_client from graphon.model_runtime.entities.model_entities import ModelType @@ -40,6 +40,29 @@ def is_cooldown(config: ModelLoadBalancingConfiguration): return lb_model_manager +def test_model_manager_with_cache_enabled_reuses_stored_credentials(): + """With ``enable_credentials_cache=True``, later calls for the same key return cached creds.""" + provider_manager = MagicMock() + bundle = MagicMock() + bundle.configuration.provider.provider = "openai" + bundle.configuration.tenant_id = "tenant-1" + bundle.configuration.model_settings = [] + bundle.model_type_instance.model_type = ModelType.LLM + get_creds = MagicMock(return_value={"api_key": "first"}) + bundle.configuration.get_current_credentials = get_creds + provider_manager.get_provider_model_bundle.return_value = bundle + + manager = ModelManager(provider_manager, enable_credentials_cache=True) + first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4") + assert first.credentials == {"api_key": "first"} + get_creds.assert_called_once() + + get_creds.return_value = {"api_key": "second"} + second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4") + assert second.credentials == {"api_key": "first"} + get_creds.assert_called_once() + + def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager): # initialize redis client redis_client.initialize(redis.Redis())