From 002e016a4f64cb880f94ce0369a67ae0e266d5f4 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 20 Oct 2025 08:47:46 +0100 Subject: [PATCH 1/2] Fixed bootstrap issue --- .pylintrc | 4 - src/launch_client.py | 6 +- src/server/api/core/bootstrap.py | 16 - src/server/api/core/databases.py | 2 +- src/server/api/core/oci.py | 55 --- src/server/api/core/prompts.py | 2 +- src/server/api/core/settings.py | 2 +- src/server/api/utils/chat.py | 5 +- src/server/api/utils/databases.py | 2 +- src/server/api/utils/models.py | 12 +- src/server/api/utils/oci.py | 159 +++++--- src/server/api/v1/embed.py | 5 +- src/server/api/v1/oci.py | 7 +- src/server/api/v1/testbed.py | 42 +- src/server/bootstrap/bootstrap.py | 6 +- .../unit/api/core/test_core_databases.py | 2 +- tests/server/unit/api/core/test_core_oci.py | 104 ----- .../server/unit/api/utils/test_utils_chat.py | 12 +- .../unit/api/utils/test_utils_databases.py | 2 +- .../unit/{ => api/utils}/test_utils_models.py | 0 tests/server/unit/api/utils/test_utils_oci.py | 378 +++++++++++++++++- .../test_bootstrap.py} | 4 +- 22 files changed, 558 insertions(+), 269 deletions(-) delete mode 100644 src/server/api/core/bootstrap.py delete mode 100644 src/server/api/core/oci.py delete mode 100644 tests/server/unit/api/core/test_core_oci.py rename tests/server/unit/{ => api/utils}/test_utils_models.py (100%) rename tests/server/unit/{api/core/test_core_bootstrap.py => bootstrap/test_bootstrap.py} (94%) diff --git a/.pylintrc b/.pylintrc index 49770063..5beeb9a4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -104,10 +104,6 @@ recursive=no # source root. source-roots= -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no diff --git a/src/launch_client.py b/src/launch_client.py index 05825a2a..20b4cac8 100644 --- a/src/launch_client.py +++ b/src/launch_client.py @@ -22,7 +22,7 @@ logger = logging_config.logging.getLogger("launch_client") # Import launch_server if it exists -LAUNCH_SERVER_EXISTS = True +launch_server_exists = True try: from launch_server import start_server, get_api_key @@ -31,7 +31,7 @@ except ImportError as ex: logger.debug("API Server not present: %s", ex) os.environ.pop("API_SERVER_CONTROL", None) - LAUNCH_SERVER_EXISTS = False + launch_server_exists = False BASE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -159,7 +159,7 @@ def main() -> None: if __name__ == "__main__": # Start Server if not running init_server_state() - if LAUNCH_SERVER_EXISTS: + if launch_server_exists: try: logger.debug("Server PID: %i", state.server["pid"]) except KeyError: diff --git a/src/server/api/core/bootstrap.py b/src/server/api/core/bootstrap.py deleted file mode 100644 index fd970758..00000000 --- a/src/server/api/core/bootstrap.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker:ignore genai - -from server.bootstrap import databases, models, oci, prompts, settings -from common import logging_config - -logger = logging_config.logging.getLogger("api.core.bootstrap") - -DATABASE_OBJECTS = databases.main() -MODEL_OBJECTS = models.main() -OCI_OBJECTS = oci.main() -PROMPT_OBJECTS = prompts.main() -SETTINGS_OBJECTS = settings.main() diff --git a/src/server/api/core/databases.py b/src/server/api/core/databases.py index 4919ff24..e74ec78b 100644 --- a/src/server/api/core/databases.py +++ b/src/server/api/core/databases.py @@ -4,7 +4,7 @@ """ from typing import Optional, Union -from server.api.core import bootstrap +from server.bootstrap import bootstrap from common.schema import Database, DatabaseNameType from common import logging_config diff --git a/src/server/api/core/oci.py b/src/server/api/core/oci.py deleted file mode 100644 index b48c8726..00000000 --- a/src/server/api/core/oci.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker:ignore genai ocids - -from typing import Optional, Union - -from server.api.core import bootstrap, settings -from common.schema import OracleCloudSettings, ClientIdType, OCIProfileType -from common import logging_config - -logger = logging_config.logging.getLogger("api.core.oci") - - -##################################################### -# Functions -##################################################### -def get_oci( - client: Optional[ClientIdType] = None, auth_profile: Optional[OCIProfileType] = None -) -> Union[list[OracleCloudSettings], OracleCloudSettings]: - """ - Return all OCI Settings if no client or auth_profile is specified. - Raises ValueError if both client and auth_profile are provided. - If client is provided, derives auth_profile and returns matching OCI settings. - If auth_profile is provided, returns matching OCI settings. - Raises ValueError if no matching OCI found. - """ - logger.debug("Getting OCI config for client: %s; auth_profile: %s", client, auth_profile) - if client is not None and auth_profile is not None: - raise ValueError("provide either 'client' or 'auth_profile', not both") - - oci_objects = bootstrap.OCI_OBJECTS - if client is not None: - client_settings = settings.get_client_settings(client) - derived_auth_profile = ( - getattr(client_settings.oci, "auth_profile", "DEFAULT") if client_settings.oci else "DEFAULT" - ) - - matching_oci = next((oci for oci in oci_objects if oci.auth_profile == derived_auth_profile), None) - if matching_oci is None: - raise ValueError(f"No settings found for client '{client}' with auth_profile '{derived_auth_profile}'") - return matching_oci - - if auth_profile is not None: - matching_oci = next((oci for oci in oci_objects if oci.auth_profile == auth_profile), None) - if matching_oci is None: - raise ValueError(f"profile '{auth_profile}' not found") - return matching_oci - - # No filters, return all - if not oci_objects: - raise ValueError("not configured") - - return oci_objects diff --git a/src/server/api/core/prompts.py b/src/server/api/core/prompts.py index 78409376..57a70a8b 100644 --- a/src/server/api/core/prompts.py +++ b/src/server/api/core/prompts.py @@ -5,7 +5,7 @@ # spell-checker:ignore from typing import Optional, Union -from server.api.core import bootstrap +from server.bootstrap import bootstrap from common.schema import PromptCategoryType, PromptNameType, Prompt from common import logging_config diff --git a/src/server/api/core/settings.py b/src/server/api/core/settings.py index 0f993721..bc89de5a 100644 --- a/src/server/api/core/settings.py +++ b/src/server/api/core/settings.py @@ -6,7 +6,7 @@ import os import copy import json -from server.api.core import bootstrap +from server.bootstrap import bootstrap from common.schema import Settings, Configuration, ClientIdType from common import logging_config diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 9f5d09b3..e68b39ca 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -11,8 +11,9 @@ from langchain_core.runnables import RunnableConfig import server.api.core.settings as core_settings -import server.api.core.oci as core_oci import server.api.core.prompts as core_prompts + +import server.api.utils.oci as utils_oci import server.api.utils.models as utils_models import server.api.utils.databases as utils_databases import server.api.utils.selectai as utils_selectai @@ -41,7 +42,7 @@ async def completion_generator( if not model["model"]: model = client_settings.ll_model.model_dump() - oci_config = core_oci.get_oci(client=client) + oci_config = utils_oci.get(client=client) # Setup Client Model try: diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index cf6d6011..3b85da05 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -156,7 +156,7 @@ def connect(config: Database) -> oracledb.Connection: raise except OSError as ex: raise ConnectionError(f"Error connecting to database: {ex}") from ex - + logger.debug("Connected to Databases: %s", config.dsn) return conn diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index 78ac2a01..ab9f70d7 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -264,16 +264,20 @@ def get_litellm_config( if provider == "oci": oci_params = { - "oci_auth_type": oci_config.authentication, - "oci_tenancy": oci_config.tenancy, "oci_region": oci_config.genai_region, "oci_compartment_id": oci_config.genai_compartment_id, } - # Only add credentials if NOT using instance principals or workload identity - if oci_config.authentication not in ("instance_principal", "oke_workload_identity"): + # Get OCI signer (returns None for API key auth) + signer = utils_oci.get_signer(oci_config) + if signer: + # Use signer for instance principals/workload identity + oci_params["oci_signer"] = signer + else: + # Use API key authentication (traditional method) oci_params.update( { + "oci_tenancy": oci_config.tenancy, "oci_user": oci_config.user, "oci_fingerprint": oci_config.fingerprint, "oci_key_file": oci_config.key_file, diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index d8a25639..bc0510c7 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -7,12 +7,13 @@ import os import base64 import json -from typing import Union +from typing import Union, Optional import urllib3.exceptions import oci -from common.schema import OracleCloudSettings +from server.bootstrap import bootstrap +from common.schema import OracleCloudSettings, ClientIdType, OCIProfileType from common import logging_config logger = logging_config.logging.getLogger("api.utils.oci") @@ -31,8 +32,76 @@ def __init__(self, status_code: int, detail: str): ##################################################### -# Functions +# CRUD Functions ##################################################### +def get( + client: Optional[ClientIdType] = None, auth_profile: Optional[OCIProfileType] = None +) -> Union[list[OracleCloudSettings], OracleCloudSettings]: + """ + Return all OCI Settings if no client or auth_profile is specified. + Raises ValueError if both client and auth_profile are provided. + If client is provided, derives auth_profile and returns matching OCI settings. + If auth_profile is provided, returns matching OCI settings. + Raises ValueError if no matching OCI found. + """ + logger.debug("Getting OCI config for client: %s; auth_profile: %s", client, auth_profile) + if client is not None and auth_profile is not None: + raise ValueError("provide either 'client' or 'auth_profile', not both") + + oci_objects = bootstrap.OCI_OBJECTS + if client is not None: + # Get client settings directly from SETTINGS_OBJECTS + logger.debug("Looking for client %s in SETTINGS_OBJECTS", client) + logger.debug( + "SETTINGS_OBJECTS has %d entries: %s", + len(bootstrap.SETTINGS_OBJECTS), + [s.client for s in bootstrap.SETTINGS_OBJECTS], + ) + client_settings = next((s for s in bootstrap.SETTINGS_OBJECTS if s.client == client), None) + if not client_settings: + available_clients = [s.client for s in bootstrap.SETTINGS_OBJECTS] + raise ValueError(f"client {client} not found in SETTINGS_OBJECTS with clients: {available_clients}") + + derived_auth_profile = ( + getattr(client_settings.oci, "auth_profile", "DEFAULT") if client_settings.oci else "DEFAULT" + ) + + matching_oci = next((oci for oci in oci_objects if oci.auth_profile == derived_auth_profile), None) + if matching_oci is None: + raise ValueError(f"No settings found for client '{client}' with auth_profile '{derived_auth_profile}'") + return matching_oci + + if auth_profile is not None: + matching_oci = next((oci for oci in oci_objects if oci.auth_profile == auth_profile), None) + if matching_oci is None: + raise ValueError(f"profile '{auth_profile}' not found") + return matching_oci + + # No filters, return all + if not oci_objects: + raise ValueError("not configured") + + return oci_objects + + +##################################################### +# Utility Functions +##################################################### +def get_signer(config: OracleCloudSettings) -> Optional[object]: + """Get OCI signer for instance principal or workload identity authentication.""" + + if config.authentication == "instance_principal": + logger.info("Creating Instance Principal signer") + return oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + + if config.authentication == "oke_workload_identity": + logger.info("Creating OKE Workload Identity signer") + return oci.auth.signers.get_oke_workload_identity_resource_principal_signer() + + # API key or security token authentication - no signer needed + return None + + def init_client( client_type: Union[ oci.object_storage.ObjectStorageClient, @@ -67,23 +136,24 @@ def init_client( config_json = config.model_dump(exclude_none=False) client = None try: - if config_json["authentication"] == "instance_principal": - logger.info("OCI Authentication with Instance Principal") - instance_signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() - client = client_type(config={"region": config_json["region"]}, signer=instance_signer, **client_kwargs) - if not config.tenancy: - config.tenancy = instance_signer.tenancy_id - elif config_json["authentication"] == "oke_workload_identity": - logger.info("OCI Authentication with Workload Identity") - oke_workload_signer = oci.auth.signers.get_oke_workload_identity_resource_principal_signer() - client = client_type(config={"region": config_json["region"]}, signer=oke_workload_signer) + # Get signer for instance principal or workload identity + signer = get_signer(config) + + if signer: + # Use signer-based authentication + client = client_type(config={"region": config_json["region"]}, signer=signer, **client_kwargs) + + # Set tenancy from signer if not already set if not config.tenancy: - token = oke_workload_signer.get_security_token() - payload_part = token.split(".")[1] - padding = "=" * (-len(payload_part) % 4) - decoded_bytes = base64.urlsafe_b64decode(payload_part + padding) - payload = json.loads(decoded_bytes) - config.tenancy = payload.get("tenant") + if config_json["authentication"] == "instance_principal": + config.tenancy = signer.tenancy_id + elif config_json["authentication"] == "oke_workload_identity": + token = signer.get_security_token() + payload_part = token.split(".")[1] + padding = "=" * (-len(payload_part) % 4) + decoded_bytes = base64.urlsafe_b64decode(payload_part + padding) + payload = json.loads(decoded_bytes) + config.tenancy = payload.get("tenant") elif config_json["authentication"] == "security_token" and config_json["security_token_file"]: logger.info("OCI Authentication with Security Token") token = None @@ -151,27 +221,24 @@ def get_regions(config: OracleCloudSettings = None) -> list[dict]: def get_genai_models(config: OracleCloudSettings, regional: bool = False) -> list: """Get a list of GenAI models in a regions compartment""" - if not hasattr(config, "genai_compartment_id") or not config.genai_compartment_id: + if not config.genai_compartment_id: raise OciException(status_code=400, detail="Missing genai_compartment_id") - genai_models = [] - # Track unique models by (region, display_name) to avoid duplicates - seen_models = set() - + # Determine regions to query if regional: - # Limit models to configured region - if not hasattr(config, "genai_region") or not config.genai_region: + if not config.genai_region: raise OciException(status_code=400, detail="Missing genai_region") regions = [{"region_name": config.genai_region}] else: - # Limit models to subscribed regions regions = get_regions(config) + genai_models = [] + seen_models = set() # Track unique models by (region, display_name) + for region in regions: region_config = config.model_copy(deep=True) region_config.region = region["region_name"] - client_type = oci.generative_ai.GenerativeAiClient - client = init_client(client_type, region_config) + client = init_client(oci.generative_ai.GenerativeAiClient, region_config) logger.info( "Checking Region: %s; Compartment: %s for GenAI services", region["region_name"], @@ -186,28 +253,25 @@ def get_genai_models(config: OracleCloudSettings, regional: bool = False) -> lis sort_by="displayName", retry_strategy=oci.retry.NoneRetryStrategy(), ) - # Identify all display_names that have been deprecated - excluded_display_names = set() + # Identify deprecated model names + excluded_display_names = { + model.display_name + for model in response.data.items + if model.time_deprecated or model.time_dedicated_retired or model.time_on_demand_retired + } + + # Build list of models (excluding deprecated ones and duplicates) for model in response.data.items: - if model.time_deprecated or model.time_dedicated_retired or model.time_on_demand_retired: - excluded_display_names.add(model.display_name) - - # Build our list of models (excluding deprecated ones and duplicates) - for model in response.data.items: - # Skip deprecated models - if model.display_name in excluded_display_names: - continue - # Skip cohere models without TEXT_EMBEDDINGS capability - if model.vendor == "cohere" and "TEXT_EMBEDDINGS" not in model.capabilities: - continue - - # Skip duplicate models (same region + display_name) model_key = (region["region_name"], model.display_name) - if model_key in seen_models: - logger.debug("Skipping duplicate model: %s in %s", model.display_name, region["region_name"]) + # Skip if deprecated, duplicate, or cohere model without TEXT_EMBEDDINGS + if ( + model.display_name in excluded_display_names + or model_key in seen_models + or (model.vendor == "cohere" and "TEXT_EMBEDDINGS" not in model.capabilities) + ): continue - seen_models.add(model_key) + seen_models.add(model_key) genai_models.append( { "region": region["region_name"], @@ -218,6 +282,7 @@ def get_genai_models(config: OracleCloudSettings, regional: bool = False) -> lis "id": model.id, } ) + logger.info("Registered %i GenAI Models", len(genai_models)) except oci.exceptions.ServiceError as ex: logger.info("Unable to get GenAI Models in Region: %s (%s)", region["region_name"], ex.message) except (oci.exceptions.RequestException, urllib3.exceptions.MaxRetryError): diff --git a/src/server/api/v1/embed.py b/src/server/api/v1/embed.py index 440437dd..77176356 100644 --- a/src/server/api/v1/embed.py +++ b/src/server/api/v1/embed.py @@ -14,8 +14,7 @@ from pydantic import HttpUrl import requests -import server.api.core.oci as core_oci - +import server.api.utils.oci as utils_oci import server.api.utils.databases as utils_databases import server.api.utils.embed as utils_embed import server.api.utils.models as utils_models @@ -134,7 +133,7 @@ async def split_embed( ) -> Response: """Perform Split and Embed""" logger.debug("Received split_embed - rate_limit: %i; request: %s", rate_limit, request) - oci_config = core_oci.get_oci(client=client) + oci_config = utils_oci.get(client=client) temp_directory = utils_embed.get_temp_directory(client, "embedding") try: diff --git a/src/server/api/v1/oci.py b/src/server/api/v1/oci.py index 919abaef..d61ce054 100644 --- a/src/server/api/v1/oci.py +++ b/src/server/api/v1/oci.py @@ -7,7 +7,6 @@ from fastapi import APIRouter, HTTPException, Header from fastapi.responses import JSONResponse -import server.api.core.oci as core_oci import server.api.utils.embed as utils_embed import server.api.utils.oci as utils_oci import server.api.utils.models as utils_models @@ -19,7 +18,7 @@ # Validate the DEFAULT OCI Profile and get models try: - default_config = core_oci.get_oci(auth_profile="DEFAULT") + default_config = utils_oci.get(auth_profile="DEFAULT") _ = utils_oci.get_namespace(config=default_config) _ = utils_models.create_genai(config=default_config) except utils_oci.OciException: @@ -37,7 +36,7 @@ async def oci_list() -> list[schema.OracleCloudSettings]: """List OCI Configuration""" logger.debug("Received oci_list") try: - return core_oci.get_oci() + return utils_oci.get() except ValueError as ex: raise HTTPException(status_code=404, detail=f"OCI: {str(ex)}.") from ex @@ -53,7 +52,7 @@ async def oci_get( """List OCI Configuration""" logger.debug("Received oci_get - auth_profile: %s", auth_profile) try: - return core_oci.get_oci(auth_profile=auth_profile) + return utils_oci.get(auth_profile=auth_profile) except ValueError as ex: raise HTTPException(status_code=404, detail=f"OCI: {str(ex)}.") from ex diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py index 6b27cb76..926de33d 100644 --- a/src/server/api/v1/testbed.py +++ b/src/server/api/v1/testbed.py @@ -19,7 +19,7 @@ from langchain_core.messages import ChatMessage import server.api.core.settings as core_settings -import server.api.core.oci as core_oci +import server.api.utils.oci as utils_oci import server.api.utils.embed as utils_embed import server.api.utils.testbed as utils_testbed import server.api.utils.databases as utils_databases @@ -150,7 +150,7 @@ async def testbed_generate_qa( """Retrieve contents from a local file uploaded and generate Q&A""" # Get the Model Configuration try: - oci_config = core_oci.get_oci(client) + oci_config = utils_oci.get(client) except ValueError as ex: raise HTTPException(status_code=400, detail=str(ex)) from ex @@ -177,20 +177,40 @@ async def testbed_generate_qa( open(full_testsets, "a", encoding="utf-8") as destination, ): destination.write(source.read()) + except KeyError as ex: + # Handle empty testset error (when no questions are generated due to model issues) + shutil.rmtree(temp_directory) + if "None of" in str(ex) and "are in the columns" in str(ex): + error_message = ( + f"Failed to generate any questions using model '{ll_model}'. " + "This may indicate the model is unavailable, retired, or not found. " + "Please verify the model name and try a different model." + ) + logger.error("TestSet Generation Failed: %s", error_message) + raise HTTPException(status_code=400, detail=error_message) from ex + # Re-raise other KeyErrors + raise + except ValueError as ex: + # Handle model validation errors (e.g., empty testset due to model issues) + shutil.rmtree(temp_directory) + error_message = str(ex) + logger.error("TestSet Validation Error: %s", error_message) + raise HTTPException(status_code=400, detail=error_message) from ex except litellm.APIConnectionError as ex: shutil.rmtree(temp_directory) - logger.error("APIConnectionError Exception: %s", str(ex)) - raise HTTPException(status_code=424, detail=str(ex)) from ex + error_message = str(ex) + logger.error("APIConnectionError Exception: %s", error_message) + raise HTTPException(status_code=424, detail=f"Model API error: {error_message}") from ex except Exception as ex: shutil.rmtree(temp_directory) logger.error("Unknown TestSet Exception: %s", str(ex)) - raise HTTPException(status_code=500, detail=f"Unexpected testset error: {str(ex)}.") from ex + raise HTTPException(status_code=500, detail=f"Unexpected TestSet error: {str(ex)}.") from ex - # Store tests in database - with open(full_testsets, "rb") as file: - upload_file = UploadFile(file=file, filename=full_testsets) - testset_qa = await testbed_upsert_testsets(client=client, files=[upload_file], name=name) - shutil.rmtree(temp_directory) + # Store tests in database (only if we successfully generated testsets) + with open(full_testsets, "rb") as file: + upload_file = UploadFile(file=file, filename=full_testsets) + testset_qa = await testbed_upsert_testsets(client=client, files=[upload_file], name=name) + shutil.rmtree(temp_directory) return testset_qa @@ -233,7 +253,7 @@ def get_answer(question: str): # Setup Judge Model logger.debug("Starting evaluation with Judge: %s", judge) - oci_config = core_oci.get_oci(client) + oci_config = utils_oci.get(client) judge_config = utils_models.get_litellm_config(model_config={"model": judge}, oci_config=oci_config, giskard=True) set_llm_model(llm_model=judge, **judge_config) diff --git a/src/server/bootstrap/bootstrap.py b/src/server/bootstrap/bootstrap.py index 3aa47026..da05592b 100644 --- a/src/server/bootstrap/bootstrap.py +++ b/src/server/bootstrap/bootstrap.py @@ -4,9 +4,13 @@ """ # spell-checker:ignore genai -from server.bootstrap import models +from server.bootstrap import databases, models, oci, prompts, settings from common import logging_config logger = logging_config.logging.getLogger("bootstrap") +DATABASE_OBJECTS = databases.main() MODEL_OBJECTS = models.main() +OCI_OBJECTS = oci.main() +PROMPT_OBJECTS = prompts.main() +SETTINGS_OBJECTS = settings.main() diff --git a/tests/server/unit/api/core/test_core_databases.py b/tests/server/unit/api/core/test_core_databases.py index e3e9deca..9aacd384 100644 --- a/tests/server/unit/api/core/test_core_databases.py +++ b/tests/server/unit/api/core/test_core_databases.py @@ -9,7 +9,7 @@ import pytest from server.api.core import databases -from server.api.core import bootstrap +from server.bootstrap import bootstrap from common.schema import Database diff --git a/tests/server/unit/api/core/test_core_oci.py b/tests/server/unit/api/core/test_core_oci.py deleted file mode 100644 index dcdc88a9..00000000 --- a/tests/server/unit/api/core/test_core_oci.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Copyright (c) 2024, 2025, Oracle and/or its affiliates. -Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. -""" -# spell-checker: disable - -from unittest.mock import patch, MagicMock - -import pytest - -from server.api.core import oci -from common.schema import OracleCloudSettings, Settings, OciSettings - - -class TestOci: - """Test OCI module functionality""" - - def setup_method(self): - """Setup test data before each test""" - self.sample_oci_default = OracleCloudSettings( - auth_profile="DEFAULT", compartment_id="ocid1.compartment.oc1..default" - ) - self.sample_oci_custom = OracleCloudSettings( - auth_profile="CUSTOM", compartment_id="ocid1.compartment.oc1..custom" - ) - self.sample_client_settings = Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) - - @patch("server.api.core.oci.bootstrap") - def test_get_oci_all(self, mock_bootstrap): - """Test getting all OCI settings when no filters are provided""" - all_oci = [self.sample_oci_default, self.sample_oci_custom] - mock_bootstrap.OCI_OBJECTS = all_oci - - result = oci.get_oci() - - assert result == all_oci - - @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") - def test_get_oci_no_objects_configured(self, mock_oci_objects): - """Test getting OCI settings when none are configured""" - mock_oci_objects.__bool__ = MagicMock(return_value=False) - - with pytest.raises(ValueError, match="not configured"): - oci.get_oci() - - @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") - def test_get_oci_by_auth_profile_found(self, mock_oci_objects): - """Test getting OCI settings by auth_profile when it exists""" - mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default, self.sample_oci_custom])) - - result = oci.get_oci(auth_profile="CUSTOM") - - assert result == self.sample_oci_custom - - @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") - def test_get_oci_by_auth_profile_not_found(self, mock_oci_objects): - """Test getting OCI settings by auth_profile when it doesn't exist""" - mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) - - with pytest.raises(ValueError, match="profile 'NONEXISTENT' not found"): - oci.get_oci(auth_profile="NONEXISTENT") - - @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") - @patch("server.api.core.oci.settings.get_client_settings") - def test_get_oci_by_client_with_oci_settings(self, mock_get_client_settings, mock_oci_objects): - """Test getting OCI settings by client when client has OCI settings""" - mock_get_client_settings.return_value = self.sample_client_settings - mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default, self.sample_oci_custom])) - - result = oci.get_oci(client="test_client") - - assert result == self.sample_oci_custom - - @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") - @patch("server.api.core.oci.settings.get_client_settings") - def test_get_oci_by_client_without_oci_settings(self, mock_get_client_settings, mock_oci_objects): - """Test getting OCI settings by client when client has no OCI settings""" - client_settings_no_oci = Settings(client="test_client", oci=None) - mock_get_client_settings.return_value = client_settings_no_oci - mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) - - result = oci.get_oci(client="test_client") - - assert result == self.sample_oci_default - - @patch("server.api.core.oci.bootstrap.OCI_OBJECTS") - @patch("server.api.core.oci.settings.get_client_settings") - def test_get_oci_by_client_no_matching_profile(self, mock_get_client_settings, mock_oci_objects): - """Test getting OCI settings by client when no matching profile exists""" - mock_get_client_settings.return_value = self.sample_client_settings - mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) # Only DEFAULT profile - - with pytest.raises(ValueError, match="No settings found for client 'test_client' with auth_profile 'CUSTOM'"): - oci.get_oci(client="test_client") - - def test_get_oci_both_client_and_auth_profile(self): - """Test that providing both client and auth_profile raises an error""" - with pytest.raises(ValueError, match="provide either 'client' or 'auth_profile', not both"): - oci.get_oci(client="test_client", auth_profile="CUSTOM") - - def test_logger_exists(self): - """Test that logger is properly configured""" - assert hasattr(oci, "logger") - assert oci.logger.name == "api.core.oci" diff --git a/tests/server/unit/api/utils/test_utils_chat.py b/tests/server/unit/api/utils/test_utils_chat.py index 80dfa095..a78249bd 100644 --- a/tests/server/unit/api/utils/test_utils_chat.py +++ b/tests/server/unit/api/utils/test_utils_chat.py @@ -40,7 +40,7 @@ def setup_method(self): ) @patch("server.api.core.settings.get_client_settings") - @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.oci.get") @patch("server.api.utils.models.get_litellm_config") @patch("server.api.core.prompts.get_prompts") @patch("server.agents.chatbot.chatbot_graph.astream") @@ -77,7 +77,7 @@ async def mock_generator(): mock_get_oci.assert_called_once_with(client="test_client") @patch("server.api.core.settings.get_client_settings") - @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.oci.get") @patch("server.api.utils.models.get_litellm_config") @patch("server.api.core.prompts.get_prompts") @patch("server.agents.chatbot.chatbot_graph.astream") @@ -112,7 +112,7 @@ async def mock_generator(): assert results[2] == "[stream_finished]" @patch("server.api.core.settings.get_client_settings") - @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.oci.get") @patch("server.api.utils.models.get_litellm_config") @patch("server.api.core.prompts.get_prompts") @patch("server.api.utils.databases.get_client_database") @@ -162,7 +162,7 @@ async def mock_generator(): assert len(results) == 1 @patch("server.api.core.settings.get_client_settings") - @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.oci.get") @patch("server.api.utils.models.get_litellm_config") @patch("server.api.core.prompts.get_prompts") @patch("server.api.utils.databases.get_client_database") @@ -213,7 +213,7 @@ async def mock_generator(): assert len(results) == 1 @patch("server.api.core.settings.get_client_settings") - @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.oci.get") @patch("server.api.utils.models.get_litellm_config") @patch("server.api.core.prompts.get_prompts") @patch("server.agents.chatbot.chatbot_graph.astream") @@ -246,7 +246,7 @@ async def mock_generator(): assert len(results) == 1 @patch("server.api.core.settings.get_client_settings") - @patch("server.api.core.oci.get_oci") + @patch("server.api.utils.oci.get") @patch("server.api.utils.models.get_litellm_config") @patch("server.api.core.prompts.get_prompts") @patch("server.agents.chatbot.chatbot_graph.astream") diff --git a/tests/server/unit/api/utils/test_utils_databases.py b/tests/server/unit/api/utils/test_utils_databases.py index 9161e788..c6db7c1d 100644 --- a/tests/server/unit/api/utils/test_utils_databases.py +++ b/tests/server/unit/api/utils/test_utils_databases.py @@ -14,7 +14,7 @@ from server.api.utils import databases from server.api.utils.databases import DbException -from server.api.core import bootstrap +from server.bootstrap import bootstrap from common.schema import Database diff --git a/tests/server/unit/test_utils_models.py b/tests/server/unit/api/utils/test_utils_models.py similarity index 100% rename from tests/server/unit/test_utils_models.py rename to tests/server/unit/api/utils/test_utils_models.py diff --git a/tests/server/unit/api/utils/test_utils_oci.py b/tests/server/unit/api/utils/test_utils_oci.py index 8c2157be..c15ec5c6 100644 --- a/tests/server/unit/api/utils/test_utils_oci.py +++ b/tests/server/unit/api/utils/test_utils_oci.py @@ -11,7 +11,7 @@ from server.api.utils import oci as oci_utils from server.api.utils.oci import OciException -from common.schema import OracleCloudSettings +from common.schema import OracleCloudSettings, Settings, OciSettings class TestOciException: @@ -25,6 +25,329 @@ def test_oci_exception_initialization(self): assert str(exc) == "Invalid configuration" +class TestOciGet: + """Test OCI get() function""" + + def setup_method(self): + """Setup test data before each test""" + self.sample_oci_default = OracleCloudSettings( + auth_profile="DEFAULT", compartment_id="ocid1.compartment.oc1..default" + ) + self.sample_oci_custom = OracleCloudSettings( + auth_profile="CUSTOM", compartment_id="ocid1.compartment.oc1..custom" + ) + self.sample_client_settings = Settings(client="test_client", oci=OciSettings(auth_profile="CUSTOM")) + + @patch("server.bootstrap.bootstrap.OCI_OBJECTS", []) + def test_get_no_objects_configured(self): + """Test getting OCI settings when none are configured""" + with pytest.raises(ValueError, match="not configured"): + oci_utils.get() + + @patch("server.bootstrap.bootstrap.OCI_OBJECTS", new_callable=list) + def test_get_all(self, mock_oci_objects): + """Test getting all OCI settings when no filters are provided""" + all_oci = [self.sample_oci_default, self.sample_oci_custom] + mock_oci_objects.extend(all_oci) + + result = oci_utils.get() + + assert result == all_oci + + @patch("server.bootstrap.bootstrap.OCI_OBJECTS") + def test_get_by_auth_profile_found(self, mock_oci_objects): + """Test getting OCI settings by auth_profile when it exists""" + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default, self.sample_oci_custom])) + + result = oci_utils.get(auth_profile="CUSTOM") + + assert result == self.sample_oci_custom + + @patch("server.bootstrap.bootstrap.OCI_OBJECTS") + def test_get_by_auth_profile_not_found(self, mock_oci_objects): + """Test getting OCI settings by auth_profile when it doesn't exist""" + mock_oci_objects.__iter__ = MagicMock(return_value=iter([self.sample_oci_default])) + + with pytest.raises(ValueError, match="profile 'NONEXISTENT' not found"): + oci_utils.get(auth_profile="NONEXISTENT") + + def test_get_by_client_with_oci_settings(self): + """Test getting OCI settings by client when client has OCI settings""" + from server.bootstrap import bootstrap + + # Save originals + orig_settings = bootstrap.SETTINGS_OBJECTS + orig_oci = bootstrap.OCI_OBJECTS + + try: + # Replace with test data + bootstrap.SETTINGS_OBJECTS = [self.sample_client_settings] + bootstrap.OCI_OBJECTS = [self.sample_oci_default, self.sample_oci_custom] + + result = oci_utils.get(client="test_client") + + assert result == self.sample_oci_custom + finally: + # Restore originals + bootstrap.SETTINGS_OBJECTS = orig_settings + bootstrap.OCI_OBJECTS = orig_oci + + def test_get_by_client_without_oci_settings(self): + """Test getting OCI settings by client when client has no OCI settings""" + from server.bootstrap import bootstrap + + client_settings_no_oci = Settings(client="test_client", oci=None) + + # Save originals + orig_settings = bootstrap.SETTINGS_OBJECTS + orig_oci = bootstrap.OCI_OBJECTS + + try: + # Replace with test data + bootstrap.SETTINGS_OBJECTS = [client_settings_no_oci] + bootstrap.OCI_OBJECTS = [self.sample_oci_default] + + result = oci_utils.get(client="test_client") + + assert result == self.sample_oci_default + finally: + # Restore originals + bootstrap.SETTINGS_OBJECTS = orig_settings + bootstrap.OCI_OBJECTS = orig_oci + + @patch("server.bootstrap.bootstrap.OCI_OBJECTS") + @patch("server.bootstrap.bootstrap.SETTINGS_OBJECTS") + def test_get_by_client_not_found(self, mock_settings_objects, mock_oci_objects): + """Test getting OCI settings when client doesn't exist""" + mock_settings_objects.__iter__ = MagicMock(return_value=iter([])) + + with pytest.raises(ValueError, match="client test_client not found"): + oci_utils.get(client="test_client") + + def test_get_by_client_no_matching_profile(self): + """Test getting OCI settings by client when no matching profile exists""" + from server.bootstrap import bootstrap + + # Save originals + orig_settings = bootstrap.SETTINGS_OBJECTS + orig_oci = bootstrap.OCI_OBJECTS + + try: + # Replace with test data + bootstrap.SETTINGS_OBJECTS = [self.sample_client_settings] + bootstrap.OCI_OBJECTS = [self.sample_oci_default] # Only DEFAULT profile + + with pytest.raises(ValueError, match="No settings found for client 'test_client' with auth_profile 'CUSTOM'"): + oci_utils.get(client="test_client") + finally: + # Restore originals + bootstrap.SETTINGS_OBJECTS = orig_settings + bootstrap.OCI_OBJECTS = orig_oci + + def test_get_both_client_and_auth_profile(self): + """Test that providing both client and auth_profile raises an error""" + with pytest.raises(ValueError, match="provide either 'client' or 'auth_profile', not both"): + oci_utils.get(client="test_client", auth_profile="CUSTOM") + + +class TestGetSigner: + """Test get_signer() function""" + + def test_get_signer_instance_principal(self): + """Test get_signer with instance_principal authentication""" + config = OracleCloudSettings(auth_profile="DEFAULT", authentication="instance_principal") + + with patch("oci.auth.signers.InstancePrincipalsSecurityTokenSigner") as mock_signer: + mock_instance = MagicMock() + mock_signer.return_value = mock_instance + + result = oci_utils.get_signer(config) + + assert result == mock_instance + mock_signer.assert_called_once() + + def test_get_signer_oke_workload_identity(self): + """Test get_signer with oke_workload_identity authentication""" + config = OracleCloudSettings(auth_profile="DEFAULT", authentication="oke_workload_identity") + + with patch("oci.auth.signers.get_oke_workload_identity_resource_principal_signer") as mock_signer: + mock_instance = MagicMock() + mock_signer.return_value = mock_instance + + result = oci_utils.get_signer(config) + + assert result == mock_instance + mock_signer.assert_called_once() + + def test_get_signer_api_key(self): + """Test get_signer with api_key authentication (returns None)""" + config = OracleCloudSettings(auth_profile="DEFAULT", authentication="api_key") + + result = oci_utils.get_signer(config) + + assert result is None + + def test_get_signer_security_token(self): + """Test get_signer with security_token authentication (returns None)""" + config = OracleCloudSettings(auth_profile="DEFAULT", authentication="security_token") + + result = oci_utils.get_signer(config) + + assert result is None + + +class TestInitClient: + """Test init_client() function""" + + def setup_method(self): + """Setup test data""" + self.api_key_config = OracleCloudSettings( + auth_profile="DEFAULT", + authentication="api_key", + region="us-ashburn-1", + user="ocid1.user.oc1..testuser", + fingerprint="test-fingerprint", + tenancy="ocid1.tenancy.oc1..testtenant", + key_file="/path/to/key.pem", + ) + + @patch("oci.object_storage.ObjectStorageClient") + @patch.object(oci_utils, "get_signer", return_value=None) + def test_init_client_api_key(self, mock_get_signer, mock_client_class): + """Test init_client with API key authentication""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + result = oci_utils.init_client(oci.object_storage.ObjectStorageClient, self.api_key_config) + + assert result == mock_client + mock_get_signer.assert_called_once_with(self.api_key_config) + mock_client_class.assert_called_once() + + @patch("oci.generative_ai_inference.GenerativeAiInferenceClient") + @patch.object(oci_utils, "get_signer", return_value=None) + def test_init_client_genai_with_endpoint(self, mock_get_signer, mock_client_class): + """Test init_client for GenAI sets correct service endpoint""" + genai_config = self.api_key_config.model_copy() + genai_config.genai_compartment_id = "ocid1.compartment.oc1..test" + genai_config.genai_region = "us-chicago-1" + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + result = oci_utils.init_client(oci.generative_ai_inference.GenerativeAiInferenceClient, genai_config) + + assert result == mock_client + # Verify service_endpoint was set in kwargs + call_kwargs = mock_client_class.call_args[1] + assert "service_endpoint" in call_kwargs + assert "us-chicago-1" in call_kwargs["service_endpoint"] + + @patch("oci.identity.IdentityClient") + @patch.object(oci_utils, "get_signer") + def test_init_client_with_instance_principal_signer(self, mock_get_signer, mock_client_class): + """Test init_client with instance principal signer""" + instance_config = OracleCloudSettings( + auth_profile="DEFAULT", + authentication="instance_principal", + region="us-ashburn-1", + tenancy=None # Will be set from signer + ) + + mock_signer = MagicMock() + mock_signer.tenancy_id = "ocid1.tenancy.oc1..test" + mock_get_signer.return_value = mock_signer + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + result = oci_utils.init_client(oci.identity.IdentityClient, instance_config) + + assert result == mock_client + # Verify signer was used + call_kwargs = mock_client_class.call_args[1] + assert call_kwargs["signer"] == mock_signer + # Verify tenancy was set from signer + assert instance_config.tenancy == "ocid1.tenancy.oc1..test" + + @patch("oci.identity.IdentityClient") + @patch.object(oci_utils, "get_signer") + def test_init_client_with_workload_identity_signer(self, mock_get_signer, mock_client_class): + """Test init_client with OKE workload identity signer""" + workload_config = OracleCloudSettings( + auth_profile="DEFAULT", + authentication="oke_workload_identity", + region="us-ashburn-1", + tenancy=None # Will be extracted from token + ) + + # Mock JWT token with tenant claim + import base64 + import json + payload = {"tenant": "ocid1.tenancy.oc1..workload"} + payload_json = json.dumps(payload) + payload_b64 = base64.urlsafe_b64encode(payload_json.encode()).decode().rstrip("=") + mock_token = f"header.{payload_b64}.signature" + + mock_signer = MagicMock() + mock_signer.get_security_token.return_value = mock_token + mock_get_signer.return_value = mock_signer + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + result = oci_utils.init_client(oci.identity.IdentityClient, workload_config) + + assert result == mock_client + # Verify tenancy was extracted from token + assert workload_config.tenancy == "ocid1.tenancy.oc1..workload" + + @patch("oci.identity.IdentityClient") + @patch.object(oci_utils, "get_signer", return_value=None) + @patch("builtins.open", new_callable=MagicMock) + @patch("oci.signer.load_private_key_from_file") + @patch("oci.auth.signers.SecurityTokenSigner") + def test_init_client_with_security_token( + self, mock_sec_token_signer, mock_load_key, mock_open, mock_get_signer, mock_client_class + ): + """Test init_client with security token authentication""" + token_config = OracleCloudSettings( + auth_profile="DEFAULT", + authentication="security_token", + region="us-ashburn-1", + security_token_file="/path/to/token", + key_file="/path/to/key.pem" + ) + + # Mock file reading + mock_open.return_value.__enter__.return_value.read.return_value = "mock_token_content" + mock_private_key = MagicMock() + mock_load_key.return_value = mock_private_key + mock_signer_instance = MagicMock() + mock_sec_token_signer.return_value = mock_signer_instance + + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + result = oci_utils.init_client(oci.identity.IdentityClient, token_config) + + assert result == mock_client + mock_load_key.assert_called_once_with("/path/to/key.pem") + mock_sec_token_signer.assert_called_once_with("mock_token_content", mock_private_key) + + @patch("oci.object_storage.ObjectStorageClient") + @patch.object(oci_utils, "get_signer", return_value=None) + def test_init_client_invalid_config(self, mock_get_signer, mock_client_class): + """Test init_client with invalid config raises OciException""" + mock_client_class.side_effect = oci.exceptions.InvalidConfig("Bad config") + + with pytest.raises(OciException) as exc_info: + oci_utils.init_client(oci.object_storage.ObjectStorageClient, self.api_key_config) + + assert exc_info.value.status_code == 400 + assert "Invalid Config" in str(exc_info.value) + + class TestOciUtils: """Test OCI utility functions""" @@ -89,6 +412,59 @@ def test_get_namespace_file_not_found(self, mock_init_client): assert exc_info.value.status_code == 400 assert "Invalid Key Path" in str(exc_info.value) + @patch.object(oci_utils, "init_client") + def test_get_namespace_service_error(self, mock_init_client): + """Test namespace retrieval with service error""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = oci.exceptions.ServiceError( + status=401, code="NotAuthenticated", headers={}, message="Auth failed" + ) + mock_init_client.return_value = mock_client + + with pytest.raises(OciException) as exc_info: + oci_utils.get_namespace(self.sample_oci_config) + + assert exc_info.value.status_code == 401 + assert "AuthN Error" in str(exc_info.value) + + @patch.object(oci_utils, "init_client") + def test_get_namespace_unbound_local_error(self, mock_init_client): + """Test namespace retrieval with unbound local error""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = UnboundLocalError("local variable referenced before assignment") + mock_init_client.return_value = mock_client + + with pytest.raises(OciException) as exc_info: + oci_utils.get_namespace(self.sample_oci_config) + + assert exc_info.value.status_code == 500 + assert "No Configuration" in str(exc_info.value) + + @patch.object(oci_utils, "init_client") + def test_get_namespace_request_exception(self, mock_init_client): + """Test namespace retrieval with request exception""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = oci.exceptions.RequestException("Connection timeout") + mock_init_client.return_value = mock_client + + with pytest.raises(OciException) as exc_info: + oci_utils.get_namespace(self.sample_oci_config) + + assert exc_info.value.status_code == 503 + + @patch.object(oci_utils, "init_client") + def test_get_namespace_generic_exception(self, mock_init_client): + """Test namespace retrieval with generic exception""" + mock_client = MagicMock() + mock_client.get_namespace.side_effect = Exception("Unexpected error") + mock_init_client.return_value = mock_client + + with pytest.raises(OciException) as exc_info: + oci_utils.get_namespace(self.sample_oci_config) + + assert exc_info.value.status_code == 500 + assert "Unexpected error" in str(exc_info.value) + @patch.object(oci_utils, "init_client") def test_get_regions_success(self, mock_init_client): """Test successful regions retrieval""" diff --git a/tests/server/unit/api/core/test_core_bootstrap.py b/tests/server/unit/bootstrap/test_bootstrap.py similarity index 94% rename from tests/server/unit/api/core/test_core_bootstrap.py rename to tests/server/unit/bootstrap/test_bootstrap.py index 090e99d4..2f1c6d42 100644 --- a/tests/server/unit/api/core/test_core_bootstrap.py +++ b/tests/server/unit/bootstrap/test_bootstrap.py @@ -7,7 +7,7 @@ import importlib from unittest.mock import patch, MagicMock -from server.api.core import bootstrap +from server.bootstrap import bootstrap class TestBootstrap: @@ -50,4 +50,4 @@ def test_module_imports_and_initialization( def test_logger_exists(self): """Test that logger is properly configured""" assert hasattr(bootstrap, "logger") - assert bootstrap.logger.name == "api.core.bootstrap" + assert bootstrap.logger.name == "bootstrap" From 8e470203a484589002b3a7a3a0f738424e6b0c0d Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 20 Oct 2025 08:51:42 +0100 Subject: [PATCH 2/2] break out overrides --- src/server/bootstrap/oci.py | 83 ++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/src/server/bootstrap/oci.py b/src/server/bootstrap/oci.py index c40c297f..81cb3c3d 100644 --- a/src/server/bootstrap/oci.py +++ b/src/server/bootstrap/oci.py @@ -16,6 +16,49 @@ logger = logging_config.logging.getLogger("bootstrap.oci") +def _apply_env_overrides_to_default_profile(config: list[dict]) -> None: + """Apply environment variable overrides to the DEFAULT OCI profile.""" + def override(profile: dict, key: str, env_key: str, env: dict, overrides: dict, default=None): + val = env.get(env_key) + if val is not None and val != profile.get(key): + overrides[key] = (profile.get(key), val) + return val + return profile.get(key, default) + + env = os.environ + + for profile in config: + if profile["auth_profile"] == oci.config.DEFAULT_PROFILE: + overrides = {} + + profile.update( + { + "tenancy": override(profile, "tenancy", "OCI_CLI_TENANCY", env, overrides), + "region": override(profile, "region", "OCI_CLI_REGION", env, overrides), + "user": override(profile, "user", "OCI_CLI_USER", env, overrides), + "fingerprint": override(profile, "fingerprint", "OCI_CLI_FINGERPRINT", env, overrides), + "key_file": override(profile, "key_file", "OCI_CLI_KEY_FILE", env, overrides), + "security_token_file": override( + profile, "security_token_file", "OCI_CLI_SECURITY_TOKEN_FILE", env, overrides + ), + "authentication": env.get("OCI_CLI_AUTH") + or ("security_token" if profile.get("security_token_file") else "api_key"), + "genai_compartment_id": override( + profile, "genai_compartment_id", "OCI_GENAI_COMPARTMENT_ID", env, overrides, None + ), + "genai_region": override(profile, "genai_region", "OCI_GENAI_REGION", env, overrides, None), + "log_requests": profile.get("log_requests", False), + "additional_user_agent": profile.get("additional_user_agent", ""), + "pass_phrase": profile.get("pass_phrase"), + } + ) + + if overrides: + logger.info("Environment variable overrides for OCI DEFAULT profile:") + for key, (old, new) in overrides.items(): + logger.info(" %s: '%s' -> '%s'", key, old, new) + + def main() -> list[OracleCloudSettings]: """Read in OCI Configuration options into an object""" logger.debug("*** Bootstrapping OCI - Start") @@ -62,45 +105,7 @@ def main() -> list[OracleCloudSettings]: config.append({"auth_profile": oci.config.DEFAULT_PROFILE}) # Override DEFAULT profile with environment variables - def override(profile: dict, key: str, env_key: str, env: dict, overrides: dict, default=None): - val = env.get(env_key) - if val is not None and val != profile.get(key): - overrides[key] = (profile.get(key), val) - return val - return profile.get(key, default) - - env = os.environ - - for profile in config: - if profile["auth_profile"] == oci.config.DEFAULT_PROFILE: - overrides = {} - - profile.update( - { - "tenancy": override(profile, "tenancy", "OCI_CLI_TENANCY", env, overrides), - "region": override(profile, "region", "OCI_CLI_REGION", env, overrides), - "user": override(profile, "user", "OCI_CLI_USER", env, overrides), - "fingerprint": override(profile, "fingerprint", "OCI_CLI_FINGERPRINT", env, overrides), - "key_file": override(profile, "key_file", "OCI_CLI_KEY_FILE", env, overrides), - "security_token_file": override( - profile, "security_token_file", "OCI_CLI_SECURITY_TOKEN_FILE", env, overrides - ), - "authentication": env.get("OCI_CLI_AUTH") - or ("security_token" if profile.get("security_token_file") else "api_key"), - "genai_compartment_id": override( - profile, "genai_compartment_id", "OCI_GENAI_COMPARTMENT_ID", env, overrides, None - ), - "genai_region": override(profile, "genai_region", "OCI_GENAI_REGION", env, overrides, None), - "log_requests": profile.get("log_requests", False), - "additional_user_agent": profile.get("additional_user_agent", ""), - "pass_phrase": profile.get("pass_phrase"), - } - ) - - if overrides: - logger.info("Environment variable overrides for OCI DEFAULT profile:") - for key, (old, new) in overrides.items(): - logger.info(" %s: '%s' -> '%s'", key, old, new) + _apply_env_overrides_to_default_profile(config) # Build final OracleCloudSettings objects oci_objects = []