diff --git a/poetry.lock b/poetry.lock index cf0885d..1fba4b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3539,4 +3539,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.12" -content-hash = "dbcd5823f305acc0a54e7f1fdd0b85a32da36eafdcab48ee37fe0dc299b80bfe" +content-hash = "f147a38ecba7556f5b328cff4841ad7087ad41f9a6da1c48e80a3dbef0c773d9" diff --git a/pyproject.toml b/pyproject.toml index 250c9dc..7450e15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ pydantic = "2.13.1" pyspark = "3.5.1" python = ">=3.10,<3.12" pyyaml = "^6.0" +requests = "2.33.1" salesforce-cdp-connector = ">=1.0.19" setuptools_scm = "^7.1.0" diff --git a/src/datacustomcode/__init__.py b/src/datacustomcode/__init__.py index 2662e74..00cfae3 100644 --- a/src/datacustomcode/__init__.py +++ b/src/datacustomcode/__init__.py @@ -17,15 +17,11 @@ from datacustomcode.credentials import AuthType, Credentials from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader from datacustomcode.io.writer.print import PrintDataCloudWriter -from datacustomcode.proxy.client.LocalProxyClientProvider import ( - LocalProxyClientProvider, -) __all__ = [ "AuthType", "Client", "Credentials", - "LocalProxyClientProvider", "PrintDataCloudWriter", "QueryAPIDataCloudReader", ] diff --git a/src/datacustomcode/auth.py b/src/datacustomcode/auth.py index ef43ddc..6636379 100644 --- a/src/datacustomcode/auth.py +++ b/src/datacustomcode/auth.py @@ -170,7 +170,7 @@ def do_oauth_browser_flow( # Start callback server click.echo(f"\nStarting local callback server on {redirect_uri}...") - server, actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue) + server, _actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue) # Build authorization URL with final redirect_uri auth_url = ( diff --git a/src/datacustomcode/cli.py b/src/datacustomcode/cli.py index d8a00fc..c6e9c5c 100644 --- a/src/datacustomcode/cli.py +++ b/src/datacustomcode/cli.py @@ -179,14 +179,15 @@ def deploy( function_invoke_opt: str, sf_cli_org: Optional[str], ): - from datacustomcode.credentials import Credentials from datacustomcode.deploy import ( COMPUTE_TYPES, - AccessTokenResponse, CodeExtensionMetadata, - _retrieve_access_token_from_sf_cli, deploy_full, ) + from datacustomcode.token_provider import ( + CredentialsTokenProvider, + SFCLITokenProvider, + ) logger.debug("Deploying project") @@ -220,22 +221,15 @@ def deploy( function_invoke_options = function_invoke_opt.split(",") metadata.functionInvokeOptions = function_invoke_options - auth: Union[Credentials, AccessTokenResponse] - if sf_cli_org: - try: - auth = _retrieve_access_token_from_sf_cli(sf_cli_org) - except RuntimeError as e: - click.secho(f"Error: {e}", fg="red") - raise click.Abort() from None - else: - try: - auth = Credentials.from_available(profile=profile) - except ValueError as e: - click.secho( - f"Error: {e}", - fg="red", - ) - raise click.Abort() from None + try: + if sf_cli_org: + auth = SFCLITokenProvider(sf_cli_org).get_token() + else: + auth = CredentialsTokenProvider(profile).get_token() + except RuntimeError as e: + click.secho(f"Error: {e}", fg="red") + raise click.Abort() from None + deploy_full(path, metadata, auth, network) diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index c1f5763..faecf0a 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -33,7 +33,6 @@ from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode - from datacustomcode.proxy.client.base import BaseProxyClient from datacustomcode.spark.base import BaseSparkSessionProvider @@ -107,7 +106,6 @@ class Client: _reader: BaseDataCloudReader _writer: BaseDataCloudWriter _file: DefaultFindFilePath - _proxy: Optional[BaseProxyClient] _data_layer_history: dict[DataCloudObjectType, set[str]] _code_type: str @@ -115,7 +113,6 @@ def __new__( cls, reader: Optional[BaseDataCloudReader] = None, writer: Optional["BaseDataCloudWriter"] = None, - proxy: Optional[BaseProxyClient] = None, spark_provider: Optional["BaseSparkSessionProvider"] = None, code_type: str = "script", ) -> Client: @@ -223,11 +220,6 @@ def write_to_dmo( self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO) return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs) # type: ignore[no-any-return] - def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str: - if self._proxy is None: - raise ValueError("No proxy configured; set proxy or proxy_config") - return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens) # type: ignore[no-any-return] - def find_file_path(self, file_name: str) -> Path: """Return a file path""" diff --git a/src/datacustomcode/cmd.py b/src/datacustomcode/cmd.py index e656316..0bb4d2c 100644 --- a/src/datacustomcode/cmd.py +++ b/src/datacustomcode/cmd.py @@ -104,6 +104,6 @@ def _cmd_output( def cmd_output(*cmd: str, **kwargs: Any) -> Union[str, None]: - returncode, stdout_b, stderr_b = _cmd_output(*cmd, **kwargs) + _returncode, stdout_b, _stderr_b = _cmd_output(*cmd, **kwargs) stdout = stdout_b.decode() if stdout_b is not None else None return stdout diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index 820c512..901b295 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -37,15 +37,14 @@ # This lets all readers and writers to be findable via config from datacustomcode.io import * # noqa: F403 from datacustomcode.io.base import BaseDataAccessLayer -from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH002 -from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH002 -from datacustomcode.proxy.base import BaseProxyAccessLayer -from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002 from datacustomcode.spark.base import BaseSparkSessionProvider if TYPE_CHECKING: from pyspark.sql import SparkSession + from datacustomcode.io.reader.base import BaseDataCloudReader + from datacustomcode.io.writer.base import BaseDataCloudWriter + _T = TypeVar("_T", bound="BaseDataAccessLayer") @@ -55,7 +54,7 @@ class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]): def to_object(self, spark: SparkSession) -> _T: type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_T, type_(spark=spark, **self.options)) + return cast("_T", type_(spark=spark, **self.options)) class SparkConfig(ForceableConfig): @@ -74,31 +73,18 @@ class SparkConfig(ForceableConfig): _P = TypeVar("_P", bound=BaseSparkSessionProvider) -_PX = TypeVar("_PX", bound=BaseProxyAccessLayer) - - -class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]): - """Config for proxy clients that take no constructor args (e.g. no spark).""" - - type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer - - def to_object(self) -> _PX: - type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_PX, type_(**self.options)) - class SparkProviderConfig(BaseObjectConfig, Generic[_P]): type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider def to_object(self) -> _P: type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_P, type_(**self.options)) + return cast("_P", type_(**self.options)) class ClientConfig(BaseConfig): - reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None - writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None - proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None + reader_config: Union[AccessLayerObjectConfig["BaseDataCloudReader"], None] = None + writer_config: Union[AccessLayerObjectConfig["BaseDataCloudWriter"], None] = None spark_config: Union[SparkConfig, None] = None spark_provider_config: Union[ SparkProviderConfig[BaseSparkSessionProvider], None @@ -126,7 +112,6 @@ def merge( self.reader_config = merge(self.reader_config, other.reader_config) self.writer_config = merge(self.writer_config, other.writer_config) - self.proxy_config = merge(self.proxy_config, other.proxy_config) self.spark_config = merge(self.spark_config, other.spark_config) self.spark_provider_config = merge( self.spark_provider_config, other.spark_provider_config diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index d58bc7f..8a6c334 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -19,11 +19,12 @@ spark_config: spark.sql.execution.arrow.pyspark.enabled: 'true' spark.driver.extraJavaOptions: -Djava.security.manager=allow -proxy_config: - type_config_name: LocalProxyClientProvider +einstein_predictions_config: + type_config_name: DefaultEinsteinPredictions options: credentials_profile: default -einstein_predictions_config: - type_config_name: DefaultEinsteinPredictions - options: {} +llm_gateway_config: + type_config_name: DefaultLLMGateway + options: + credentials_profile: default diff --git a/src/datacustomcode/deploy.py b/src/datacustomcode/deploy.py index 4505438..114252a 100644 --- a/src/datacustomcode/deploy.py +++ b/src/datacustomcode/deploy.py @@ -19,11 +19,9 @@ import os import re import shutil -import subprocess import tempfile import time from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -37,15 +35,10 @@ import requests from datacustomcode.cmd import cmd_output -from datacustomcode.credentials import AuthType from datacustomcode.scan import find_base_directory, get_package_type -if TYPE_CHECKING: - from datacustomcode.credentials import Credentials - DATA_CUSTOM_CODE_PATH = "services/data/v63.0/ssot/data-custom-code" DATA_TRANSFORMS_PATH = "services/data/v63.0/ssot/data-transforms" -AUTH_PATH = "services/oauth2/token" WAIT_FOR_DEPLOYMENT_TIMEOUT = 3000 # Available compute types for Data Cloud deployments. @@ -163,80 +156,6 @@ class AccessTokenResponse(BaseModel): instance_url: str -def _retrieve_access_token(credentials: Credentials) -> AccessTokenResponse: - """Get an access token for the Salesforce API.""" - logger.debug("Getting oauth token...") - - url = f"{credentials.login_url.rstrip('/')}/{AUTH_PATH.lstrip('/')}" - - if credentials.auth_type == AuthType.OAUTH_TOKENS: - data = { - "grant_type": "refresh_token", - "refresh_token": credentials.refresh_token, - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - } - elif credentials.auth_type == AuthType.CLIENT_CREDENTIALS: - data = { - "grant_type": "client_credentials", - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - } - else: - raise ValueError(f"Unsupported auth_type: {credentials.auth_type}") - - response = _make_api_call(url, "POST", data=data) - return AccessTokenResponse(**response) - - -def _retrieve_access_token_from_sf_cli(sf_cli_org: str) -> AccessTokenResponse: - """Get an access token from the Salesforce CLI.""" - try: - result = subprocess.run( - ["sf", "org", "display", "--target-org", sf_cli_org, "--json"], - capture_output=True, - text=True, - check=True, - timeout=30, - ) - except FileNotFoundError as exc: - raise RuntimeError( - "The 'sf' command was not found. " - "Please install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli" - ) from exc - except subprocess.TimeoutExpired as exc: - raise RuntimeError( - f"'sf org display' timed out for org '{sf_cli_org}'" - ) from exc - except subprocess.CalledProcessError as exc: - raise RuntimeError( - f"'sf org display' failed for org '{sf_cli_org}'.\n" - f"Ensure the org is authenticated via 'sf org login web'.\n" - f"stderr: {exc.stderr.strip()}" - ) from exc - - try: - data = json.loads(result.stdout) - except json.JSONDecodeError as exc: - raise RuntimeError(f"Failed to parse 'sf org display' output: {exc}") from exc - - if data.get("status") != 0: - raise RuntimeError( - f"SF CLI error for org '{sf_cli_org}': " - f"{data.get('message', 'unknown error')}" - ) - - org_result = data.get("result", {}) - access_token = org_result.get("accessToken") - instance_url = org_result.get("instanceUrl") - if not access_token or not instance_url: - raise RuntimeError( - f"'sf org display' did not return an access token or instance URL " - f"for org '{sf_cli_org}'" - ) - return AccessTokenResponse(access_token=access_token, instance_url=instance_url) - - class CreateDeploymentResponse(BaseModel): fileUploadUrl: str @@ -567,16 +486,11 @@ def zip( def deploy_full( directory: str, metadata: CodeExtensionMetadata, - credentials: Union["Credentials", AccessTokenResponse], + access_token: AccessTokenResponse, docker_network: str, callback=None, ) -> AccessTokenResponse: """Deploy a data transform in the DataCloud.""" - if isinstance(credentials, AccessTokenResponse): - access_token = credentials - else: - access_token = _retrieve_access_token(credentials) - # prepare payload config = get_config(directory) @@ -587,7 +501,6 @@ def deploy_full( wait_for_deployment(access_token, metadata, callback) # create data transform - if isinstance(config, DataTransformConfig): create_data_transform(directory, access_token, metadata, config) return access_token diff --git a/src/datacustomcode/einstein_platform_client.py b/src/datacustomcode/einstein_platform_client.py new file mode 100644 index 0000000..761f80a --- /dev/null +++ b/src/datacustomcode/einstein_platform_client.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + Any, + Dict, + Optional, +) + +from loguru import logger +import requests + +from datacustomcode.token_provider import ( + CredentialsTokenProvider, + SFCLITokenProvider, + TokenProvider, +) + + +class EinsteinPlatformClient: + EINSTEIN_PLATFORM_MODELS_URL = ( + "https://api.salesforce.com/einstein/platform/v1/models" + ) + + def __init__( + self, + credentials_profile: Optional[str] = None, + sf_cli_org: Optional[str] = None, + **kwargs: Any, + ): + if sf_cli_org: + self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org) + logger.debug(f"Using SF CLI token provider for org: {sf_cli_org}") + else: + profile = credentials_profile or "default" + self._token_provider = CredentialsTokenProvider(profile) + logger.debug(f"Using credentials token provider with profile: {profile}") + self.token_response = None + super().__init__(**kwargs) + + def _get_headers(self): + if self.token_response is None: + self.token_response = self._token_provider.get_token() + + return { + "Authorization": f"Bearer {self.token_response.access_token}", + "Content-Type": "application/json", + "x-sfdc-app-context": "EinsteinGPT", + "x-client-feature-id": "ai-platform-models-connected-app", + } + + def make_post_request(self, url, payload): + try: + response = requests.post( + url, json=payload, headers=self._get_headers(), timeout=180 + ) + if not response.ok: + error_msg = ( + f"Request to {url} failed. " + f"Reason: {response.status_code} {response.reason} - " + f"Response body: {response.text}" + ) + logger.error(error_msg) + return response + except requests.exceptions.RequestException as e: + logger.error(f"Request to {url} failed: {e}") + raise RuntimeError(f"Request to {url} failed {e}") from e + + def parse_response(self, response): + response_data: Dict[str, Any] = {} + if response.content: + try: + response_data = response.json() + except ValueError: + logger.warning("Failed to parse response as JSON") + response_data = {"raw_response": response.text} + return response_data diff --git a/src/datacustomcode/einstein_platform_config.py b/src/datacustomcode/einstein_platform_config.py new file mode 100644 index 0000000..135809d --- /dev/null +++ b/src/datacustomcode/einstein_platform_config.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + ClassVar, + Optional, + Type, + cast, +) + +from datacustomcode.common_config import BaseObjectConfig + + +class CredentialsObjectConfig(BaseObjectConfig): + type_to_create: ClassVar[Type] + credentials_profile: Optional[str] = None + sf_cli_org: Optional[str] = None + + def to_object(self): + """Create an object instance, automatically including credentials in options""" + + options = self.options.copy() + if self.credentials_profile is not None: + options["credentials_profile"] = self.credentials_profile + if self.sf_cli_org is not None: + options["sf_cli_org"] = self.sf_cli_org + + type_ = self.type_to_create.subclass_from_config_name(self.type_config_name) + return cast(type_, type_(**options)) diff --git a/src/datacustomcode/einstein_predictions/__init__.py b/src/datacustomcode/einstein_predictions/__init__.py index 4a4b388..9aaa2a3 100644 --- a/src/datacustomcode/einstein_predictions/__init__.py +++ b/src/datacustomcode/einstein_predictions/__init__.py @@ -17,6 +17,6 @@ from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions __all__ = [ - "EinsteinPredictions", "DefaultEinsteinPredictions", + "EinsteinPredictions", ] diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index 1f741fc..28e51f0 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -13,23 +13,67 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import ( + Any, + ClassVar, + Dict, + List, +) + +from datacustomcode.einstein_platform_client import EinsteinPlatformClient from datacustomcode.einstein_predictions.base import EinsteinPredictions from datacustomcode.einstein_predictions.types import ( PredictionRequest, PredictionResponse, + PredictionType, ) -class DefaultEinsteinPredictions(EinsteinPredictions): +class DefaultEinsteinPredictions(EinsteinPlatformClient, EinsteinPredictions): CONFIG_NAME = "DefaultEinsteinPredictions" - - def __init__(self, **kwargs): - super().__init__(**kwargs) + ENDPOINT_MAP: ClassVar[dict[PredictionType, str]] = { + PredictionType.REGRESSION: "regression", + PredictionType.CLUSTERING: "clustering", + PredictionType.CLASSIFICATION: "classification", + PredictionType.BINARY_CLASSIFICATION: "binary-classification", + PredictionType.MULTI_OUTCOME: "multi-outcome", + } def predict(self, request: PredictionRequest) -> PredictionResponse: + endpoint = self.ENDPOINT_MAP.get(request.prediction_type) + if not endpoint: + raise RuntimeError( + f"Unknown prediction type: {request.prediction_type}. " + f"Valid types: {list(self.ENDPOINT_MAP.keys())}" + ) + + api_url = ( + f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_api_name}/{endpoint}" + ) + + prediction_columns: List[Dict[str, Any]] = [] + for col in request.prediction_columns: + col_data: Dict[str, Any] = {"columnName": col.column_name} + if col.string_values: + col_data["stringValues"] = col.string_values + if col.double_values: + col_data["doubleValues"] = col.double_values + if col.boolean_values: + col_data["booleanValues"] = col.boolean_values + if col.date_values: + col_data["dateValues"] = col.date_values + if col.datetime_values: + col_data["datetimeValues"] = col.datetime_values + prediction_columns.append(col_data) + + payload: Dict[str, Any] = {"predictionColumns": prediction_columns} + + if request.settings: + payload["settings"] = request.settings + + response = self.make_post_request(api_url, payload) return PredictionResponse( - version="v1", prediction_type=request.prediction_type, - status_code=200, - data={"results": [{"prediction": {"predictedValue": 1.0}}]}, + status_code=response.status_code, + data=self.parse_response(response), ) diff --git a/src/datacustomcode/einstein_predictions_config.py b/src/datacustomcode/einstein_predictions_config.py index 4e83164..1b4758f 100644 --- a/src/datacustomcode/einstein_predictions_config.py +++ b/src/datacustomcode/einstein_predictions_config.py @@ -19,25 +19,17 @@ Type, TypeVar, Union, - cast, ) -from datacustomcode.common_config import ( - BaseConfig, - BaseObjectConfig, - default_config_file, -) +from datacustomcode.common_config import BaseConfig, default_config_file +from datacustomcode.einstein_platform_config import CredentialsObjectConfig from datacustomcode.einstein_predictions.base import EinsteinPredictions _E = TypeVar("_E", bound=EinsteinPredictions) -class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]): - type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract] - - def to_object(self) -> _E: - type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_E, type_(**self.options)) +class EinsteinPredictionsObjectConfig(CredentialsObjectConfig, Generic[_E]): + type_to_create: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract] class EinsteinPredictionsConfig(BaseConfig): diff --git a/src/datacustomcode/io/reader/sf_cli.py b/src/datacustomcode/io/reader/sf_cli.py index adfb3ed..cfeb06e 100644 --- a/src/datacustomcode/io/reader/sf_cli.py +++ b/src/datacustomcode/io/reader/sf_cli.py @@ -16,7 +16,6 @@ import json import logging -import subprocess from typing import ( TYPE_CHECKING, Final, @@ -29,6 +28,7 @@ from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.reader.utils import _pandas_to_spark_schema +from datacustomcode.token_provider import SFCLITokenProvider if TYPE_CHECKING: from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession @@ -78,64 +78,9 @@ def __init__( logger.debug(f"Initialized SFCLIDataCloudReader for org '{sf_cli_org}'") def _get_token(self) -> tuple[str, str]: - """Fetch a fresh access token and instance URL from the SF CLI. - - Returns: - ``(access_token, instance_url)`` - - Raises: - RuntimeError: If the ``sf`` command is not on PATH, times out, or - returns an error. - """ - try: - result = subprocess.run( - ["sf", "org", "display", "--target-org", self.sf_cli_org, "--json"], - capture_output=True, - text=True, - check=True, - timeout=30, - ) - except FileNotFoundError as exc: - raise RuntimeError( - "The 'sf' command was not found. " - "Please install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli" - ) from exc - except subprocess.TimeoutExpired as exc: - raise RuntimeError( - f"'sf org display' timed out for org '{self.sf_cli_org}'" - ) from exc - except subprocess.CalledProcessError as exc: - raise RuntimeError( - f"'sf org display' failed for org '{self.sf_cli_org}'.\n" - f"Ensure the org is authenticated via 'sf org login web'.\n" - f"stderr: {exc.stderr.strip()}" - ) from exc - - try: - data = json.loads(result.stdout) - except json.JSONDecodeError as exc: - raise RuntimeError( - f"Failed to parse 'sf org display' output: {exc}" - ) from exc - - if data.get("status") != 0: - raise RuntimeError( - f"SF CLI error for org '{self.sf_cli_org}': " - f"{data.get('message', 'unknown error')}" - ) - - org_result = data.get("result", {}) - access_token = org_result.get("accessToken") - instance_url = org_result.get("instanceUrl") - - if not access_token or not instance_url: - raise RuntimeError( - f"'sf org display' did not return an access token or instance URL " - f"for org '{self.sf_cli_org}'" - ) - + token_response = SFCLITokenProvider(self.sf_cli_org).get_token() logger.debug(f"Fetched token from SF CLI for org '{self.sf_cli_org}'") - return access_token, instance_url + return token_response.access_token, token_response.instance_url def _execute_query(self, sql: str) -> pd.DataFrame: """Execute *sql* against the Data Cloud REST endpoint. diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index 9fefbc7..88374e3 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict + +from datacustomcode.einstein_platform_client import EinsteinPlatformClient from datacustomcode.llm_gateway.base import LLMGateway from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse @@ -21,15 +24,24 @@ ) -class DefaultLLMGateway(LLMGateway): +class DefaultLLMGateway(EinsteinPlatformClient, LLMGateway): CONFIG_NAME = "DefaultLLMGateway" def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: + api_url = ( + f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_name}/generations" + ) - response_data = { - "version": "v1", - "status_code": 200, - "data": {"generation": {"generatedText": "Hello World"}}, - } + payload: Dict[str, Any] = {"prompt": request.prompt} - return GenerateTextResponseBuilder.build(response_data) + if request.localization: + payload["localization"] = request.localization + if request.tags: + payload["tags"] = request.tags + + response = self.make_post_request(api_url, payload) + response_dict = { + "status_code": response.status_code, + "data": self.parse_response(response), + } + return GenerateTextResponseBuilder.build(response_dict) diff --git a/src/datacustomcode/llm_gateway_config.py b/src/datacustomcode/llm_gateway_config.py index 7e7e53d..a65d0eb 100644 --- a/src/datacustomcode/llm_gateway_config.py +++ b/src/datacustomcode/llm_gateway_config.py @@ -19,25 +19,17 @@ Type, TypeVar, Union, - cast, ) -from datacustomcode.common_config import ( - BaseConfig, - BaseObjectConfig, - default_config_file, -) +from datacustomcode.common_config import BaseConfig, default_config_file +from datacustomcode.einstein_platform_config import CredentialsObjectConfig from datacustomcode.llm_gateway.base import LLMGateway _E = TypeVar("_E", bound=LLMGateway) -class LLMGatewayObjectConfig(BaseObjectConfig, Generic[_E]): - type_base: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract] - - def to_object(self) -> _E: - type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_E, type_(**self.options)) +class LLMGatewayObjectConfig(CredentialsObjectConfig, Generic[_E]): + type_to_create: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract] class LLMGatewayConfig(BaseConfig): diff --git a/src/datacustomcode/proxy/client/LocalProxyClientProvider.py b/src/datacustomcode/proxy/client/LocalProxyClientProvider.py deleted file mode 100644 index 9c08b54..0000000 --- a/src/datacustomcode/proxy/client/LocalProxyClientProvider.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from datacustomcode.proxy.client.base import BaseProxyClient - - -class LocalProxyClientProvider(BaseProxyClient): - """Default proxy client provider.""" - - CONFIG_NAME = "LocalProxyClientProvider" - - def __init__(self, **kwargs: object) -> None: - pass - - def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: - return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}" - - def llm_gateway_generate_text( - self, template, values, llmModelId: str, maxTokens: int - ): - return f"Using Generate Text with {llmModelId} and maxTokens: {maxTokens}" diff --git a/src/datacustomcode/run.py b/src/datacustomcode/run.py index b581072..0e5052a 100644 --- a/src/datacustomcode/run.py +++ b/src/datacustomcode/run.py @@ -25,10 +25,12 @@ ) from datacustomcode.config import config +from datacustomcode.einstein_predictions_config import einstein_predictions_config +from datacustomcode.llm_gateway_config import llm_gateway_config from datacustomcode.scan import find_base_directory, get_package_type -def _set_config_option(config_obj, key: str, value: str) -> None: +def _set_config_option(config_obj, key: str, value: Optional[str]) -> None: """Set an option on a config object if it exists and has options attribute. Args: @@ -36,10 +38,33 @@ def _set_config_option(config_obj, key: str, value: str) -> None: key: Option key to set value: Option value to set """ - if config_obj and hasattr(config_obj, "options"): + if config_obj and hasattr(config_obj, "options") and value is not None: config_obj.options[key] = value +def _update_config_options(profile: Optional[str], sf_cli_org: Optional[str]): + if sf_cli_org: + config_key = "sf_cli_org" + _set_config_option(config.reader_config, config_key, sf_cli_org) + _set_config_option(config.writer_config, config_key, sf_cli_org) + _set_config_option( + einstein_predictions_config.einstein_predictions_config, + config_key, + sf_cli_org, + ) + _set_config_option( + llm_gateway_config.llm_gateway_config, config_key, sf_cli_org + ) + elif profile != "default": + config_key = "credentials_profile" + _set_config_option(config.reader_config, config_key, profile) + _set_config_option(config.writer_config, config_key, profile) + _set_config_option( + einstein_predictions_config.einstein_predictions_config, config_key, profile + ) + _set_config_option(llm_gateway_config.llm_gateway_config, config_key, profile) + + def run_entrypoint( entrypoint: str, config_file: Union[str, None], @@ -47,7 +72,7 @@ def run_entrypoint( profile: str, sf_cli_org: Optional[str] = None, ) -> None: - """Run the entrypoint script with the given config and dependencies. + """Run the entrypoint for script or function with the given config and dependencies. Args: entrypoint: The entrypoint script to run. @@ -98,12 +123,8 @@ def run_entrypoint( _set_config_option(config.reader_config, "dataspace", dataspace) _set_config_option(config.writer_config, "dataspace", dataspace) - if sf_cli_org: - _set_config_option(config.reader_config, "sf_cli_org", sf_cli_org) - _set_config_option(config.writer_config, "sf_cli_org", sf_cli_org) - elif profile != "default": - _set_config_option(config.reader_config, "credentials_profile", profile) - _set_config_option(config.writer_config, "credentials_profile", profile) + _update_config_options(profile, sf_cli_org) + for dependency in dependencies: try: importlib.import_module(dependency) diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index 613d142..a1cd685 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -59,9 +59,23 @@ def make_einstein_prediction(runtime: Runtime) -> None: ) prediction_response = runtime.einstein_predictions.predict(prediction_request) - print( - f"Einstein prediction results - success: {prediction_response.is_success} \ - response data: {prediction_response.data}" + logger.info( + f"Einstein prediction results - success: [{prediction_response.is_success}] " + f"response data: {prediction_response.data}" + ) + + +def generate_text(runtime: Runtime): + builder = GenerateTextRequestBuilder() + llm_request = ( + builder.set_prompt("Generate 2 dog names") + .set_model("sfdc_ai__DefaultGPT52") + .build() + ) + llm_response = runtime.llm_gateway.generate_text(llm_request) + logger.info( + f"LLM Gateway generate text results - success: [{llm_response.is_success}] " + f"response data: {llm_response.data}" ) @@ -73,16 +87,14 @@ def function(request: dict, runtime: Runtime) -> dict: output_chunks = [] current_seq_no = 1 # Start sequence number from 1 - make_einstein_prediction(runtime) - - builder = GenerateTextRequestBuilder() - llm_request = builder.set_prompt("Hello").set_model("modelName").build() - llm_response = runtime.llm_gateway.generate_text(llm_request) - - if llm_response.is_success: - print(llm_response.text) - else: - print(llm_response.error_code) + """ + You can use your AI models configured in Salesforce + to generate texts or predict an outcome. + First configure an external client app before using these AI APIs + https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api-get-started.html#create-a-salesforce-app" + """ + # generate_text(runtime) + # make_einstein_prediction(runtime) for item in items: # Item is DocElement as dict diff --git a/src/datacustomcode/token_provider.py b/src/datacustomcode/token_provider.py new file mode 100644 index 0000000..ee98aeb --- /dev/null +++ b/src/datacustomcode/token_provider.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datacustomcode.deploy import AccessTokenResponse + + +class TokenProvider(ABC): + """Abstract base class for providing authentication tokens.""" + + @abstractmethod + def get_token(self) -> "AccessTokenResponse": + """Retrieve a fresh access token and instance URL. + + Returns: + AccessTokenResponse containing access_token and instance_url + """ + ... + + +class CredentialsTokenProvider(TokenProvider): + """Token provider that uses stored credentials with refresh token flow.""" + + def __init__(self, credentials_profile: str = "default"): + self.credentials_profile = credentials_profile + + def get_token(self) -> "AccessTokenResponse": + """Get token by refreshing credentials from stored profile.""" + import requests + + from datacustomcode.credentials import AuthType, Credentials + from datacustomcode.deploy import AccessTokenResponse + + # Load credentials (freshly, not cached) + credentials = Credentials.from_available(profile=self.credentials_profile) + + token_url = f"{credentials.login_url.rstrip('/')}/services/oauth2/token" + + if credentials.auth_type == AuthType.OAUTH_TOKENS: + data = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token, + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + } + elif credentials.auth_type == AuthType.CLIENT_CREDENTIALS: + data = { + "grant_type": "client_credentials", + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + } + else: + raise RuntimeError(f"Unsupported auth_type: {credentials.auth_type}") + + try: + response = requests.post(token_url, data=data, timeout=30) + response.raise_for_status() + token_data = response.json() + + access_token = token_data.get("access_token") + instance_url = token_data.get("instance_url") + + if not access_token or not instance_url: + raise RuntimeError( + "Token refresh response missing access_token or instance_url" + ) + + return AccessTokenResponse( + access_token=access_token, instance_url=instance_url + ) + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to get access token: {e}") from e + + +class SFCLITokenProvider(TokenProvider): + """Token provider that uses Salesforce CLI for authentication.""" + + def __init__(self, sf_cli_org: str): + self.sf_cli_org = sf_cli_org + + def get_token(self) -> "AccessTokenResponse": + """Get token from Salesforce SF CLI""" + import json + import subprocess + + from datacustomcode.deploy import AccessTokenResponse + + try: + result = subprocess.run( + ["sf", "org", "display", "--target-org", self.sf_cli_org, "--json"], + capture_output=True, + text=True, + check=True, + timeout=30, + ) + except FileNotFoundError as exc: + raise RuntimeError( + "The 'sf' command was not found. " + "Install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli" + ) from exc + except subprocess.TimeoutExpired as exc: + raise RuntimeError( + f"'sf org display' timed out for org '{self.sf_cli_org}'" + ) from exc + except subprocess.CalledProcessError as exc: + raise RuntimeError( + f"'sf org display' failed for org '{self.sf_cli_org}': {exc.stderr}" + ) from exc + + try: + data = json.loads(result.stdout) + except json.JSONDecodeError as exc: + raise RuntimeError( + f"Failed to parse JSON from 'sf org display': {result.stdout}" + ) from exc + + if data.get("status") != 0: + raise RuntimeError( + f"SF CLI error for org '{self.sf_cli_org}': " + f"{data.get('message', 'unknown error')}" + ) + + result_data = data.get("result", {}) + access_token = result_data.get("accessToken") + instance_url = result_data.get("instanceUrl") + + if not access_token or not instance_url: + raise RuntimeError( + f"'sf org display' did not return an access token or instance URL " + f"for org '{self.sf_cli_org}'" + ) + + return AccessTokenResponse(access_token=access_token, instance_url=instance_url) diff --git a/tests/file/test_path_default.py b/tests/file/test_path_default.py index c92eacd..8350122 100644 --- a/tests/file/test_path_default.py +++ b/tests/file/test_path_default.py @@ -72,7 +72,7 @@ def test_find_file_path_file_not_found(self): with pytest.raises( FileNotFoundError, - match="File 'test.txt' not found in any search location", + match=r"File 'test\.txt' not found in any search location", ): finder.find_file_path("test.txt") diff --git a/tests/io/reader/test_query_api.py b/tests/io/reader/test_query_api.py index 33cb54d..f9baa81 100644 --- a/tests/io/reader/test_query_api.py +++ b/tests/io/reader/test_query_api.py @@ -121,8 +121,12 @@ def test_pandas_to_spark_schema_datetime_types(self): assert field_dict[field_name].nullable # Verify the actual pandas dtypes to ensure our test data has the expected types - assert str(df["datetime_ns"].dtype) == "datetime64[ns]" - assert str(df["datetime_ns_utc"].dtype) == "datetime64[ns, UTC]" + # Pandas may use 'ns' or 'us' precision depending on version + assert str(df["datetime_ns"].dtype) in ["datetime64[ns]", "datetime64[us]"] + assert str(df["datetime_ns_utc"].dtype) in [ + "datetime64[ns, UTC]", + "datetime64[us, UTC]", + ] assert str(df["datetime_ms"].dtype) == "datetime64[ms]" assert str(df["datetime_ms_utc"].dtype) == "datetime64[ms]" diff --git a/tests/test_auth.py b/tests/test_auth.py index 1e335ac..974b3c0 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -148,7 +148,7 @@ def test_run_oauth_callback_server_valid_uri( mock_server.server_address = ("localhost", 5555) mock_tcpserver.return_value = mock_server - server, port = _run_oauth_callback_server(redirect_uri, auth_code_queue) + _server, port = _run_oauth_callback_server(redirect_uri, auth_code_queue) assert port == 5555 mock_tcpserver.assert_called_once() @@ -186,7 +186,7 @@ def test_run_oauth_callback_server_different_port( mock_server.server_address = ("localhost", 8080) mock_tcpserver.return_value = mock_server - server, port = _run_oauth_callback_server(redirect_uri, auth_code_queue) + _server, port = _run_oauth_callback_server(redirect_uri, auth_code_queue) assert port == 8080 mock_tcpserver.assert_called_once() diff --git a/tests/test_cli.py b/tests/test_cli.py index ab18fa1..e26cbdc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -72,11 +72,14 @@ def test_init_command( class TestDeploy: @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.credentials.Credentials.from_available") - def test_deploy_command_success(self, mock_credentials, mock_deploy_full): + @patch("datacustomcode.token_provider.CredentialsTokenProvider") + def test_deploy_command_success(self, mock_token_provider, mock_deploy_full): """Test successful deploy command.""" - # Mock credentials - mock_creds = mock_credentials.return_value + # Mock token provider + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( + access_token="test_token", instance_url="https://instance.example.com" + ) runner = CliRunner() with runner.isolated_filesystem(): @@ -86,7 +89,7 @@ def test_deploy_command_success(self, mock_credentials, mock_deploy_full): result = runner.invoke(deploy, ["--name", "test-job", "--version", "1.0.0"]) assert result.exit_code == 0 - mock_credentials.assert_called_once() + mock_token_provider.assert_called_once_with("default") mock_deploy_full.assert_called_once() # Check that deploy_full was called with correct arguments @@ -97,14 +100,20 @@ def test_deploy_command_success(self, mock_credentials, mock_deploy_full): ) # metadata (hyphen sanitized to underscore) assert call_args[0][1].version == "1.0.0" assert call_args[0][1].description == "Custom Data Transform Code" - assert call_args[0][2] == mock_creds # credentials + assert call_args[0][2].access_token == "test_token" + assert call_args[0][2].instance_url == "https://instance.example.com" @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.credentials.Credentials.from_available") + @patch("datacustomcode.token_provider.CredentialsTokenProvider") def test_deploy_command_function_invoke_options( - self, mock_credentials, mock_deploy_full + self, mock_token_provider, mock_deploy_full ): """Test deploy command with function invoke options.""" + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( + access_token="test_token", instance_url="https://instance.example.com" + ) + runner = CliRunner() with runner.isolated_filesystem(): # Create test payload directory @@ -123,13 +132,13 @@ def test_deploy_command_function_invoke_options( call_args = mock_deploy_full.call_args assert call_args[0][1].functionInvokeOptions == ["option1", "option2"] - @patch("datacustomcode.credentials.Credentials.from_available") - def test_deploy_command_credentials_error(self, mock_credentials): + @patch("datacustomcode.token_provider.CredentialsTokenProvider") + def test_deploy_command_credentials_error(self, mock_token_provider): """Test deploy command when credentials are not available.""" - # Mock credentials to raise ValueError - mock_credentials.side_effect = ValueError( - "Credentials not found in env or ini file. " - "Run `datacustomcode configure` to create a credentials file." + # Mock token provider to raise RuntimeError + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.side_effect = RuntimeError( + "Failed to refresh access token" ) runner = CliRunner() @@ -140,12 +149,17 @@ def test_deploy_command_credentials_error(self, mock_credentials): result = runner.invoke(deploy, ["--name", "test-job"]) assert result.exit_code == 1 - assert "Error: Credentials not found in env or ini file" in result.output + assert "Error: Failed to refresh access token" in result.output @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.credentials.Credentials.from_available") - def test_deploy_command_custom_path(self, mock_credentials, mock_deploy_full): + @patch("datacustomcode.token_provider.CredentialsTokenProvider") + def test_deploy_command_custom_path(self, mock_token_provider, mock_deploy_full): """Test deploy command with custom path.""" + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( + access_token="test_token", instance_url="https://instance.example.com" + ) + runner = CliRunner() with runner.isolated_filesystem(): # Create test directory @@ -163,11 +177,16 @@ def test_deploy_command_custom_path(self, mock_credentials, mock_deploy_full): assert call_args[0][0] == "custom_path" # path @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.credentials.Credentials.from_available") + @patch("datacustomcode.token_provider.CredentialsTokenProvider") def test_deploy_command_custom_description( - self, mock_credentials, mock_deploy_full + self, mock_token_provider, mock_deploy_full ): """Test deploy command with custom description.""" + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( + access_token="test_token", instance_url="https://instance.example.com" + ) + runner = CliRunner() with runner.isolated_filesystem(): # Create test payload directory @@ -185,13 +204,13 @@ def test_deploy_command_custom_description( assert call_args[0][1].description == "Custom description" @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") - def test_deploy_command_sf_cli_org(self, mock_sf_cli_token, mock_deploy_full): + @patch("datacustomcode.token_provider.SFCLITokenProvider") + def test_deploy_command_sf_cli_org(self, mock_sf_cli_provider, mock_deploy_full): """Test deploy command with --sf-cli-org flag.""" - mock_token = AccessTokenResponse( + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="test_token", instance_url="https://test.salesforce.com" ) - mock_sf_cli_token.return_value = mock_token runner = CliRunner() with runner.isolated_filesystem(): @@ -201,15 +220,20 @@ def test_deploy_command_sf_cli_org(self, mock_sf_cli_token, mock_deploy_full): ) assert result.exit_code == 0 - mock_sf_cli_token.assert_called_once_with("my-org") + mock_sf_cli_provider.assert_called_once_with("my-org") mock_deploy_full.assert_called_once() call_args = mock_deploy_full.call_args - assert call_args[0][2] == mock_token # AccessTokenResponse passed directly + # Check AccessTokenResponse passed directly + assert call_args[0][2].access_token == "test_token" + assert call_args[0][2].instance_url == "https://test.salesforce.com" - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") - def test_deploy_command_sf_cli_org_error(self, mock_sf_cli_token): + @patch("datacustomcode.token_provider.SFCLITokenProvider") + def test_deploy_command_sf_cli_org_error(self, mock_sf_cli_provider): """Test deploy command when --sf-cli-org fails.""" - mock_sf_cli_token.side_effect = RuntimeError("sf command not found") + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.side_effect = RuntimeError( + "sf command not found" + ) runner = CliRunner() with runner.isolated_filesystem(): diff --git a/tests/test_client.py b/tests/test_client.py index 5981ca9..c2cf46a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,7 +17,6 @@ ) from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode -from datacustomcode.proxy.client.base import BaseProxyClient class MockDataCloudReader(BaseDataCloudReader): @@ -76,13 +75,6 @@ def mock_config(mock_spark): ) -@pytest.fixture -def mock_proxy(): - """Mock proxy client to avoid starting Spark when reader/writer are provided.""" - proxy = MagicMock(spec=BaseProxyClient) - return proxy - - @pytest.fixture def reset_client(): """Reset the Client singleton between tests.""" @@ -93,12 +85,12 @@ def reset_client(): class TestClient: - def test_singleton_pattern(self, reset_client, mock_spark, mock_proxy): + def test_singleton_pattern(self, reset_client, mock_spark): """Test that Client behaves as a singleton.""" reader = MockDataCloudReader(mock_spark) writer = MockDataCloudWriter(mock_spark) - client1 = Client(reader=reader, writer=writer, proxy=mock_proxy) + client1 = Client(reader=reader, writer=writer) client2 = Client() assert client1 is client2 @@ -144,38 +136,38 @@ def test_initialization_with_config(self, mock_config, reset_client, mock_spark) assert client._reader is mock_reader assert client._writer is mock_writer - def test_read_dlo(self, reset_client, mock_spark, mock_proxy): + def test_read_dlo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dlo.return_value = mock_df - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) result = client.read_dlo("test_dlo") reader.read_dlo.assert_called_once_with("test_dlo") assert result is mock_df assert "test_dlo" in client._data_layer_history[DataCloudObjectType.DLO] - def test_read_dmo(self, reset_client, mock_spark, mock_proxy): + def test_read_dmo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dmo.return_value = mock_df - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) result = client.read_dmo("test_dmo") reader.read_dmo.assert_called_once_with("test_dmo") assert result is mock_df assert "test_dmo" in client._data_layer_history[DataCloudObjectType.DMO] - def test_write_to_dlo(self, reset_client, mock_spark, mock_proxy): + def test_write_to_dlo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) client._record_dlo_access("some_dlo") client.write_to_dlo("test_dlo", mock_df, WriteMode.APPEND, extra_param=True) @@ -184,12 +176,12 @@ def test_write_to_dlo(self, reset_client, mock_spark, mock_proxy): "test_dlo", mock_df, WriteMode.APPEND, extra_param=True ) - def test_write_to_dmo(self, reset_client, mock_spark, mock_proxy): + def test_write_to_dmo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) client._record_dmo_access("some_dmo") client.write_to_dmo("test_dmo", mock_df, WriteMode.OVERWRITE, extra_param=True) @@ -198,13 +190,13 @@ def test_write_to_dmo(self, reset_client, mock_spark, mock_proxy): "test_dmo", mock_df, WriteMode.OVERWRITE, extra_param=True ) - def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark, mock_proxy): + def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark): """Test that mixing DLOs and DMOs raises an exception.""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) client._record_dlo_access("test_dlo") with pytest.raises(DataCloudAccessLayerException) as exc_info: @@ -212,13 +204,13 @@ def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark, mock_pro assert "test_dlo" in str(exc_info.value) - def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark, mock_proxy): + def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark): """Test that mixing DMOs and DLOs raises an exception (converse case).""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) client._record_dmo_access("test_dmo") with pytest.raises(DataCloudAccessLayerException) as exc_info: @@ -226,14 +218,14 @@ def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark, mock_pro assert "test_dmo" in str(exc_info.value) - def test_read_pattern_flow(self, reset_client, mock_spark, mock_proxy): + def test_read_pattern_flow(self, reset_client, mock_spark): """Test a complete flow of reading and writing within the same object type.""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dlo.return_value = mock_df - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) df = client.read_dlo("source_dlo") client.write_to_dlo("target_dlo", df, WriteMode.APPEND) @@ -247,7 +239,7 @@ def test_read_pattern_flow(self, reset_client, mock_spark, mock_proxy): # Reset for DMO test Client._instance = None - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) reader.read_dmo.return_value = mock_df df = client.read_dmo("source_dmo") diff --git a/tests/test_deploy.py b/tests/test_deploy.py index a6ab5aa..2fd3ce2 100644 --- a/tests/test_deploy.py +++ b/tests/test_deploy.py @@ -1,6 +1,5 @@ """Tests for the deploy module.""" -import json from unittest.mock import ( ANY, MagicMock, @@ -12,7 +11,6 @@ import pytest import requests -from datacustomcode.credentials import AuthType, Credentials from datacustomcode.deploy import ( DloPermission, Permissions, @@ -28,8 +26,6 @@ DataTransformConfig, DeploymentsResponse, _make_api_call, - _retrieve_access_token, - _retrieve_access_token_from_sf_cli, _sanitize_api_name, create_data_transform, create_deployment, @@ -598,58 +594,6 @@ def test_make_api_call_invalid_response(self, mock_request): _make_api_call("https://example.com", "GET") -class TestRetrieveAccessToken: - @patch("datacustomcode.deploy._make_api_call") - def test_retrieve_access_token(self, mock_make_api_call): - """Test retrieving access token.""" - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.OAUTH_TOKENS, - refresh_token="refresh", - client_secret="secret", - ) - - mock_make_api_call.return_value = { - "access_token": "test_token", - "instance_url": "https://instance.example.com", - } - - result = _retrieve_access_token(credentials) - - mock_make_api_call.assert_called_once() - call_args = mock_make_api_call.call_args - assert call_args.kwargs["data"]["grant_type"] == "refresh_token" - assert call_args.kwargs["data"]["refresh_token"] == "refresh" - assert isinstance(result, AccessTokenResponse) - assert result.access_token == "test_token" - assert result.instance_url == "https://instance.example.com" - - @patch("datacustomcode.deploy._make_api_call") - def test_retrieve_access_token_client_credentials(self, mock_make_api_call): - """Test retrieving access token with client credentials flow.""" - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.CLIENT_CREDENTIALS, - client_secret="secret", - ) - - mock_make_api_call.return_value = { - "access_token": "test_token", - "instance_url": "https://instance.example.com", - } - - result = _retrieve_access_token(credentials) - - mock_make_api_call.assert_called_once() - call_args = mock_make_api_call.call_args - assert call_args.kwargs["data"]["grant_type"] == "client_credentials" - assert isinstance(result, AccessTokenResponse) - assert result.access_token == "test_token" - assert result.instance_url == "https://instance.example.com" - - class TestCreateDeployment: @patch("datacustomcode.deploy._make_api_call") def test_create_deployment_success(self, mock_make_api_call): @@ -963,7 +907,7 @@ def test_verify_data_transform_config_missing(self, mock_exists): mock_exists.return_value = False with pytest.raises( FileNotFoundError, - match="config.json not found at /test/dir/payload/config.json", + match=r"config\.json not found at /test/dir/payload/config\.json", ): get_config("/test/dir/payload") @@ -974,7 +918,7 @@ def test_verify_data_transform_config_invalid_json(self, mock_file, mock_exists) mock_exists.return_value = True with pytest.raises( ValueError, - match="config.json at /test/dir/payload/config.json is not valid JSON", + match=r"config\.json at /test/dir/payload/config\.json is not valid JSON", ): get_config("/test/dir/payload") @@ -985,8 +929,8 @@ def test_verify_data_transform_config_missing_fields(self, mock_file, mock_exist mock_exists.return_value = True with pytest.raises( ValueError, - match="config.json at /test/dir/payload/config.json is missing " - "required fields: entryPoint, dataspace, permissions", + match=r"config\.json at /test/dir/payload/config\.json is missing " + r"required fields: entryPoint, dataspace, permissions", ): get_config("/test/dir/payload") @@ -1048,10 +992,8 @@ class TestDeployFull: @patch("datacustomcode.deploy.upload_zip") @patch("datacustomcode.deploy.zip") @patch("datacustomcode.deploy.create_deployment") - @patch("datacustomcode.deploy._retrieve_access_token") def test_deploy_full( self, - mock_retrieve_token, mock_create_deployment, mock_zip, mock_upload_zip, @@ -1070,13 +1012,6 @@ def test_deploy_full( ), ) mock_get_config.return_value = data_transform_config - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.OAUTH_TOKENS, - refresh_token="refresh", - client_secret="secret", - ) metadata = CodeExtensionMetadata( name="test_job", version="1.0.0", @@ -1090,16 +1025,14 @@ def test_deploy_full( access_token = AccessTokenResponse( access_token="test_token", instance_url="https://instance.example.com" ) - mock_retrieve_token.return_value = access_token mock_create_deployment.return_value = CreateDeploymentResponse( fileUploadUrl="https://upload.example.com" ) # Call function - result = deploy_full("/test/dir", metadata, credentials, "default", callback) + result = deploy_full("/test/dir", metadata, access_token, "default", callback) # Assertions - mock_retrieve_token.assert_called_once_with(credentials) mock_get_config.assert_called_once_with("/test/dir") mock_create_deployment.assert_called_once_with(access_token, metadata) mock_zip.assert_called_once_with("/test/dir", "default", "script") @@ -1110,7 +1043,6 @@ def test_deploy_full( ) assert result == access_token - @patch("datacustomcode.deploy._retrieve_access_token") @patch("datacustomcode.deploy.get_config") @patch("datacustomcode.deploy.create_deployment") @patch("datacustomcode.deploy.zip") @@ -1125,7 +1057,6 @@ def test_deploy_full_client_credentials( mock_zip, mock_create_deployment, mock_get_config, - mock_retrieve_token, ): """Test full deployment process using client credentials auth.""" data_transform_config = DataTransformConfig( @@ -1138,12 +1069,6 @@ def test_deploy_full_client_credentials( ), ) mock_get_config.return_value = data_transform_config - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.CLIENT_CREDENTIALS, - client_secret="secret", - ) metadata = CodeExtensionMetadata( name="test_job", version="1.0.0", @@ -1156,14 +1081,12 @@ def test_deploy_full_client_credentials( access_token = AccessTokenResponse( access_token="test_token", instance_url="https://instance.example.com" ) - mock_retrieve_token.return_value = access_token mock_create_deployment.return_value = CreateDeploymentResponse( fileUploadUrl="https://upload.example.com" ) - result = deploy_full("/test/dir", metadata, credentials, "default", callback) + result = deploy_full("/test/dir", metadata, access_token, "default", callback) - mock_retrieve_token.assert_called_once_with(credentials) mock_get_config.assert_called_once_with("/test/dir") mock_create_deployment.assert_called_once_with(access_token, metadata) mock_zip.assert_called_once_with("/test/dir", "default", "script") @@ -1199,7 +1122,6 @@ def test_run_data_transform(self, mock_make_api_call): class TestDeployFullWithDockerIntegration: - @patch("datacustomcode.deploy._retrieve_access_token") @patch("datacustomcode.deploy.create_deployment") @patch("datacustomcode.deploy.zip") @patch("datacustomcode.deploy.upload_zip") @@ -1216,16 +1138,8 @@ def test_deploy_full_happy_path( mock_upload_zip, mock_zip, mock_create_deployment, - mock_retrieve_token, ): """Test full deployment process with Docker dependency building.""" - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.OAUTH_TOKENS, - refresh_token="refresh", - client_secret="secret", - ) metadata = CodeExtensionMetadata( name="test_job", version="1.0.0", @@ -1249,27 +1163,17 @@ def test_deploy_full_happy_path( access_token = AccessTokenResponse( access_token="test_token", instance_url="https://instance.example.com" ) - mock_retrieve_token.return_value = access_token mock_create_deployment.return_value = CreateDeploymentResponse( fileUploadUrl="https://upload.example.com" ) # Mock that requirements.txt exists and has dependencies mock_has_requirements.return_value = True - data_transform_config = DataTransformConfig( - sdkVersion="1.0.0", - entryPoint="entrypoint.py", - dataspace="test_dataspace", - permissions=Permissions( - read=DloPermission(dlo=["input_dlo"]), - write=DloPermission(dlo=["output_dlo"]), - ), - ) + # Call function - result = deploy_full("/test/dir", metadata, credentials, "default", callback) + result = deploy_full("/test/dir", metadata, access_token, "default", callback) # Assertions - mock_retrieve_token.assert_called_once_with(credentials) mock_get_config.assert_called_once_with("/test/dir") mock_create_deployment.assert_called_once_with(access_token, metadata) mock_zip.assert_called_once_with("/test/dir", "default", "script") @@ -1281,104 +1185,6 @@ def test_deploy_full_happy_path( assert result == access_token -class TestRetrieveAccessTokenFromSFCLI: - """Tests for _retrieve_access_token_from_sf_cli.""" - - SF_CLI_OUTPUT = json.dumps( - { - "status": 0, - "result": { - "accessToken": "sf_access_token", - "instanceUrl": "https://sf.salesforce.com", - }, - } - ) - - @patch("datacustomcode.deploy.subprocess.run") - def test_happy_path(self, mock_run): - """Successful sf org display returns AccessTokenResponse.""" - mock_run.return_value = MagicMock(stdout=self.SF_CLI_OUTPUT, returncode=0) - - result = _retrieve_access_token_from_sf_cli("my-org") - - assert result.access_token == "sf_access_token" - assert result.instance_url == "https://sf.salesforce.com" - mock_run.assert_called_once_with( - ["sf", "org", "display", "--target-org", "my-org", "--json"], - capture_output=True, - text=True, - check=True, - timeout=30, - ) - - @patch("datacustomcode.deploy.subprocess.run") - def test_file_not_found(self, mock_run): - """FileNotFoundError raised when sf CLI is not installed.""" - mock_run.side_effect = FileNotFoundError("No such file or directory: 'sf'") - - with pytest.raises(RuntimeError, match="'sf' command was not found"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_timeout_expired(self, mock_run): - """TimeoutExpired raised when sf CLI times out.""" - import subprocess - - mock_run.side_effect = subprocess.TimeoutExpired(cmd="sf", timeout=30) - - with pytest.raises(RuntimeError, match="timed out"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_called_process_error(self, mock_run): - """CalledProcessError raised when sf CLI exits non-zero.""" - import subprocess - - mock_run.side_effect = subprocess.CalledProcessError( - returncode=1, cmd="sf", stderr="Org not found" - ) - - with pytest.raises(RuntimeError, match="failed for org"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_json_decode_error(self, mock_run): - """RuntimeError raised when output is not valid JSON.""" - mock_run.return_value = MagicMock(stdout="not-json", returncode=0) - - with pytest.raises(RuntimeError, match="Failed to parse"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_nonzero_status_in_json(self, mock_run): - """RuntimeError raised when JSON status field is non-zero.""" - output = json.dumps({"status": 1, "message": "org not found"}) - mock_run.return_value = MagicMock(stdout=output, returncode=0) - - with pytest.raises(RuntimeError, match="SF CLI error"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_missing_access_token(self, mock_run): - """RuntimeError raised when accessToken is absent.""" - output = json.dumps( - {"status": 0, "result": {"instanceUrl": "https://sf.salesforce.com"}} - ) - mock_run.return_value = MagicMock(stdout=output, returncode=0) - - with pytest.raises(RuntimeError, match="did not return"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_missing_instance_url(self, mock_run): - """RuntimeError raised when instanceUrl is absent.""" - output = json.dumps({"status": 0, "result": {"accessToken": "sf_access_token"}}) - mock_run.return_value = MagicMock(stdout=output, returncode=0) - - with pytest.raises(RuntimeError, match="did not return"): - _retrieve_access_token_from_sf_cli("my-org") - - class TestDeployFullWithAccessTokenResponse: """Test deploy_full when passed an AccessTokenResponse directly.""" @@ -1388,10 +1194,8 @@ class TestDeployFullWithAccessTokenResponse: @patch("datacustomcode.deploy.zip") @patch("datacustomcode.deploy.create_deployment") @patch("datacustomcode.deploy.get_config") - @patch("datacustomcode.deploy._retrieve_access_token") def test_deploy_full_with_access_token_response_skips_token_exchange( self, - mock_retrieve_token, mock_get_config, mock_create_deployment, mock_zip, @@ -1399,7 +1203,7 @@ def test_deploy_full_with_access_token_response_skips_token_exchange( mock_wait, mock_create_transform, ): - """deploy_full skips token exchange when given an AccessTokenResponse.""" + """deploy_full now only accepts AccessTokenResponse.""" access_token = AccessTokenResponse( access_token="direct_token", instance_url="https://instance.example.com" ) @@ -1417,7 +1221,6 @@ def test_deploy_full_with_access_token_response_skips_token_exchange( result = deploy_full("/test/dir", metadata, access_token, "default") - mock_retrieve_token.assert_not_called() mock_create_deployment.assert_called_once_with(access_token, metadata) assert result == access_token diff --git a/tests/test_llm_gateway.py b/tests/test_llm_gateway.py index e05bc02..7875e21 100644 --- a/tests/test_llm_gateway.py +++ b/tests/test_llm_gateway.py @@ -3,8 +3,6 @@ from pydantic import ValidationError import pytest -from datacustomcode.llm_gateway.base import LLMGateway -from datacustomcode.llm_gateway.default import DefaultLLMGateway from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest from datacustomcode.llm_gateway.types.generate_text_request_builder import ( GenerateTextRequestBuilder, @@ -210,28 +208,3 @@ def test_builder_with_minimal_dict(self): response = GenerateTextResponseBuilder.build(response_dict) assert response.status_code == 200 assert response.version == "v1" # Default value - - -class TestDefaultLLMGateway: - """Test DefaultLLMGateway implementation.""" - - def test_default_gateway_is_llm_gateway(self): - """Test DefaultLLMGateway inherits from LLMGateway.""" - gateway = DefaultLLMGateway() - assert isinstance(gateway, LLMGateway) - - def test_generate_text_returns_response(self): - """Test generate_text returns GenerateTextResponse.""" - gateway = DefaultLLMGateway() - request = GenerateTextRequest(model_name="gpt-4", prompt="Hello") - response = gateway.generate_text(request) - assert isinstance(response, GenerateTextResponse) - - def test_generate_text_success_response(self): - """Test generate_text returns successful response.""" - gateway = DefaultLLMGateway() - request = GenerateTextRequest(model_name="gpt-4", prompt="Hello") - response = gateway.generate_text(request) - assert response.is_success is True - assert response.status_code == 200 - assert len(response.text) > 0 diff --git a/tests/test_scan.py b/tests/test_scan.py index 528f728..bfdad8b 100644 --- a/tests/test_scan.py +++ b/tests/test_scan.py @@ -524,8 +524,8 @@ def test_rejects_missing_dataspace(self): # Should raise ValueError when dataspace field is missing with pytest.raises( ValueError, - match="dataspace must be defined. Please add a 'dataspace' field to " - "the config.json file.", + match=r"dataspace must be defined\. Please add a 'dataspace' field to " + r"the config\.json file\.", ): update_config(temp_path) finally: diff --git a/tests/test_sf_cli_contract.py b/tests/test_sf_cli_contract.py index e60f68f..b96ab35 100644 --- a/tests/test_sf_cli_contract.py +++ b/tests/test_sf_cli_contract.py @@ -154,48 +154,51 @@ class TestDeployArgContract: "--cpu-size", "CPU_2XL", ] # fmt: skip - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") + @patch("datacustomcode.token_provider.SFCLITokenProvider") @patch("datacustomcode.deploy.deploy_full") @patch("datacustomcode.cli.find_base_directory") @patch("datacustomcode.cli.get_package_type") def test_accepts_required_flags( - self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_token + self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_provider ): mock_find_base.return_value = "payload" mock_pkg_type.return_value = "script" - mock_sf_cli_token.return_value = AccessTokenResponse( + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="tok", instance_url="https://example.com" ) runner = CliRunner() result = runner.invoke(deploy, self._BASE_ARGS) assert result.exit_code != 2, result.output - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") + @patch("datacustomcode.token_provider.SFCLITokenProvider") @patch("datacustomcode.deploy.deploy_full") @patch("datacustomcode.cli.find_base_directory") @patch("datacustomcode.cli.get_package_type") def test_accepts_network_flag( - self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_token + self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_provider ): mock_find_base.return_value = "payload" mock_pkg_type.return_value = "script" - mock_sf_cli_token.return_value = AccessTokenResponse( + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="tok", instance_url="https://example.com" ) runner = CliRunner() result = runner.invoke(deploy, [*self._BASE_ARGS, "--network", "custom"]) assert result.exit_code != 2, result.output - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") + @patch("datacustomcode.token_provider.SFCLITokenProvider") @patch("datacustomcode.deploy.deploy_full") @patch("datacustomcode.cli.find_base_directory") @patch("datacustomcode.cli.get_package_type") def test_accepts_function_invoke_opt_flag( - self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_token + self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_provider ): mock_find_base.return_value = "payload" mock_pkg_type.return_value = "function" - mock_sf_cli_token.return_value = AccessTokenResponse( + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="tok", instance_url="https://example.com" ) runner = CliRunner() diff --git a/tests/test_token_provider.py b/tests/test_token_provider.py new file mode 100644 index 0000000..5125e37 --- /dev/null +++ b/tests/test_token_provider.py @@ -0,0 +1,188 @@ +"""Tests for token_provider module.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from datacustomcode.credentials import AuthType, Credentials +from datacustomcode.deploy import AccessTokenResponse +from datacustomcode.token_provider import CredentialsTokenProvider, SFCLITokenProvider + + +class TestCredentialsTokenProvider: + """Tests for CredentialsTokenProvider.""" + + def test_oauth_tokens_auth_type(self): + """Test token retrieval with OAUTH_TOKENS auth type.""" + provider = CredentialsTokenProvider("test_profile") + + mock_credentials = Credentials( + login_url="https://login.salesforce.com", + client_id="test_client_id", + client_secret="test_secret", + refresh_token="test_refresh_token", + auth_type=AuthType.OAUTH_TOKENS, + ) + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with patch("requests.post") as mock_post: + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "test_access_token", + "instance_url": "https://instance.salesforce.com", + } + mock_post.return_value = mock_response + + result = provider.get_token() + + assert isinstance(result, AccessTokenResponse) + assert result.access_token == "test_access_token" + assert result.instance_url == "https://instance.salesforce.com" + + # Verify correct auth type was used + call_args = mock_post.call_args + assert call_args[1]["data"]["grant_type"] == "refresh_token" + assert call_args[1]["data"]["refresh_token"] == "test_refresh_token" + + def test_client_credentials_auth_type(self): + """Test token retrieval with CLIENT_CREDENTIALS auth type.""" + provider = CredentialsTokenProvider("test_profile") + + mock_credentials = Credentials( + login_url="https://login.salesforce.com", + client_id="test_client_id", + client_secret="test_secret", + auth_type=AuthType.CLIENT_CREDENTIALS, + ) + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with patch("requests.post") as mock_post: + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "test_access_token", + "instance_url": "https://instance.salesforce.com", + } + mock_post.return_value = mock_response + + result = provider.get_token() + + assert isinstance(result, AccessTokenResponse) + assert result.access_token == "test_access_token" + assert result.instance_url == "https://instance.salesforce.com" + + # Verify correct auth type was used + call_args = mock_post.call_args + assert call_args[1]["data"]["grant_type"] == "client_credentials" + assert "refresh_token" not in call_args[1]["data"] + + def test_unsupported_auth_type_raises_error(self): + """Test that unsupported auth type raises RuntimeError.""" + provider = CredentialsTokenProvider("test_profile") + + # Create credentials with invalid auth_type + mock_credentials = MagicMock() + mock_credentials.login_url = "https://login.salesforce.com" + mock_credentials.auth_type = "INVALID_AUTH_TYPE" + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with pytest.raises(RuntimeError, match="Unsupported auth_type"): + provider.get_token() + + def test_missing_access_token_raises_error(self): + """Test that missing access_token in response raises RuntimeError.""" + provider = CredentialsTokenProvider("test_profile") + + mock_credentials = Credentials( + login_url="https://login.salesforce.com", + client_id="test_client_id", + client_secret="test_secret", + refresh_token="test_refresh_token", + auth_type=AuthType.OAUTH_TOKENS, + ) + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with patch("requests.post") as mock_post: + mock_response = MagicMock() + mock_response.json.return_value = { + "instance_url": "https://instance.salesforce.com" + } + mock_post.return_value = mock_response + + with pytest.raises( + RuntimeError, match="missing access_token or instance_url" + ): + provider.get_token() + + def test_request_exception_raises_runtime_error(self): + """Test that request exceptions are wrapped in RuntimeError.""" + import requests + + provider = CredentialsTokenProvider("test_profile") + + mock_credentials = Credentials( + login_url="https://login.salesforce.com", + client_id="test_client_id", + client_secret="test_secret", + refresh_token="test_refresh_token", + auth_type=AuthType.OAUTH_TOKENS, + ) + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with patch( + "requests.post", + side_effect=requests.exceptions.RequestException("Network error"), + ): + with pytest.raises(RuntimeError, match="Failed to get access token"): + provider.get_token() + + +class TestSFCLITokenProvider: + """Tests for SFCLITokenProvider.""" + + def test_successful_token_retrieval(self): + """Test successful token retrieval from SF CLI.""" + import json + + provider = SFCLITokenProvider("test_org") + + cli_output = json.dumps( + { + "status": 0, + "result": { + "accessToken": "cli_access_token", + "instanceUrl": "https://cli.salesforce.com", + }, + } + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(stdout=cli_output) + + result = provider.get_token() + + assert isinstance(result, AccessTokenResponse) + assert result.access_token == "cli_access_token" + assert result.instance_url == "https://cli.salesforce.com" + + def test_sf_command_not_found(self): + """Test that FileNotFoundError is wrapped in RuntimeError.""" + provider = SFCLITokenProvider("test_org") + + with patch("subprocess.run", side_effect=FileNotFoundError()): + with pytest.raises(RuntimeError, match="'sf' command was not found"): + provider.get_token()