Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions examples/corpus_search/azure_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ async def release_lock(self, token: LockToken) -> None:
# CORPUS_AZURE_CONTAINER_URL=https://<account>.blob.core.windows.net/<container>

_CONTAINER_URL_ENV = "CORPUS_AZURE_CONTAINER_URL"
_MI_CLIENT_ID_ENV = "CORPUS_AZURE_MANAGED_IDENTITY_CLIENT_ID"
_SQLITE_SUFFIX = ".sqlite"


Expand All @@ -312,16 +313,26 @@ class AzureCorpusBackendRegistry:
Listing walks the container to recover corpus ids.
"""

def __init__(self, container_url: str) -> None:
def __init__(self, container_url: str, *, managed_identity_client_id: str | None = None) -> None:
# ``DefaultAzureCredential`` is imported lazily so this module
# is importable without the Azure SDK installed; constructing
# the registry is when the cost is paid.
from azure.identity import DefaultAzureCredential # type: ignore[import-not-found]

self._container_url = container_url.rstrip("/")
# One credential for the lifetime of the registry so token
# caches survive across calls.
self._credential = DefaultAzureCredential()
# When a managed identity client id is supplied, pin
# ``ManagedIdentityCredential`` to it explicitly. Otherwise the
# SDK picks ``AZURE_CLIENT_ID`` from the env, which on this
# service is the OAuth App Registration's client id — not a
# managed identity — so the MI step of the credential chain
# fails with "No User Assigned … Managed Identity found for
# specified ClientId". Passing the kwarg overrides only the MI
# step; other credentials in the chain (env, workload identity,
# az CLI for local dev) are unaffected.
if managed_identity_client_id:
self._credential = DefaultAzureCredential(managed_identity_client_id=managed_identity_client_id)
else:
self._credential = DefaultAzureCredential()

@property
def source(self) -> str:
Expand Down Expand Up @@ -369,8 +380,13 @@ async def list_corpora(self) -> list[dict[str, Any]]:
def build_registry() -> AzureCorpusBackendRegistry:
"""Factory entry point referenced by ``CORPUS_BACKEND_REGISTRY_FACTORY``.

Reads the container URL from the environment so operators don't have
to pass parameters through the env-var-spec mechanism.
Reads the container URL from ``CORPUS_AZURE_CONTAINER_URL`` and an
optional managed-identity client id from
``CORPUS_AZURE_MANAGED_IDENTITY_CLIENT_ID``. The MI env var is
needed whenever the host also sets ``AZURE_CLIENT_ID`` to something
that isn't a managed identity (e.g. the MCP server uses
``AZURE_CLIENT_ID`` for OAuth audience validation, which would
otherwise confuse ``DefaultAzureCredential``'s MI step).
"""
container_url = os.environ.get(_CONTAINER_URL_ENV)
if not container_url:
Expand All @@ -379,4 +395,7 @@ def build_registry() -> AzureCorpusBackendRegistry:
"(e.g. https://<account>.blob.core.windows.net/<container>) for the "
"Azure-backed corpus registry."
)
return AzureCorpusBackendRegistry(container_url)
return AzureCorpusBackendRegistry(
container_url,
managed_identity_client_id=os.environ.get(_MI_CLIENT_ID_ENV),
)
Loading