Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions api/core/app/llm/model_access.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any

from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
Expand All @@ -14,8 +15,21 @@


class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.

Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).

Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""

tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]

def __init__(
self,
Expand All @@ -30,8 +44,12 @@ def __init__(
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
self.credentials_cache = {}

def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this cache overlaps with ModelManager._credentials_cache enabled at L107. Both ultimately call provider_configuration.get_current_credentials for the same (provider, LLM, model).

return deepcopy(self.credentials_cache[(provider_name, model_name)])

provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
Expand All @@ -46,6 +64,7 @@ def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")

self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials


Expand All @@ -65,7 +84,8 @@ def __init__(
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
),
enable_credentials_cache=True,
)
self.model_manager = model_manager

Expand All @@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager)
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)

return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
Expand Down
46 changes: 41 additions & 5 deletions api/core/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload

from configs import dify_config
Expand Down Expand Up @@ -36,11 +37,13 @@ class ModelInstance:
Model instance class.
"""

def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
Expand Down Expand Up @@ -434,8 +437,30 @@ def get_tts_voices(self, language: str | None = None):


class ModelManager:
def __init__(self, provider_manager: ProviderManager):
"""Resolves :class:`ModelInstance` objects for a tenant and provider.

When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.

The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""

def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache

@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
Expand Down Expand Up @@ -463,8 +488,19 @@ def get_model_instance(
tenant_id=tenant_id, provider=provider, model_type=model_type
)

model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
cred_cache_key = (tenant_id, provider, model_type.value, model)

if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)

ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret

def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""
Expand Down
2 changes: 2 additions & 0 deletions api/services/enterprise/enterprise_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from typing import TYPE_CHECKING

from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator

from configs import dify_config
Expand Down Expand Up @@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:

class EnterpriseService:
@classmethod
@ttl_cache(ttl=5)
def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info")

Expand Down
25 changes: 24 additions & 1 deletion api/tests/unit_tests/core/test_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pytest_mock import MockerFixture

from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager
from core.model_manager import LBModelManager, ModelManager
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelType

Expand Down Expand Up @@ -40,6 +40,29 @@ def is_cooldown(config: ModelLoadBalancingConfiguration):
return lb_model_manager


def test_model_manager_with_cache_enabled_reuses_stored_credentials():
"""With ``enable_credentials_cache=True``, later calls for the same key return cached creds."""
provider_manager = MagicMock()
bundle = MagicMock()
bundle.configuration.provider.provider = "openai"
bundle.configuration.tenant_id = "tenant-1"
bundle.configuration.model_settings = []
bundle.model_type_instance.model_type = ModelType.LLM
get_creds = MagicMock(return_value={"api_key": "first"})
bundle.configuration.get_current_credentials = get_creds
provider_manager.get_provider_model_bundle.return_value = bundle

manager = ModelManager(provider_manager, enable_credentials_cache=True)
first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
assert first.credentials == {"api_key": "first"}
get_creds.assert_called_once()

get_creds.return_value = {"api_key": "second"}
second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
assert second.credentials == {"api_key": "first"}
get_creds.assert_called_once()


def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client
redis_client.initialize(redis.Redis())
Expand Down
Loading