From 5f201e3696d9700121f1a7eb70cd38a242da3afc Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Fri, 24 Apr 2026 14:59:17 -0400 Subject: [PATCH 01/14] test Einstein Prediction with actual production API --- src/datacustomcode/cli.py | 32 ++- src/datacustomcode/config.yaml | 8 +- src/datacustomcode/deploy.py | 89 +------- .../einstein_predictions/impl/default.py | 114 +++++++++- src/datacustomcode/io/reader/sf_cli.py | 61 +----- src/datacustomcode/run.py | 39 +++- src/datacustomcode/token_provider.py | 149 +++++++++++++ tests/io/reader/test_query_api.py | 8 +- tests/test_cli.py | 80 ++++--- tests/test_deploy.py | 207 +----------------- tests/test_sf_cli_contract.py | 21 +- tests/test_token_provider.py | 188 ++++++++++++++++ 12 files changed, 577 insertions(+), 419 deletions(-) create mode 100644 src/datacustomcode/token_provider.py create mode 100644 tests/test_token_provider.py diff --git a/src/datacustomcode/cli.py b/src/datacustomcode/cli.py index d8a00fc..c6e9c5c 100644 --- a/src/datacustomcode/cli.py +++ b/src/datacustomcode/cli.py @@ -179,14 +179,15 @@ def deploy( function_invoke_opt: str, sf_cli_org: Optional[str], ): - from datacustomcode.credentials import Credentials from datacustomcode.deploy import ( COMPUTE_TYPES, - AccessTokenResponse, CodeExtensionMetadata, - _retrieve_access_token_from_sf_cli, deploy_full, ) + from datacustomcode.token_provider import ( + CredentialsTokenProvider, + SFCLITokenProvider, + ) logger.debug("Deploying project") @@ -220,22 +221,15 @@ def deploy( function_invoke_options = function_invoke_opt.split(",") metadata.functionInvokeOptions = function_invoke_options - auth: Union[Credentials, AccessTokenResponse] - if sf_cli_org: - try: - auth = _retrieve_access_token_from_sf_cli(sf_cli_org) - except RuntimeError as e: - click.secho(f"Error: {e}", fg="red") - raise click.Abort() from None - else: - try: - auth = Credentials.from_available(profile=profile) - except ValueError as e: - click.secho( - f"Error: {e}", - fg="red", - ) - raise click.Abort() from None + try: + if sf_cli_org: + auth = SFCLITokenProvider(sf_cli_org).get_token() + else: + auth = CredentialsTokenProvider(profile).get_token() + except RuntimeError as e: + click.secho(f"Error: {e}", fg="red") + raise click.Abort() from None + deploy_full(path, metadata, auth, network) diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index d58bc7f..d630a81 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -26,4 +26,10 @@ proxy_config: einstein_predictions_config: type_config_name: DefaultEinsteinPredictions - options: {} + options: + credentials_profile: default + +llm_gateway_config: + type_config_name: DefaultLLMGateway + options: + credentials_profile: default diff --git a/src/datacustomcode/deploy.py b/src/datacustomcode/deploy.py index 4505438..114252a 100644 --- a/src/datacustomcode/deploy.py +++ b/src/datacustomcode/deploy.py @@ -19,11 +19,9 @@ import os import re import shutil -import subprocess import tempfile import time from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -37,15 +35,10 @@ import requests from datacustomcode.cmd import cmd_output -from datacustomcode.credentials import AuthType from datacustomcode.scan import find_base_directory, get_package_type -if TYPE_CHECKING: - from datacustomcode.credentials import Credentials - DATA_CUSTOM_CODE_PATH = "services/data/v63.0/ssot/data-custom-code" DATA_TRANSFORMS_PATH = "services/data/v63.0/ssot/data-transforms" -AUTH_PATH = "services/oauth2/token" WAIT_FOR_DEPLOYMENT_TIMEOUT = 3000 # Available compute types for Data Cloud deployments. @@ -163,80 +156,6 @@ class AccessTokenResponse(BaseModel): instance_url: str -def _retrieve_access_token(credentials: Credentials) -> AccessTokenResponse: - """Get an access token for the Salesforce API.""" - logger.debug("Getting oauth token...") - - url = f"{credentials.login_url.rstrip('/')}/{AUTH_PATH.lstrip('/')}" - - if credentials.auth_type == AuthType.OAUTH_TOKENS: - data = { - "grant_type": "refresh_token", - "refresh_token": credentials.refresh_token, - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - } - elif credentials.auth_type == AuthType.CLIENT_CREDENTIALS: - data = { - "grant_type": "client_credentials", - "client_id": credentials.client_id, - "client_secret": credentials.client_secret, - } - else: - raise ValueError(f"Unsupported auth_type: {credentials.auth_type}") - - response = _make_api_call(url, "POST", data=data) - return AccessTokenResponse(**response) - - -def _retrieve_access_token_from_sf_cli(sf_cli_org: str) -> AccessTokenResponse: - """Get an access token from the Salesforce CLI.""" - try: - result = subprocess.run( - ["sf", "org", "display", "--target-org", sf_cli_org, "--json"], - capture_output=True, - text=True, - check=True, - timeout=30, - ) - except FileNotFoundError as exc: - raise RuntimeError( - "The 'sf' command was not found. " - "Please install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli" - ) from exc - except subprocess.TimeoutExpired as exc: - raise RuntimeError( - f"'sf org display' timed out for org '{sf_cli_org}'" - ) from exc - except subprocess.CalledProcessError as exc: - raise RuntimeError( - f"'sf org display' failed for org '{sf_cli_org}'.\n" - f"Ensure the org is authenticated via 'sf org login web'.\n" - f"stderr: {exc.stderr.strip()}" - ) from exc - - try: - data = json.loads(result.stdout) - except json.JSONDecodeError as exc: - raise RuntimeError(f"Failed to parse 'sf org display' output: {exc}") from exc - - if data.get("status") != 0: - raise RuntimeError( - f"SF CLI error for org '{sf_cli_org}': " - f"{data.get('message', 'unknown error')}" - ) - - org_result = data.get("result", {}) - access_token = org_result.get("accessToken") - instance_url = org_result.get("instanceUrl") - if not access_token or not instance_url: - raise RuntimeError( - f"'sf org display' did not return an access token or instance URL " - f"for org '{sf_cli_org}'" - ) - return AccessTokenResponse(access_token=access_token, instance_url=instance_url) - - class CreateDeploymentResponse(BaseModel): fileUploadUrl: str @@ -567,16 +486,11 @@ def zip( def deploy_full( directory: str, metadata: CodeExtensionMetadata, - credentials: Union["Credentials", AccessTokenResponse], + access_token: AccessTokenResponse, docker_network: str, callback=None, ) -> AccessTokenResponse: """Deploy a data transform in the DataCloud.""" - if isinstance(credentials, AccessTokenResponse): - access_token = credentials - else: - access_token = _retrieve_access_token(credentials) - # prepare payload config = get_config(directory) @@ -587,7 +501,6 @@ def deploy_full( wait_for_deployment(access_token, metadata, callback) # create data transform - if isinstance(config, DataTransformConfig): create_data_transform(directory, access_token, metadata, config) return access_token diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index 1f741fc..ae138e1 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -13,23 +13,131 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import ( + Any, + ClassVar, + Dict, + List, + Optional, +) + +from loguru import logger +import requests + from datacustomcode.einstein_predictions.base import EinsteinPredictions from datacustomcode.einstein_predictions.types import ( PredictionRequest, PredictionResponse, + PredictionType, +) +from datacustomcode.token_provider import ( + CredentialsTokenProvider, + SFCLITokenProvider, + TokenProvider, ) class DefaultEinsteinPredictions(EinsteinPredictions): CONFIG_NAME = "DefaultEinsteinPredictions" + EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1" - def __init__(self, **kwargs): + ENDPOINT_MAP: ClassVar[dict[PredictionType, str]] = { + PredictionType.REGRESSION: "regression", + PredictionType.CLUSTERING: "clustering", + PredictionType.CLASSIFICATION: "classification", + PredictionType.BINARY_CLASSIFICATION: "binary-classification", + PredictionType.MULTI_OUTCOME: "multi-outcome", + } + + def __init__( + self, + credentials_profile: Optional[str] = None, + sf_cli_org: Optional[str] = None, + **kwargs, + ): super().__init__(**kwargs) + if sf_cli_org: + self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org) + logger.debug(f"Using SF CLI token provider for org: {sf_cli_org}") + else: + profile = credentials_profile or "default" + self._token_provider = CredentialsTokenProvider(profile) + logger.debug(f"Using credentials token provider with profile: {profile}") + def predict(self, request: PredictionRequest) -> PredictionResponse: + """Make a prediction request to the Einstein Predictions API""" + token_response = self._token_provider.get_token() + access_token = token_response.access_token + + endpoint = self.ENDPOINT_MAP.get(request.prediction_type) + if not endpoint: + raise RuntimeError( + f"Unknown prediction type: {request.prediction_type}. " + f"Valid types: {list(self.ENDPOINT_MAP.keys())}" + ) + + api_url = ( + f"{self.EINSTEIN_PLATFORM_URL}/models/" + f"{request.model_api_name}/{endpoint}" + ) + + prediction_columns: List[Dict[str, Any]] = [] + for col in request.prediction_columns: + col_data: Dict[str, Any] = {"columnName": col.column_name} + if col.string_values: + col_data["stringValues"] = col.string_values + if col.double_values: + col_data["doubleValues"] = col.double_values + if col.boolean_values: + col_data["booleanValues"] = col.boolean_values + if col.date_values: + col_data["dateValues"] = col.date_values + if col.datetime_values: + col_data["datetimeValues"] = col.datetime_values + prediction_columns.append(col_data) + + payload: Dict[str, Any] = {"predictionColumns": prediction_columns} + + if request.settings: + payload["settings"] = request.settings + + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + "x-sfdc-app-context": "EinsteinGPT", + "x-client-feature-id": "ai-platform-models-connected-app", + } + + logger.debug(f"Making Einstein prediction request to: {api_url}") + try: + response = requests.post(api_url, json=payload, headers=headers, timeout=60) + if not response.ok and not response.text: + error_msg = ( + f"Einstein Prediction request failed: {api_url} - " + f"{response.status_code} {response.reason}. " + "If your code uses Einstein APIs, make sure you have " + 'configured the SDK to use "client_credentials" auth type. ' + "Refer to https://developer.salesforce.com/docs/ai/agentforce/" + "guide/agent-api-get-started.html#create-a-salesforce-app " + "to create your external client app." + ) + logger.error(error_msg) + except requests.exceptions.RequestException as e: + logger.error(f"Prediction API request failed: {api_url} {e}") + raise RuntimeError(f"Prediction API request failed: {e}") from e + + response_data: Dict[str, Any] = {} + if response.content: + try: + response_data = response.json() + except ValueError: + logger.warning("Failed to parse response as JSON") + response_data = {"raw_response": response.text} + return PredictionResponse( version="v1", prediction_type=request.prediction_type, - status_code=200, - data={"results": [{"prediction": {"predictedValue": 1.0}}]}, + status_code=response.status_code, + data=response_data, ) diff --git a/src/datacustomcode/io/reader/sf_cli.py b/src/datacustomcode/io/reader/sf_cli.py index adfb3ed..cfeb06e 100644 --- a/src/datacustomcode/io/reader/sf_cli.py +++ b/src/datacustomcode/io/reader/sf_cli.py @@ -16,7 +16,6 @@ import json import logging -import subprocess from typing import ( TYPE_CHECKING, Final, @@ -29,6 +28,7 @@ from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.reader.utils import _pandas_to_spark_schema +from datacustomcode.token_provider import SFCLITokenProvider if TYPE_CHECKING: from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession @@ -78,64 +78,9 @@ def __init__( logger.debug(f"Initialized SFCLIDataCloudReader for org '{sf_cli_org}'") def _get_token(self) -> tuple[str, str]: - """Fetch a fresh access token and instance URL from the SF CLI. - - Returns: - ``(access_token, instance_url)`` - - Raises: - RuntimeError: If the ``sf`` command is not on PATH, times out, or - returns an error. - """ - try: - result = subprocess.run( - ["sf", "org", "display", "--target-org", self.sf_cli_org, "--json"], - capture_output=True, - text=True, - check=True, - timeout=30, - ) - except FileNotFoundError as exc: - raise RuntimeError( - "The 'sf' command was not found. " - "Please install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli" - ) from exc - except subprocess.TimeoutExpired as exc: - raise RuntimeError( - f"'sf org display' timed out for org '{self.sf_cli_org}'" - ) from exc - except subprocess.CalledProcessError as exc: - raise RuntimeError( - f"'sf org display' failed for org '{self.sf_cli_org}'.\n" - f"Ensure the org is authenticated via 'sf org login web'.\n" - f"stderr: {exc.stderr.strip()}" - ) from exc - - try: - data = json.loads(result.stdout) - except json.JSONDecodeError as exc: - raise RuntimeError( - f"Failed to parse 'sf org display' output: {exc}" - ) from exc - - if data.get("status") != 0: - raise RuntimeError( - f"SF CLI error for org '{self.sf_cli_org}': " - f"{data.get('message', 'unknown error')}" - ) - - org_result = data.get("result", {}) - access_token = org_result.get("accessToken") - instance_url = org_result.get("instanceUrl") - - if not access_token or not instance_url: - raise RuntimeError( - f"'sf org display' did not return an access token or instance URL " - f"for org '{self.sf_cli_org}'" - ) - + token_response = SFCLITokenProvider(self.sf_cli_org).get_token() logger.debug(f"Fetched token from SF CLI for org '{self.sf_cli_org}'") - return access_token, instance_url + return token_response.access_token, token_response.instance_url def _execute_query(self, sql: str) -> pd.DataFrame: """Execute *sql* against the Data Cloud REST endpoint. diff --git a/src/datacustomcode/run.py b/src/datacustomcode/run.py index b581072..0e5052a 100644 --- a/src/datacustomcode/run.py +++ b/src/datacustomcode/run.py @@ -25,10 +25,12 @@ ) from datacustomcode.config import config +from datacustomcode.einstein_predictions_config import einstein_predictions_config +from datacustomcode.llm_gateway_config import llm_gateway_config from datacustomcode.scan import find_base_directory, get_package_type -def _set_config_option(config_obj, key: str, value: str) -> None: +def _set_config_option(config_obj, key: str, value: Optional[str]) -> None: """Set an option on a config object if it exists and has options attribute. Args: @@ -36,10 +38,33 @@ def _set_config_option(config_obj, key: str, value: str) -> None: key: Option key to set value: Option value to set """ - if config_obj and hasattr(config_obj, "options"): + if config_obj and hasattr(config_obj, "options") and value is not None: config_obj.options[key] = value +def _update_config_options(profile: Optional[str], sf_cli_org: Optional[str]): + if sf_cli_org: + config_key = "sf_cli_org" + _set_config_option(config.reader_config, config_key, sf_cli_org) + _set_config_option(config.writer_config, config_key, sf_cli_org) + _set_config_option( + einstein_predictions_config.einstein_predictions_config, + config_key, + sf_cli_org, + ) + _set_config_option( + llm_gateway_config.llm_gateway_config, config_key, sf_cli_org + ) + elif profile != "default": + config_key = "credentials_profile" + _set_config_option(config.reader_config, config_key, profile) + _set_config_option(config.writer_config, config_key, profile) + _set_config_option( + einstein_predictions_config.einstein_predictions_config, config_key, profile + ) + _set_config_option(llm_gateway_config.llm_gateway_config, config_key, profile) + + def run_entrypoint( entrypoint: str, config_file: Union[str, None], @@ -47,7 +72,7 @@ def run_entrypoint( profile: str, sf_cli_org: Optional[str] = None, ) -> None: - """Run the entrypoint script with the given config and dependencies. + """Run the entrypoint for script or function with the given config and dependencies. Args: entrypoint: The entrypoint script to run. @@ -98,12 +123,8 @@ def run_entrypoint( _set_config_option(config.reader_config, "dataspace", dataspace) _set_config_option(config.writer_config, "dataspace", dataspace) - if sf_cli_org: - _set_config_option(config.reader_config, "sf_cli_org", sf_cli_org) - _set_config_option(config.writer_config, "sf_cli_org", sf_cli_org) - elif profile != "default": - _set_config_option(config.reader_config, "credentials_profile", profile) - _set_config_option(config.writer_config, "credentials_profile", profile) + _update_config_options(profile, sf_cli_org) + for dependency in dependencies: try: importlib.import_module(dependency) diff --git a/src/datacustomcode/token_provider.py b/src/datacustomcode/token_provider.py new file mode 100644 index 0000000..ee98aeb --- /dev/null +++ b/src/datacustomcode/token_provider.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datacustomcode.deploy import AccessTokenResponse + + +class TokenProvider(ABC): + """Abstract base class for providing authentication tokens.""" + + @abstractmethod + def get_token(self) -> "AccessTokenResponse": + """Retrieve a fresh access token and instance URL. + + Returns: + AccessTokenResponse containing access_token and instance_url + """ + ... + + +class CredentialsTokenProvider(TokenProvider): + """Token provider that uses stored credentials with refresh token flow.""" + + def __init__(self, credentials_profile: str = "default"): + self.credentials_profile = credentials_profile + + def get_token(self) -> "AccessTokenResponse": + """Get token by refreshing credentials from stored profile.""" + import requests + + from datacustomcode.credentials import AuthType, Credentials + from datacustomcode.deploy import AccessTokenResponse + + # Load credentials (freshly, not cached) + credentials = Credentials.from_available(profile=self.credentials_profile) + + token_url = f"{credentials.login_url.rstrip('/')}/services/oauth2/token" + + if credentials.auth_type == AuthType.OAUTH_TOKENS: + data = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token, + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + } + elif credentials.auth_type == AuthType.CLIENT_CREDENTIALS: + data = { + "grant_type": "client_credentials", + "client_id": credentials.client_id, + "client_secret": credentials.client_secret, + } + else: + raise RuntimeError(f"Unsupported auth_type: {credentials.auth_type}") + + try: + response = requests.post(token_url, data=data, timeout=30) + response.raise_for_status() + token_data = response.json() + + access_token = token_data.get("access_token") + instance_url = token_data.get("instance_url") + + if not access_token or not instance_url: + raise RuntimeError( + "Token refresh response missing access_token or instance_url" + ) + + return AccessTokenResponse( + access_token=access_token, instance_url=instance_url + ) + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to get access token: {e}") from e + + +class SFCLITokenProvider(TokenProvider): + """Token provider that uses Salesforce CLI for authentication.""" + + def __init__(self, sf_cli_org: str): + self.sf_cli_org = sf_cli_org + + def get_token(self) -> "AccessTokenResponse": + """Get token from Salesforce SF CLI""" + import json + import subprocess + + from datacustomcode.deploy import AccessTokenResponse + + try: + result = subprocess.run( + ["sf", "org", "display", "--target-org", self.sf_cli_org, "--json"], + capture_output=True, + text=True, + check=True, + timeout=30, + ) + except FileNotFoundError as exc: + raise RuntimeError( + "The 'sf' command was not found. " + "Install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli" + ) from exc + except subprocess.TimeoutExpired as exc: + raise RuntimeError( + f"'sf org display' timed out for org '{self.sf_cli_org}'" + ) from exc + except subprocess.CalledProcessError as exc: + raise RuntimeError( + f"'sf org display' failed for org '{self.sf_cli_org}': {exc.stderr}" + ) from exc + + try: + data = json.loads(result.stdout) + except json.JSONDecodeError as exc: + raise RuntimeError( + f"Failed to parse JSON from 'sf org display': {result.stdout}" + ) from exc + + if data.get("status") != 0: + raise RuntimeError( + f"SF CLI error for org '{self.sf_cli_org}': " + f"{data.get('message', 'unknown error')}" + ) + + result_data = data.get("result", {}) + access_token = result_data.get("accessToken") + instance_url = result_data.get("instanceUrl") + + if not access_token or not instance_url: + raise RuntimeError( + f"'sf org display' did not return an access token or instance URL " + f"for org '{self.sf_cli_org}'" + ) + + return AccessTokenResponse(access_token=access_token, instance_url=instance_url) diff --git a/tests/io/reader/test_query_api.py b/tests/io/reader/test_query_api.py index 33cb54d..f9baa81 100644 --- a/tests/io/reader/test_query_api.py +++ b/tests/io/reader/test_query_api.py @@ -121,8 +121,12 @@ def test_pandas_to_spark_schema_datetime_types(self): assert field_dict[field_name].nullable # Verify the actual pandas dtypes to ensure our test data has the expected types - assert str(df["datetime_ns"].dtype) == "datetime64[ns]" - assert str(df["datetime_ns_utc"].dtype) == "datetime64[ns, UTC]" + # Pandas may use 'ns' or 'us' precision depending on version + assert str(df["datetime_ns"].dtype) in ["datetime64[ns]", "datetime64[us]"] + assert str(df["datetime_ns_utc"].dtype) in [ + "datetime64[ns, UTC]", + "datetime64[us, UTC]", + ] assert str(df["datetime_ms"].dtype) == "datetime64[ms]" assert str(df["datetime_ms_utc"].dtype) == "datetime64[ms]" diff --git a/tests/test_cli.py b/tests/test_cli.py index ab18fa1..e26cbdc 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -72,11 +72,14 @@ def test_init_command( class TestDeploy: @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.credentials.Credentials.from_available") - def test_deploy_command_success(self, mock_credentials, mock_deploy_full): + @patch("datacustomcode.token_provider.CredentialsTokenProvider") + def test_deploy_command_success(self, mock_token_provider, mock_deploy_full): """Test successful deploy command.""" - # Mock credentials - mock_creds = mock_credentials.return_value + # Mock token provider + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( + access_token="test_token", instance_url="https://instance.example.com" + ) runner = CliRunner() with runner.isolated_filesystem(): @@ -86,7 +89,7 @@ def test_deploy_command_success(self, mock_credentials, mock_deploy_full): result = runner.invoke(deploy, ["--name", "test-job", "--version", "1.0.0"]) assert result.exit_code == 0 - mock_credentials.assert_called_once() + mock_token_provider.assert_called_once_with("default") mock_deploy_full.assert_called_once() # Check that deploy_full was called with correct arguments @@ -97,14 +100,20 @@ def test_deploy_command_success(self, mock_credentials, mock_deploy_full): ) # metadata (hyphen sanitized to underscore) assert call_args[0][1].version == "1.0.0" assert call_args[0][1].description == "Custom Data Transform Code" - assert call_args[0][2] == mock_creds # credentials + assert call_args[0][2].access_token == "test_token" + assert call_args[0][2].instance_url == "https://instance.example.com" @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.credentials.Credentials.from_available") + @patch("datacustomcode.token_provider.CredentialsTokenProvider") def test_deploy_command_function_invoke_options( - self, mock_credentials, mock_deploy_full + self, mock_token_provider, mock_deploy_full ): """Test deploy command with function invoke options.""" + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( + access_token="test_token", instance_url="https://instance.example.com" + ) + runner = CliRunner() with runner.isolated_filesystem(): # Create test payload directory @@ -123,13 +132,13 @@ def test_deploy_command_function_invoke_options( call_args = mock_deploy_full.call_args assert call_args[0][1].functionInvokeOptions == ["option1", "option2"] - @patch("datacustomcode.credentials.Credentials.from_available") - def test_deploy_command_credentials_error(self, mock_credentials): + @patch("datacustomcode.token_provider.CredentialsTokenProvider") + def test_deploy_command_credentials_error(self, mock_token_provider): """Test deploy command when credentials are not available.""" - # Mock credentials to raise ValueError - mock_credentials.side_effect = ValueError( - "Credentials not found in env or ini file. " - "Run `datacustomcode configure` to create a credentials file." + # Mock token provider to raise RuntimeError + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.side_effect = RuntimeError( + "Failed to refresh access token" ) runner = CliRunner() @@ -140,12 +149,17 @@ def test_deploy_command_credentials_error(self, mock_credentials): result = runner.invoke(deploy, ["--name", "test-job"]) assert result.exit_code == 1 - assert "Error: Credentials not found in env or ini file" in result.output + assert "Error: Failed to refresh access token" in result.output @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.credentials.Credentials.from_available") - def test_deploy_command_custom_path(self, mock_credentials, mock_deploy_full): + @patch("datacustomcode.token_provider.CredentialsTokenProvider") + def test_deploy_command_custom_path(self, mock_token_provider, mock_deploy_full): """Test deploy command with custom path.""" + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( + access_token="test_token", instance_url="https://instance.example.com" + ) + runner = CliRunner() with runner.isolated_filesystem(): # Create test directory @@ -163,11 +177,16 @@ def test_deploy_command_custom_path(self, mock_credentials, mock_deploy_full): assert call_args[0][0] == "custom_path" # path @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.credentials.Credentials.from_available") + @patch("datacustomcode.token_provider.CredentialsTokenProvider") def test_deploy_command_custom_description( - self, mock_credentials, mock_deploy_full + self, mock_token_provider, mock_deploy_full ): """Test deploy command with custom description.""" + mock_provider_instance = mock_token_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( + access_token="test_token", instance_url="https://instance.example.com" + ) + runner = CliRunner() with runner.isolated_filesystem(): # Create test payload directory @@ -185,13 +204,13 @@ def test_deploy_command_custom_description( assert call_args[0][1].description == "Custom description" @patch("datacustomcode.deploy.deploy_full") - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") - def test_deploy_command_sf_cli_org(self, mock_sf_cli_token, mock_deploy_full): + @patch("datacustomcode.token_provider.SFCLITokenProvider") + def test_deploy_command_sf_cli_org(self, mock_sf_cli_provider, mock_deploy_full): """Test deploy command with --sf-cli-org flag.""" - mock_token = AccessTokenResponse( + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="test_token", instance_url="https://test.salesforce.com" ) - mock_sf_cli_token.return_value = mock_token runner = CliRunner() with runner.isolated_filesystem(): @@ -201,15 +220,20 @@ def test_deploy_command_sf_cli_org(self, mock_sf_cli_token, mock_deploy_full): ) assert result.exit_code == 0 - mock_sf_cli_token.assert_called_once_with("my-org") + mock_sf_cli_provider.assert_called_once_with("my-org") mock_deploy_full.assert_called_once() call_args = mock_deploy_full.call_args - assert call_args[0][2] == mock_token # AccessTokenResponse passed directly + # Check AccessTokenResponse passed directly + assert call_args[0][2].access_token == "test_token" + assert call_args[0][2].instance_url == "https://test.salesforce.com" - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") - def test_deploy_command_sf_cli_org_error(self, mock_sf_cli_token): + @patch("datacustomcode.token_provider.SFCLITokenProvider") + def test_deploy_command_sf_cli_org_error(self, mock_sf_cli_provider): """Test deploy command when --sf-cli-org fails.""" - mock_sf_cli_token.side_effect = RuntimeError("sf command not found") + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.side_effect = RuntimeError( + "sf command not found" + ) runner = CliRunner() with runner.isolated_filesystem(): diff --git a/tests/test_deploy.py b/tests/test_deploy.py index a6ab5aa..40f3281 100644 --- a/tests/test_deploy.py +++ b/tests/test_deploy.py @@ -1,6 +1,5 @@ """Tests for the deploy module.""" -import json from unittest.mock import ( ANY, MagicMock, @@ -12,7 +11,6 @@ import pytest import requests -from datacustomcode.credentials import AuthType, Credentials from datacustomcode.deploy import ( DloPermission, Permissions, @@ -28,8 +26,6 @@ DataTransformConfig, DeploymentsResponse, _make_api_call, - _retrieve_access_token, - _retrieve_access_token_from_sf_cli, _sanitize_api_name, create_data_transform, create_deployment, @@ -598,58 +594,6 @@ def test_make_api_call_invalid_response(self, mock_request): _make_api_call("https://example.com", "GET") -class TestRetrieveAccessToken: - @patch("datacustomcode.deploy._make_api_call") - def test_retrieve_access_token(self, mock_make_api_call): - """Test retrieving access token.""" - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.OAUTH_TOKENS, - refresh_token="refresh", - client_secret="secret", - ) - - mock_make_api_call.return_value = { - "access_token": "test_token", - "instance_url": "https://instance.example.com", - } - - result = _retrieve_access_token(credentials) - - mock_make_api_call.assert_called_once() - call_args = mock_make_api_call.call_args - assert call_args.kwargs["data"]["grant_type"] == "refresh_token" - assert call_args.kwargs["data"]["refresh_token"] == "refresh" - assert isinstance(result, AccessTokenResponse) - assert result.access_token == "test_token" - assert result.instance_url == "https://instance.example.com" - - @patch("datacustomcode.deploy._make_api_call") - def test_retrieve_access_token_client_credentials(self, mock_make_api_call): - """Test retrieving access token with client credentials flow.""" - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.CLIENT_CREDENTIALS, - client_secret="secret", - ) - - mock_make_api_call.return_value = { - "access_token": "test_token", - "instance_url": "https://instance.example.com", - } - - result = _retrieve_access_token(credentials) - - mock_make_api_call.assert_called_once() - call_args = mock_make_api_call.call_args - assert call_args.kwargs["data"]["grant_type"] == "client_credentials" - assert isinstance(result, AccessTokenResponse) - assert result.access_token == "test_token" - assert result.instance_url == "https://instance.example.com" - - class TestCreateDeployment: @patch("datacustomcode.deploy._make_api_call") def test_create_deployment_success(self, mock_make_api_call): @@ -1048,10 +992,8 @@ class TestDeployFull: @patch("datacustomcode.deploy.upload_zip") @patch("datacustomcode.deploy.zip") @patch("datacustomcode.deploy.create_deployment") - @patch("datacustomcode.deploy._retrieve_access_token") def test_deploy_full( self, - mock_retrieve_token, mock_create_deployment, mock_zip, mock_upload_zip, @@ -1070,13 +1012,6 @@ def test_deploy_full( ), ) mock_get_config.return_value = data_transform_config - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.OAUTH_TOKENS, - refresh_token="refresh", - client_secret="secret", - ) metadata = CodeExtensionMetadata( name="test_job", version="1.0.0", @@ -1090,16 +1025,14 @@ def test_deploy_full( access_token = AccessTokenResponse( access_token="test_token", instance_url="https://instance.example.com" ) - mock_retrieve_token.return_value = access_token mock_create_deployment.return_value = CreateDeploymentResponse( fileUploadUrl="https://upload.example.com" ) # Call function - result = deploy_full("/test/dir", metadata, credentials, "default", callback) + result = deploy_full("/test/dir", metadata, access_token, "default", callback) # Assertions - mock_retrieve_token.assert_called_once_with(credentials) mock_get_config.assert_called_once_with("/test/dir") mock_create_deployment.assert_called_once_with(access_token, metadata) mock_zip.assert_called_once_with("/test/dir", "default", "script") @@ -1110,7 +1043,6 @@ def test_deploy_full( ) assert result == access_token - @patch("datacustomcode.deploy._retrieve_access_token") @patch("datacustomcode.deploy.get_config") @patch("datacustomcode.deploy.create_deployment") @patch("datacustomcode.deploy.zip") @@ -1125,7 +1057,6 @@ def test_deploy_full_client_credentials( mock_zip, mock_create_deployment, mock_get_config, - mock_retrieve_token, ): """Test full deployment process using client credentials auth.""" data_transform_config = DataTransformConfig( @@ -1138,12 +1069,6 @@ def test_deploy_full_client_credentials( ), ) mock_get_config.return_value = data_transform_config - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.CLIENT_CREDENTIALS, - client_secret="secret", - ) metadata = CodeExtensionMetadata( name="test_job", version="1.0.0", @@ -1156,14 +1081,12 @@ def test_deploy_full_client_credentials( access_token = AccessTokenResponse( access_token="test_token", instance_url="https://instance.example.com" ) - mock_retrieve_token.return_value = access_token mock_create_deployment.return_value = CreateDeploymentResponse( fileUploadUrl="https://upload.example.com" ) - result = deploy_full("/test/dir", metadata, credentials, "default", callback) + result = deploy_full("/test/dir", metadata, access_token, "default", callback) - mock_retrieve_token.assert_called_once_with(credentials) mock_get_config.assert_called_once_with("/test/dir") mock_create_deployment.assert_called_once_with(access_token, metadata) mock_zip.assert_called_once_with("/test/dir", "default", "script") @@ -1199,7 +1122,6 @@ def test_run_data_transform(self, mock_make_api_call): class TestDeployFullWithDockerIntegration: - @patch("datacustomcode.deploy._retrieve_access_token") @patch("datacustomcode.deploy.create_deployment") @patch("datacustomcode.deploy.zip") @patch("datacustomcode.deploy.upload_zip") @@ -1216,16 +1138,8 @@ def test_deploy_full_happy_path( mock_upload_zip, mock_zip, mock_create_deployment, - mock_retrieve_token, ): """Test full deployment process with Docker dependency building.""" - credentials = Credentials( - login_url="https://example.com", - client_id="id", - auth_type=AuthType.OAUTH_TOKENS, - refresh_token="refresh", - client_secret="secret", - ) metadata = CodeExtensionMetadata( name="test_job", version="1.0.0", @@ -1249,27 +1163,17 @@ def test_deploy_full_happy_path( access_token = AccessTokenResponse( access_token="test_token", instance_url="https://instance.example.com" ) - mock_retrieve_token.return_value = access_token mock_create_deployment.return_value = CreateDeploymentResponse( fileUploadUrl="https://upload.example.com" ) # Mock that requirements.txt exists and has dependencies mock_has_requirements.return_value = True - data_transform_config = DataTransformConfig( - sdkVersion="1.0.0", - entryPoint="entrypoint.py", - dataspace="test_dataspace", - permissions=Permissions( - read=DloPermission(dlo=["input_dlo"]), - write=DloPermission(dlo=["output_dlo"]), - ), - ) + # Call function - result = deploy_full("/test/dir", metadata, credentials, "default", callback) + result = deploy_full("/test/dir", metadata, access_token, "default", callback) # Assertions - mock_retrieve_token.assert_called_once_with(credentials) mock_get_config.assert_called_once_with("/test/dir") mock_create_deployment.assert_called_once_with(access_token, metadata) mock_zip.assert_called_once_with("/test/dir", "default", "script") @@ -1281,104 +1185,6 @@ def test_deploy_full_happy_path( assert result == access_token -class TestRetrieveAccessTokenFromSFCLI: - """Tests for _retrieve_access_token_from_sf_cli.""" - - SF_CLI_OUTPUT = json.dumps( - { - "status": 0, - "result": { - "accessToken": "sf_access_token", - "instanceUrl": "https://sf.salesforce.com", - }, - } - ) - - @patch("datacustomcode.deploy.subprocess.run") - def test_happy_path(self, mock_run): - """Successful sf org display returns AccessTokenResponse.""" - mock_run.return_value = MagicMock(stdout=self.SF_CLI_OUTPUT, returncode=0) - - result = _retrieve_access_token_from_sf_cli("my-org") - - assert result.access_token == "sf_access_token" - assert result.instance_url == "https://sf.salesforce.com" - mock_run.assert_called_once_with( - ["sf", "org", "display", "--target-org", "my-org", "--json"], - capture_output=True, - text=True, - check=True, - timeout=30, - ) - - @patch("datacustomcode.deploy.subprocess.run") - def test_file_not_found(self, mock_run): - """FileNotFoundError raised when sf CLI is not installed.""" - mock_run.side_effect = FileNotFoundError("No such file or directory: 'sf'") - - with pytest.raises(RuntimeError, match="'sf' command was not found"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_timeout_expired(self, mock_run): - """TimeoutExpired raised when sf CLI times out.""" - import subprocess - - mock_run.side_effect = subprocess.TimeoutExpired(cmd="sf", timeout=30) - - with pytest.raises(RuntimeError, match="timed out"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_called_process_error(self, mock_run): - """CalledProcessError raised when sf CLI exits non-zero.""" - import subprocess - - mock_run.side_effect = subprocess.CalledProcessError( - returncode=1, cmd="sf", stderr="Org not found" - ) - - with pytest.raises(RuntimeError, match="failed for org"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_json_decode_error(self, mock_run): - """RuntimeError raised when output is not valid JSON.""" - mock_run.return_value = MagicMock(stdout="not-json", returncode=0) - - with pytest.raises(RuntimeError, match="Failed to parse"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_nonzero_status_in_json(self, mock_run): - """RuntimeError raised when JSON status field is non-zero.""" - output = json.dumps({"status": 1, "message": "org not found"}) - mock_run.return_value = MagicMock(stdout=output, returncode=0) - - with pytest.raises(RuntimeError, match="SF CLI error"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_missing_access_token(self, mock_run): - """RuntimeError raised when accessToken is absent.""" - output = json.dumps( - {"status": 0, "result": {"instanceUrl": "https://sf.salesforce.com"}} - ) - mock_run.return_value = MagicMock(stdout=output, returncode=0) - - with pytest.raises(RuntimeError, match="did not return"): - _retrieve_access_token_from_sf_cli("my-org") - - @patch("datacustomcode.deploy.subprocess.run") - def test_missing_instance_url(self, mock_run): - """RuntimeError raised when instanceUrl is absent.""" - output = json.dumps({"status": 0, "result": {"accessToken": "sf_access_token"}}) - mock_run.return_value = MagicMock(stdout=output, returncode=0) - - with pytest.raises(RuntimeError, match="did not return"): - _retrieve_access_token_from_sf_cli("my-org") - - class TestDeployFullWithAccessTokenResponse: """Test deploy_full when passed an AccessTokenResponse directly.""" @@ -1388,10 +1194,8 @@ class TestDeployFullWithAccessTokenResponse: @patch("datacustomcode.deploy.zip") @patch("datacustomcode.deploy.create_deployment") @patch("datacustomcode.deploy.get_config") - @patch("datacustomcode.deploy._retrieve_access_token") def test_deploy_full_with_access_token_response_skips_token_exchange( self, - mock_retrieve_token, mock_get_config, mock_create_deployment, mock_zip, @@ -1399,7 +1203,7 @@ def test_deploy_full_with_access_token_response_skips_token_exchange( mock_wait, mock_create_transform, ): - """deploy_full skips token exchange when given an AccessTokenResponse.""" + """deploy_full now only accepts AccessTokenResponse.""" access_token = AccessTokenResponse( access_token="direct_token", instance_url="https://instance.example.com" ) @@ -1417,7 +1221,6 @@ def test_deploy_full_with_access_token_response_skips_token_exchange( result = deploy_full("/test/dir", metadata, access_token, "default") - mock_retrieve_token.assert_not_called() mock_create_deployment.assert_called_once_with(access_token, metadata) assert result == access_token diff --git a/tests/test_sf_cli_contract.py b/tests/test_sf_cli_contract.py index e60f68f..b96ab35 100644 --- a/tests/test_sf_cli_contract.py +++ b/tests/test_sf_cli_contract.py @@ -154,48 +154,51 @@ class TestDeployArgContract: "--cpu-size", "CPU_2XL", ] # fmt: skip - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") + @patch("datacustomcode.token_provider.SFCLITokenProvider") @patch("datacustomcode.deploy.deploy_full") @patch("datacustomcode.cli.find_base_directory") @patch("datacustomcode.cli.get_package_type") def test_accepts_required_flags( - self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_token + self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_provider ): mock_find_base.return_value = "payload" mock_pkg_type.return_value = "script" - mock_sf_cli_token.return_value = AccessTokenResponse( + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="tok", instance_url="https://example.com" ) runner = CliRunner() result = runner.invoke(deploy, self._BASE_ARGS) assert result.exit_code != 2, result.output - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") + @patch("datacustomcode.token_provider.SFCLITokenProvider") @patch("datacustomcode.deploy.deploy_full") @patch("datacustomcode.cli.find_base_directory") @patch("datacustomcode.cli.get_package_type") def test_accepts_network_flag( - self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_token + self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_provider ): mock_find_base.return_value = "payload" mock_pkg_type.return_value = "script" - mock_sf_cli_token.return_value = AccessTokenResponse( + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="tok", instance_url="https://example.com" ) runner = CliRunner() result = runner.invoke(deploy, [*self._BASE_ARGS, "--network", "custom"]) assert result.exit_code != 2, result.output - @patch("datacustomcode.deploy._retrieve_access_token_from_sf_cli") + @patch("datacustomcode.token_provider.SFCLITokenProvider") @patch("datacustomcode.deploy.deploy_full") @patch("datacustomcode.cli.find_base_directory") @patch("datacustomcode.cli.get_package_type") def test_accepts_function_invoke_opt_flag( - self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_token + self, mock_pkg_type, mock_find_base, mock_deploy_full, mock_sf_cli_provider ): mock_find_base.return_value = "payload" mock_pkg_type.return_value = "function" - mock_sf_cli_token.return_value = AccessTokenResponse( + mock_provider_instance = mock_sf_cli_provider.return_value + mock_provider_instance.get_token.return_value = AccessTokenResponse( access_token="tok", instance_url="https://example.com" ) runner = CliRunner() diff --git a/tests/test_token_provider.py b/tests/test_token_provider.py new file mode 100644 index 0000000..5125e37 --- /dev/null +++ b/tests/test_token_provider.py @@ -0,0 +1,188 @@ +"""Tests for token_provider module.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from datacustomcode.credentials import AuthType, Credentials +from datacustomcode.deploy import AccessTokenResponse +from datacustomcode.token_provider import CredentialsTokenProvider, SFCLITokenProvider + + +class TestCredentialsTokenProvider: + """Tests for CredentialsTokenProvider.""" + + def test_oauth_tokens_auth_type(self): + """Test token retrieval with OAUTH_TOKENS auth type.""" + provider = CredentialsTokenProvider("test_profile") + + mock_credentials = Credentials( + login_url="https://login.salesforce.com", + client_id="test_client_id", + client_secret="test_secret", + refresh_token="test_refresh_token", + auth_type=AuthType.OAUTH_TOKENS, + ) + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with patch("requests.post") as mock_post: + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "test_access_token", + "instance_url": "https://instance.salesforce.com", + } + mock_post.return_value = mock_response + + result = provider.get_token() + + assert isinstance(result, AccessTokenResponse) + assert result.access_token == "test_access_token" + assert result.instance_url == "https://instance.salesforce.com" + + # Verify correct auth type was used + call_args = mock_post.call_args + assert call_args[1]["data"]["grant_type"] == "refresh_token" + assert call_args[1]["data"]["refresh_token"] == "test_refresh_token" + + def test_client_credentials_auth_type(self): + """Test token retrieval with CLIENT_CREDENTIALS auth type.""" + provider = CredentialsTokenProvider("test_profile") + + mock_credentials = Credentials( + login_url="https://login.salesforce.com", + client_id="test_client_id", + client_secret="test_secret", + auth_type=AuthType.CLIENT_CREDENTIALS, + ) + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with patch("requests.post") as mock_post: + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "test_access_token", + "instance_url": "https://instance.salesforce.com", + } + mock_post.return_value = mock_response + + result = provider.get_token() + + assert isinstance(result, AccessTokenResponse) + assert result.access_token == "test_access_token" + assert result.instance_url == "https://instance.salesforce.com" + + # Verify correct auth type was used + call_args = mock_post.call_args + assert call_args[1]["data"]["grant_type"] == "client_credentials" + assert "refresh_token" not in call_args[1]["data"] + + def test_unsupported_auth_type_raises_error(self): + """Test that unsupported auth type raises RuntimeError.""" + provider = CredentialsTokenProvider("test_profile") + + # Create credentials with invalid auth_type + mock_credentials = MagicMock() + mock_credentials.login_url = "https://login.salesforce.com" + mock_credentials.auth_type = "INVALID_AUTH_TYPE" + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with pytest.raises(RuntimeError, match="Unsupported auth_type"): + provider.get_token() + + def test_missing_access_token_raises_error(self): + """Test that missing access_token in response raises RuntimeError.""" + provider = CredentialsTokenProvider("test_profile") + + mock_credentials = Credentials( + login_url="https://login.salesforce.com", + client_id="test_client_id", + client_secret="test_secret", + refresh_token="test_refresh_token", + auth_type=AuthType.OAUTH_TOKENS, + ) + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with patch("requests.post") as mock_post: + mock_response = MagicMock() + mock_response.json.return_value = { + "instance_url": "https://instance.salesforce.com" + } + mock_post.return_value = mock_response + + with pytest.raises( + RuntimeError, match="missing access_token or instance_url" + ): + provider.get_token() + + def test_request_exception_raises_runtime_error(self): + """Test that request exceptions are wrapped in RuntimeError.""" + import requests + + provider = CredentialsTokenProvider("test_profile") + + mock_credentials = Credentials( + login_url="https://login.salesforce.com", + client_id="test_client_id", + client_secret="test_secret", + refresh_token="test_refresh_token", + auth_type=AuthType.OAUTH_TOKENS, + ) + + with patch( + "datacustomcode.credentials.Credentials.from_available", + return_value=mock_credentials, + ): + with patch( + "requests.post", + side_effect=requests.exceptions.RequestException("Network error"), + ): + with pytest.raises(RuntimeError, match="Failed to get access token"): + provider.get_token() + + +class TestSFCLITokenProvider: + """Tests for SFCLITokenProvider.""" + + def test_successful_token_retrieval(self): + """Test successful token retrieval from SF CLI.""" + import json + + provider = SFCLITokenProvider("test_org") + + cli_output = json.dumps( + { + "status": 0, + "result": { + "accessToken": "cli_access_token", + "instanceUrl": "https://cli.salesforce.com", + }, + } + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(stdout=cli_output) + + result = provider.get_token() + + assert isinstance(result, AccessTokenResponse) + assert result.access_token == "cli_access_token" + assert result.instance_url == "https://cli.salesforce.com" + + def test_sf_command_not_found(self): + """Test that FileNotFoundError is wrapped in RuntimeError.""" + provider = SFCLITokenProvider("test_org") + + with patch("subprocess.run", side_effect=FileNotFoundError()): + with pytest.raises(RuntimeError, match="'sf' command was not found"): + provider.get_token() From afa12009a2238f5397d5de5ef7751ee0214dc258 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Fri, 24 Apr 2026 16:54:01 -0400 Subject: [PATCH 02/14] use production llm gateway --- .../einstein_platform_client.py | 75 +++++++++++++++++++ .../einstein_predictions/__init__.py | 2 +- .../einstein_predictions/impl/default.py | 59 ++++----------- src/datacustomcode/llm_gateway/default.py | 64 +++++++++++++--- tests/test_llm_gateway.py | 27 ------- 5 files changed, 144 insertions(+), 83 deletions(-) create mode 100644 src/datacustomcode/einstein_platform_client.py diff --git a/src/datacustomcode/einstein_platform_client.py b/src/datacustomcode/einstein_platform_client.py new file mode 100644 index 0000000..92a241f --- /dev/null +++ b/src/datacustomcode/einstein_platform_client.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + Any, + Dict, + Optional, +) + +from abc import ABC +from loguru import logger + +from datacustomcode.token_provider import ( + CredentialsTokenProvider, + SFCLITokenProvider, + TokenProvider, +) + + +class EinsteinPlatformClient(ABC): + EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1" + EINSTEIN_WARNING_MESSAGE = ( + "If your code uses Einstein APIs, make sure you have " + 'configured the SDK to use "client_credentials" auth type. ' + "Refer to https://developer.salesforce.com/docs/ai/agentforce/" + "guide/agent-api-get-started.html#create-a-salesforce-app " + "to create your external client app." + ) + + def __init__( + self, + credentials_profile: Optional[str] = None, + sf_cli_org: Optional[str] = None, + ): + if sf_cli_org: + self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org) + logger.debug(f"Using SF CLI token provider for org: {sf_cli_org}") + else: + profile = credentials_profile or "default" + self._token_provider = CredentialsTokenProvider(profile) + logger.debug(f"Using credentials token provider with profile: {profile}") + self.token_response = None + + def get_headers(self): + if self.token_response is None: + self.token_response = self._token_provider.get_token() + + return { + "Authorization": f"Bearer {self.token_response.access_token}", + "Content-Type": "application/json", + "x-sfdc-app-context": "EinsteinGPT", + "x-client-feature-id": "ai-platform-models-connected-app", + } + + def parse_response(self, response): + response_data: Dict[str, Any] = {} + if response.content: + try: + response_data = response.json() + except ValueError: + logger.warning("Failed to parse response as JSON") + response_data = {"raw_response": response.text} + return response_data diff --git a/src/datacustomcode/einstein_predictions/__init__.py b/src/datacustomcode/einstein_predictions/__init__.py index 4a4b388..9aaa2a3 100644 --- a/src/datacustomcode/einstein_predictions/__init__.py +++ b/src/datacustomcode/einstein_predictions/__init__.py @@ -17,6 +17,6 @@ from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions __all__ = [ - "EinsteinPredictions", "DefaultEinsteinPredictions", + "EinsteinPredictions", ] diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index ae138e1..431c41e 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -24,23 +24,17 @@ from loguru import logger import requests +from datacustomcode.einstein_platform_client import EinsteinPlatformClient from datacustomcode.einstein_predictions.base import EinsteinPredictions from datacustomcode.einstein_predictions.types import ( PredictionRequest, PredictionResponse, PredictionType, ) -from datacustomcode.token_provider import ( - CredentialsTokenProvider, - SFCLITokenProvider, - TokenProvider, -) -class DefaultEinsteinPredictions(EinsteinPredictions): +class DefaultEinsteinPredictions(EinsteinPlatformClient, EinsteinPredictions): CONFIG_NAME = "DefaultEinsteinPredictions" - EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1" - ENDPOINT_MAP: ClassVar[dict[PredictionType, str]] = { PredictionType.REGRESSION: "regression", PredictionType.CLUSTERING: "clustering", @@ -55,21 +49,12 @@ def __init__( sf_cli_org: Optional[str] = None, **kwargs, ): - super().__init__(**kwargs) - - if sf_cli_org: - self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org) - logger.debug(f"Using SF CLI token provider for org: {sf_cli_org}") - else: - profile = credentials_profile or "default" - self._token_provider = CredentialsTokenProvider(profile) - logger.debug(f"Using credentials token provider with profile: {profile}") + EinsteinPlatformClient.__init__( + self, credentials_profile=credentials_profile, sf_cli_org=sf_cli_org + ) + EinsteinPredictions.__init__(self, **kwargs) def predict(self, request: PredictionRequest) -> PredictionResponse: - """Make a prediction request to the Einstein Predictions API""" - token_response = self._token_provider.get_token() - access_token = token_response.access_token - endpoint = self.ENDPOINT_MAP.get(request.prediction_type) if not endpoint: raise RuntimeError( @@ -102,42 +87,24 @@ def predict(self, request: PredictionRequest) -> PredictionResponse: if request.settings: payload["settings"] = request.settings - headers = { - "Authorization": f"Bearer {access_token}", - "Content-Type": "application/json", - "x-sfdc-app-context": "EinsteinGPT", - "x-client-feature-id": "ai-platform-models-connected-app", - } - logger.debug(f"Making Einstein prediction request to: {api_url}") try: - response = requests.post(api_url, json=payload, headers=headers, timeout=60) + response = requests.post( + api_url, json=payload, headers=self.get_headers(), timeout=180 + ) if not response.ok and not response.text: error_msg = ( f"Einstein Prediction request failed: {api_url} - " f"{response.status_code} {response.reason}. " - "If your code uses Einstein APIs, make sure you have " - 'configured the SDK to use "client_credentials" auth type. ' - "Refer to https://developer.salesforce.com/docs/ai/agentforce/" - "guide/agent-api-get-started.html#create-a-salesforce-app " - "to create your external client app." + f"{self.EINSTEIN_WARNING_MESSAGE}" ) logger.error(error_msg) except requests.exceptions.RequestException as e: - logger.error(f"Prediction API request failed: {api_url} {e}") - raise RuntimeError(f"Prediction API request failed: {e}") from e - - response_data: Dict[str, Any] = {} - if response.content: - try: - response_data = response.json() - except ValueError: - logger.warning("Failed to parse response as JSON") - response_data = {"raw_response": response.text} + logger.error(f"Einstein Prediction request failed: {api_url} {e}") + raise RuntimeError(f"Einstein Prediction request failed: {e}") from e return PredictionResponse( - version="v1", prediction_type=request.prediction_type, status_code=response.status_code, - data=response_data, + data=self.parse_response(response), ) diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index 9fefbc7..256291b 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -13,23 +13,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import ( + Any, + Dict, + Optional, +) + +from loguru import logger +import requests + +from datacustomcode.einstein_platform_client import EinsteinPlatformClient + + from datacustomcode.llm_gateway.base import LLMGateway from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse -from datacustomcode.llm_gateway.types.generate_text_response_builder import ( - GenerateTextResponseBuilder, -) -class DefaultLLMGateway(LLMGateway): +class DefaultLLMGateway(EinsteinPlatformClient, LLMGateway): CONFIG_NAME = "DefaultLLMGateway" + def __init__( + self, + credentials_profile: Optional[str] = None, + sf_cli_org: Optional[str] = None, + **kwargs, + ): + EinsteinPlatformClient.__init__( + self, credentials_profile=credentials_profile, sf_cli_org=sf_cli_org + ) + LLMGateway.__init__(self, **kwargs) + def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: + api_url = ( + f"{self.EINSTEIN_PLATFORM_URL}/models/" + f"{request.model_name}/generations" + ) - response_data = { - "version": "v1", - "status_code": 200, - "data": {"generation": {"generatedText": "Hello World"}}, + payload: Dict[str, Any] = { + "prompt": request.prompt } - return GenerateTextResponseBuilder.build(response_data) + if request.localization: + payload["localization"] = request.localization + if request.tags: + payload["tags"] = request.tags + + logger.debug(f"Making Generate text request: {api_url}") + try: + response = requests.post( + api_url, json=payload, headers=self.get_headers(), timeout=180 + ) + if not response.ok and not response.text: + error_msg = ( + f"Generate text request failed: {api_url} - " + f"{response.status_code} {response.reason}. " + f"{self.EINSTEIN_WARNING_MESSAGE}" + ) + logger.error(error_msg) + except requests.exceptions.RequestException as e: + logger.error(f"Generate text request failed: {api_url} {e}") + raise RuntimeError(f"Generate text request failed: {e}") from e + + return GenerateTextResponse( + status_code=response.status_code, + data=self.parse_response(response) + ) diff --git a/tests/test_llm_gateway.py b/tests/test_llm_gateway.py index e05bc02..7875e21 100644 --- a/tests/test_llm_gateway.py +++ b/tests/test_llm_gateway.py @@ -3,8 +3,6 @@ from pydantic import ValidationError import pytest -from datacustomcode.llm_gateway.base import LLMGateway -from datacustomcode.llm_gateway.default import DefaultLLMGateway from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest from datacustomcode.llm_gateway.types.generate_text_request_builder import ( GenerateTextRequestBuilder, @@ -210,28 +208,3 @@ def test_builder_with_minimal_dict(self): response = GenerateTextResponseBuilder.build(response_dict) assert response.status_code == 200 assert response.version == "v1" # Default value - - -class TestDefaultLLMGateway: - """Test DefaultLLMGateway implementation.""" - - def test_default_gateway_is_llm_gateway(self): - """Test DefaultLLMGateway inherits from LLMGateway.""" - gateway = DefaultLLMGateway() - assert isinstance(gateway, LLMGateway) - - def test_generate_text_returns_response(self): - """Test generate_text returns GenerateTextResponse.""" - gateway = DefaultLLMGateway() - request = GenerateTextRequest(model_name="gpt-4", prompt="Hello") - response = gateway.generate_text(request) - assert isinstance(response, GenerateTextResponse) - - def test_generate_text_success_response(self): - """Test generate_text returns successful response.""" - gateway = DefaultLLMGateway() - request = GenerateTextRequest(model_name="gpt-4", prompt="Hello") - response = gateway.generate_text(request) - assert response.is_success is True - assert response.status_code == 200 - assert len(response.text) > 0 From 10adca3e61fd8740d88c014afc6f08298bc4205e Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Fri, 24 Apr 2026 17:12:32 -0400 Subject: [PATCH 03/14] Remote proxy phase 1 --- src/datacustomcode/__init__.py | 4 -- src/datacustomcode/client.py | 8 ---- src/datacustomcode/config.py | 16 ------- src/datacustomcode/config.yaml | 5 --- .../proxy/client/LocalProxyClientProvider.py | 34 --------------- .../templates/function/payload/entrypoint.py | 20 +++++---- tests/test_client.py | 42 ++++++++----------- 7 files changed, 28 insertions(+), 101 deletions(-) delete mode 100644 src/datacustomcode/proxy/client/LocalProxyClientProvider.py diff --git a/src/datacustomcode/__init__.py b/src/datacustomcode/__init__.py index 2662e74..00cfae3 100644 --- a/src/datacustomcode/__init__.py +++ b/src/datacustomcode/__init__.py @@ -17,15 +17,11 @@ from datacustomcode.credentials import AuthType, Credentials from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader from datacustomcode.io.writer.print import PrintDataCloudWriter -from datacustomcode.proxy.client.LocalProxyClientProvider import ( - LocalProxyClientProvider, -) __all__ = [ "AuthType", "Client", "Credentials", - "LocalProxyClientProvider", "PrintDataCloudWriter", "QueryAPIDataCloudReader", ] diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index c1f5763..faecf0a 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -33,7 +33,6 @@ from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode - from datacustomcode.proxy.client.base import BaseProxyClient from datacustomcode.spark.base import BaseSparkSessionProvider @@ -107,7 +106,6 @@ class Client: _reader: BaseDataCloudReader _writer: BaseDataCloudWriter _file: DefaultFindFilePath - _proxy: Optional[BaseProxyClient] _data_layer_history: dict[DataCloudObjectType, set[str]] _code_type: str @@ -115,7 +113,6 @@ def __new__( cls, reader: Optional[BaseDataCloudReader] = None, writer: Optional["BaseDataCloudWriter"] = None, - proxy: Optional[BaseProxyClient] = None, spark_provider: Optional["BaseSparkSessionProvider"] = None, code_type: str = "script", ) -> Client: @@ -223,11 +220,6 @@ def write_to_dmo( self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO) return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs) # type: ignore[no-any-return] - def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str: - if self._proxy is None: - raise ValueError("No proxy configured; set proxy or proxy_config") - return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens) # type: ignore[no-any-return] - def find_file_path(self, file_name: str) -> Path: """Return a file path""" diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index 820c512..740f4be 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -39,8 +39,6 @@ from datacustomcode.io.base import BaseDataAccessLayer from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH002 from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH002 -from datacustomcode.proxy.base import BaseProxyAccessLayer -from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002 from datacustomcode.spark.base import BaseSparkSessionProvider if TYPE_CHECKING: @@ -74,18 +72,6 @@ class SparkConfig(ForceableConfig): _P = TypeVar("_P", bound=BaseSparkSessionProvider) -_PX = TypeVar("_PX", bound=BaseProxyAccessLayer) - - -class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]): - """Config for proxy clients that take no constructor args (e.g. no spark).""" - - type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer - - def to_object(self) -> _PX: - type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_PX, type_(**self.options)) - class SparkProviderConfig(BaseObjectConfig, Generic[_P]): type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider @@ -98,7 +84,6 @@ def to_object(self) -> _P: class ClientConfig(BaseConfig): reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None - proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None spark_config: Union[SparkConfig, None] = None spark_provider_config: Union[ SparkProviderConfig[BaseSparkSessionProvider], None @@ -126,7 +111,6 @@ def merge( self.reader_config = merge(self.reader_config, other.reader_config) self.writer_config = merge(self.writer_config, other.writer_config) - self.proxy_config = merge(self.proxy_config, other.proxy_config) self.spark_config = merge(self.spark_config, other.spark_config) self.spark_provider_config = merge( self.spark_provider_config, other.spark_provider_config diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index d630a81..8a6c334 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -19,11 +19,6 @@ spark_config: spark.sql.execution.arrow.pyspark.enabled: 'true' spark.driver.extraJavaOptions: -Djava.security.manager=allow -proxy_config: - type_config_name: LocalProxyClientProvider - options: - credentials_profile: default - einstein_predictions_config: type_config_name: DefaultEinsteinPredictions options: diff --git a/src/datacustomcode/proxy/client/LocalProxyClientProvider.py b/src/datacustomcode/proxy/client/LocalProxyClientProvider.py deleted file mode 100644 index 9c08b54..0000000 --- a/src/datacustomcode/proxy/client/LocalProxyClientProvider.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2025, Salesforce, Inc. -# SPDX-License-Identifier: Apache-2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -from datacustomcode.proxy.client.base import BaseProxyClient - - -class LocalProxyClientProvider(BaseProxyClient): - """Default proxy client provider.""" - - CONFIG_NAME = "LocalProxyClientProvider" - - def __init__(self, **kwargs: object) -> None: - pass - - def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: - return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}" - - def llm_gateway_generate_text( - self, template, values, llmModelId: str, maxTokens: int - ): - return f"Using Generate Text with {llmModelId} and maxTokens: {maxTokens}" diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index 613d142..e77970b 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -65,6 +65,16 @@ def make_einstein_prediction(runtime: Runtime) -> None: ) +def generate_text(runtime: Runtime): + builder = GenerateTextRequestBuilder() + llm_request = builder.set_prompt("Hello").set_model("modelName").build() + llm_response = runtime.llm_gateway.generate_text(llm_request) + + if llm_response.is_success: + print(llm_response.text) + else: + print(llm_response.error_code) + def function(request: dict, runtime: Runtime) -> dict: logger.info("Inside Function") logger.info(request) @@ -73,17 +83,9 @@ def function(request: dict, runtime: Runtime) -> dict: output_chunks = [] current_seq_no = 1 # Start sequence number from 1 + generate_text(runtime) make_einstein_prediction(runtime) - builder = GenerateTextRequestBuilder() - llm_request = builder.set_prompt("Hello").set_model("modelName").build() - llm_response = runtime.llm_gateway.generate_text(llm_request) - - if llm_response.is_success: - print(llm_response.text) - else: - print(llm_response.error_code) - for item in items: # Item is DocElement as dict logger.info(f"Processing item: {item}") diff --git a/tests/test_client.py b/tests/test_client.py index 5981ca9..c2cf46a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,7 +17,6 @@ ) from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode -from datacustomcode.proxy.client.base import BaseProxyClient class MockDataCloudReader(BaseDataCloudReader): @@ -76,13 +75,6 @@ def mock_config(mock_spark): ) -@pytest.fixture -def mock_proxy(): - """Mock proxy client to avoid starting Spark when reader/writer are provided.""" - proxy = MagicMock(spec=BaseProxyClient) - return proxy - - @pytest.fixture def reset_client(): """Reset the Client singleton between tests.""" @@ -93,12 +85,12 @@ def reset_client(): class TestClient: - def test_singleton_pattern(self, reset_client, mock_spark, mock_proxy): + def test_singleton_pattern(self, reset_client, mock_spark): """Test that Client behaves as a singleton.""" reader = MockDataCloudReader(mock_spark) writer = MockDataCloudWriter(mock_spark) - client1 = Client(reader=reader, writer=writer, proxy=mock_proxy) + client1 = Client(reader=reader, writer=writer) client2 = Client() assert client1 is client2 @@ -144,38 +136,38 @@ def test_initialization_with_config(self, mock_config, reset_client, mock_spark) assert client._reader is mock_reader assert client._writer is mock_writer - def test_read_dlo(self, reset_client, mock_spark, mock_proxy): + def test_read_dlo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dlo.return_value = mock_df - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) result = client.read_dlo("test_dlo") reader.read_dlo.assert_called_once_with("test_dlo") assert result is mock_df assert "test_dlo" in client._data_layer_history[DataCloudObjectType.DLO] - def test_read_dmo(self, reset_client, mock_spark, mock_proxy): + def test_read_dmo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dmo.return_value = mock_df - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) result = client.read_dmo("test_dmo") reader.read_dmo.assert_called_once_with("test_dmo") assert result is mock_df assert "test_dmo" in client._data_layer_history[DataCloudObjectType.DMO] - def test_write_to_dlo(self, reset_client, mock_spark, mock_proxy): + def test_write_to_dlo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) client._record_dlo_access("some_dlo") client.write_to_dlo("test_dlo", mock_df, WriteMode.APPEND, extra_param=True) @@ -184,12 +176,12 @@ def test_write_to_dlo(self, reset_client, mock_spark, mock_proxy): "test_dlo", mock_df, WriteMode.APPEND, extra_param=True ) - def test_write_to_dmo(self, reset_client, mock_spark, mock_proxy): + def test_write_to_dmo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) client._record_dmo_access("some_dmo") client.write_to_dmo("test_dmo", mock_df, WriteMode.OVERWRITE, extra_param=True) @@ -198,13 +190,13 @@ def test_write_to_dmo(self, reset_client, mock_spark, mock_proxy): "test_dmo", mock_df, WriteMode.OVERWRITE, extra_param=True ) - def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark, mock_proxy): + def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark): """Test that mixing DLOs and DMOs raises an exception.""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) client._record_dlo_access("test_dlo") with pytest.raises(DataCloudAccessLayerException) as exc_info: @@ -212,13 +204,13 @@ def test_mixed_dlo_dmo_raises_exception(self, reset_client, mock_spark, mock_pro assert "test_dlo" in str(exc_info.value) - def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark, mock_proxy): + def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark): """Test that mixing DMOs and DLOs raises an exception (converse case).""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) client._record_dmo_access("test_dmo") with pytest.raises(DataCloudAccessLayerException) as exc_info: @@ -226,14 +218,14 @@ def test_mixed_dmo_dlo_raises_exception(self, reset_client, mock_spark, mock_pro assert "test_dmo" in str(exc_info.value) - def test_read_pattern_flow(self, reset_client, mock_spark, mock_proxy): + def test_read_pattern_flow(self, reset_client, mock_spark): """Test a complete flow of reading and writing within the same object type.""" reader = MagicMock(spec=BaseDataCloudReader) writer = MagicMock(spec=BaseDataCloudWriter) mock_df = MagicMock(spec=DataFrame) reader.read_dlo.return_value = mock_df - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) df = client.read_dlo("source_dlo") client.write_to_dlo("target_dlo", df, WriteMode.APPEND) @@ -247,7 +239,7 @@ def test_read_pattern_flow(self, reset_client, mock_spark, mock_proxy): # Reset for DMO test Client._instance = None - client = Client(reader=reader, writer=writer, proxy=mock_proxy) + client = Client(reader=reader, writer=writer) reader.read_dmo.return_value = mock_df df = client.read_dmo("source_dmo") From fc479ca8132a3777568f9a00cd017b4b364c408e Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Fri, 24 Apr 2026 17:32:20 -0400 Subject: [PATCH 04/14] update function template --- src/datacustomcode/templates/function/payload/entrypoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index e77970b..484d2bb 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -60,14 +60,14 @@ def make_einstein_prediction(runtime: Runtime) -> None: prediction_response = runtime.einstein_predictions.predict(prediction_request) print( - f"Einstein prediction results - success: {prediction_response.is_success} \ - response data: {prediction_response.data}" + f"Einstein prediction results - success: [{prediction_response.is_success}] \ +response data: {prediction_response.data}" ) def generate_text(runtime: Runtime): builder = GenerateTextRequestBuilder() - llm_request = builder.set_prompt("Hello").set_model("modelName").build() + llm_request = builder.set_prompt("Generate 2 dog names").set_model("sfdc_ai__DefaultGPT52").build() llm_response = runtime.llm_gateway.generate_text(llm_request) if llm_response.is_success: From 22afe2b4c012704104430698b41b41e7ddd91731 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Fri, 24 Apr 2026 17:44:37 -0400 Subject: [PATCH 05/14] comment out llm generate text and prediction examples --- .../templates/function/payload/entrypoint.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index 484d2bb..0e17511 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -83,8 +83,11 @@ def function(request: dict, runtime: Runtime) -> dict: output_chunks = [] current_seq_no = 1 # Start sequence number from 1 - generate_text(runtime) - make_einstein_prediction(runtime) + """ + You can use your AI models configured in Salesforce to generate texts or predict an outcome + """ + # generate_text(runtime) + # make_einstein_prediction(runtime) for item in items: # Item is DocElement as dict From af91379828f15ee53b3e13191cdf595fc3b22208 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Fri, 24 Apr 2026 18:04:49 -0400 Subject: [PATCH 06/14] fix linting failures --- src/datacustomcode/einstein_platform_client.py | 3 +-- src/datacustomcode/templates/function/payload/entrypoint.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/datacustomcode/einstein_platform_client.py b/src/datacustomcode/einstein_platform_client.py index 92a241f..cf2bcbf 100644 --- a/src/datacustomcode/einstein_platform_client.py +++ b/src/datacustomcode/einstein_platform_client.py @@ -19,7 +19,6 @@ Optional, ) -from abc import ABC from loguru import logger from datacustomcode.token_provider import ( @@ -29,7 +28,7 @@ ) -class EinsteinPlatformClient(ABC): +class EinsteinPlatformClient: EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1" EINSTEIN_WARNING_MESSAGE = ( "If your code uses Einstein APIs, make sure you have " diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index 0e17511..5db7fbc 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -67,7 +67,8 @@ def make_einstein_prediction(runtime: Runtime) -> None: def generate_text(runtime: Runtime): builder = GenerateTextRequestBuilder() - llm_request = builder.set_prompt("Generate 2 dog names").set_model("sfdc_ai__DefaultGPT52").build() + llm_request = builder.set_prompt("Generate 2 dog names").\ + set_model("sfdc_ai__DefaultGPT52").build() llm_response = runtime.llm_gateway.generate_text(llm_request) if llm_response.is_success: @@ -84,7 +85,8 @@ def function(request: dict, runtime: Runtime) -> dict: current_seq_no = 1 # Start sequence number from 1 """ - You can use your AI models configured in Salesforce to generate texts or predict an outcome + You can use your AI models configured in Salesforce + to generate texts or predict an outcome """ # generate_text(runtime) # make_einstein_prediction(runtime) From f35f190652ed054d4a3269c05a7c3eb773aa9837 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Sun, 26 Apr 2026 14:15:21 -0400 Subject: [PATCH 07/14] fix linting CI failures --- src/datacustomcode/auth.py | 2 +- src/datacustomcode/cmd.py | 2 +- src/datacustomcode/config.py | 8 ++++---- src/datacustomcode/einstein_predictions_config.py | 2 +- src/datacustomcode/llm_gateway/default.py | 11 +++-------- src/datacustomcode/llm_gateway_config.py | 2 +- .../templates/function/payload/entrypoint.py | 14 ++++++++++---- tests/file/test_path_default.py | 2 +- tests/test_auth.py | 4 ++-- tests/test_deploy.py | 8 ++++---- tests/test_scan.py | 4 ++-- 11 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/datacustomcode/auth.py b/src/datacustomcode/auth.py index ef43ddc..6636379 100644 --- a/src/datacustomcode/auth.py +++ b/src/datacustomcode/auth.py @@ -170,7 +170,7 @@ def do_oauth_browser_flow( # Start callback server click.echo(f"\nStarting local callback server on {redirect_uri}...") - server, actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue) + server, _actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue) # Build authorization URL with final redirect_uri auth_url = ( diff --git a/src/datacustomcode/cmd.py b/src/datacustomcode/cmd.py index e656316..0bb4d2c 100644 --- a/src/datacustomcode/cmd.py +++ b/src/datacustomcode/cmd.py @@ -104,6 +104,6 @@ def _cmd_output( def cmd_output(*cmd: str, **kwargs: Any) -> Union[str, None]: - returncode, stdout_b, stderr_b = _cmd_output(*cmd, **kwargs) + _returncode, stdout_b, _stderr_b = _cmd_output(*cmd, **kwargs) stdout = stdout_b.decode() if stdout_b is not None else None return stdout diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index 740f4be..666e134 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -37,8 +37,8 @@ # This lets all readers and writers to be findable via config from datacustomcode.io import * # noqa: F403 from datacustomcode.io.base import BaseDataAccessLayer -from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH002 -from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH002 +from datacustomcode.io.reader.base import BaseDataCloudReader +from datacustomcode.io.writer.base import BaseDataCloudWriter from datacustomcode.spark.base import BaseSparkSessionProvider if TYPE_CHECKING: @@ -53,7 +53,7 @@ class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]): def to_object(self, spark: SparkSession) -> _T: type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_T, type_(spark=spark, **self.options)) + return cast("_T", type_(spark=spark, **self.options)) class SparkConfig(ForceableConfig): @@ -78,7 +78,7 @@ class SparkProviderConfig(BaseObjectConfig, Generic[_P]): def to_object(self) -> _P: type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_P, type_(**self.options)) + return cast("_P", type_(**self.options)) class ClientConfig(BaseConfig): diff --git a/src/datacustomcode/einstein_predictions_config.py b/src/datacustomcode/einstein_predictions_config.py index 4e83164..41fe879 100644 --- a/src/datacustomcode/einstein_predictions_config.py +++ b/src/datacustomcode/einstein_predictions_config.py @@ -37,7 +37,7 @@ class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]): def to_object(self) -> _E: type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_E, type_(**self.options)) + return cast("_E", type_(**self.options)) class EinsteinPredictionsConfig(BaseConfig): diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index 256291b..ed2e570 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -19,12 +19,10 @@ Optional, ) -from loguru import logger import requests +from loguru import logger from datacustomcode.einstein_platform_client import EinsteinPlatformClient - - from datacustomcode.llm_gateway.base import LLMGateway from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse @@ -50,9 +48,7 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: f"{request.model_name}/generations" ) - payload: Dict[str, Any] = { - "prompt": request.prompt - } + payload: Dict[str, Any] = {"prompt": request.prompt} if request.localization: payload["localization"] = request.localization @@ -76,6 +72,5 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: raise RuntimeError(f"Generate text request failed: {e}") from e return GenerateTextResponse( - status_code=response.status_code, - data=self.parse_response(response) + status_code=response.status_code, data=self.parse_response(response) ) diff --git a/src/datacustomcode/llm_gateway_config.py b/src/datacustomcode/llm_gateway_config.py index 7e7e53d..d96d564 100644 --- a/src/datacustomcode/llm_gateway_config.py +++ b/src/datacustomcode/llm_gateway_config.py @@ -37,7 +37,7 @@ class LLMGatewayObjectConfig(BaseObjectConfig, Generic[_E]): def to_object(self) -> _E: type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast(_E, type_(**self.options)) + return cast("_E", type_(**self.options)) class LLMGatewayConfig(BaseConfig): diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index 5db7fbc..8e92dbd 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -60,15 +60,18 @@ def make_einstein_prediction(runtime: Runtime) -> None: prediction_response = runtime.einstein_predictions.predict(prediction_request) print( - f"Einstein prediction results - success: [{prediction_response.is_success}] \ -response data: {prediction_response.data}" + f"Einstein prediction results - success: [{prediction_response.is_success}] " + f"response data: {prediction_response.data}" ) def generate_text(runtime: Runtime): builder = GenerateTextRequestBuilder() - llm_request = builder.set_prompt("Generate 2 dog names").\ - set_model("sfdc_ai__DefaultGPT52").build() + llm_request = ( + builder.set_prompt("Generate 2 dog names") + .set_model("sfdc_ai__DefaultGPT52") + .build() + ) llm_response = runtime.llm_gateway.generate_text(llm_request) if llm_response.is_success: @@ -76,6 +79,9 @@ def generate_text(runtime: Runtime): else: print(llm_response.error_code) + + + def function(request: dict, runtime: Runtime) -> dict: logger.info("Inside Function") logger.info(request) diff --git a/tests/file/test_path_default.py b/tests/file/test_path_default.py index c92eacd..8350122 100644 --- a/tests/file/test_path_default.py +++ b/tests/file/test_path_default.py @@ -72,7 +72,7 @@ def test_find_file_path_file_not_found(self): with pytest.raises( FileNotFoundError, - match="File 'test.txt' not found in any search location", + match=r"File 'test\.txt' not found in any search location", ): finder.find_file_path("test.txt") diff --git a/tests/test_auth.py b/tests/test_auth.py index 1e335ac..974b3c0 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -148,7 +148,7 @@ def test_run_oauth_callback_server_valid_uri( mock_server.server_address = ("localhost", 5555) mock_tcpserver.return_value = mock_server - server, port = _run_oauth_callback_server(redirect_uri, auth_code_queue) + _server, port = _run_oauth_callback_server(redirect_uri, auth_code_queue) assert port == 5555 mock_tcpserver.assert_called_once() @@ -186,7 +186,7 @@ def test_run_oauth_callback_server_different_port( mock_server.server_address = ("localhost", 8080) mock_tcpserver.return_value = mock_server - server, port = _run_oauth_callback_server(redirect_uri, auth_code_queue) + _server, port = _run_oauth_callback_server(redirect_uri, auth_code_queue) assert port == 8080 mock_tcpserver.assert_called_once() diff --git a/tests/test_deploy.py b/tests/test_deploy.py index 40f3281..2fd3ce2 100644 --- a/tests/test_deploy.py +++ b/tests/test_deploy.py @@ -907,7 +907,7 @@ def test_verify_data_transform_config_missing(self, mock_exists): mock_exists.return_value = False with pytest.raises( FileNotFoundError, - match="config.json not found at /test/dir/payload/config.json", + match=r"config\.json not found at /test/dir/payload/config\.json", ): get_config("/test/dir/payload") @@ -918,7 +918,7 @@ def test_verify_data_transform_config_invalid_json(self, mock_file, mock_exists) mock_exists.return_value = True with pytest.raises( ValueError, - match="config.json at /test/dir/payload/config.json is not valid JSON", + match=r"config\.json at /test/dir/payload/config\.json is not valid JSON", ): get_config("/test/dir/payload") @@ -929,8 +929,8 @@ def test_verify_data_transform_config_missing_fields(self, mock_file, mock_exist mock_exists.return_value = True with pytest.raises( ValueError, - match="config.json at /test/dir/payload/config.json is missing " - "required fields: entryPoint, dataspace, permissions", + match=r"config\.json at /test/dir/payload/config\.json is missing " + r"required fields: entryPoint, dataspace, permissions", ): get_config("/test/dir/payload") diff --git a/tests/test_scan.py b/tests/test_scan.py index 528f728..bfdad8b 100644 --- a/tests/test_scan.py +++ b/tests/test_scan.py @@ -524,8 +524,8 @@ def test_rejects_missing_dataspace(self): # Should raise ValueError when dataspace field is missing with pytest.raises( ValueError, - match="dataspace must be defined. Please add a 'dataspace' field to " - "the config.json file.", + match=r"dataspace must be defined\. Please add a 'dataspace' field to " + r"the config\.json file\.", ): update_config(temp_path) finally: From f0d700be723dd24b9e9d8e08ff882f608df8fdaf Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Sun, 26 Apr 2026 14:39:03 -0400 Subject: [PATCH 08/14] fix linting ci build failures --- src/datacustomcode/config.py | 8 ++++---- src/datacustomcode/llm_gateway/default.py | 5 ++--- .../templates/function/payload/entrypoint.py | 2 -- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index 666e134..775287e 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -37,12 +37,12 @@ # This lets all readers and writers to be findable via config from datacustomcode.io import * # noqa: F403 from datacustomcode.io.base import BaseDataAccessLayer -from datacustomcode.io.reader.base import BaseDataCloudReader -from datacustomcode.io.writer.base import BaseDataCloudWriter from datacustomcode.spark.base import BaseSparkSessionProvider if TYPE_CHECKING: from pyspark.sql import SparkSession + from datacustomcode.io.reader.base import BaseDataCloudReader + from datacustomcode.io.writer.base import BaseDataCloudWriter _T = TypeVar("_T", bound="BaseDataAccessLayer") @@ -82,8 +82,8 @@ def to_object(self) -> _P: class ClientConfig(BaseConfig): - reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None - writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None + reader_config: Union[AccessLayerObjectConfig["BaseDataCloudReader"], None] = None + writer_config: Union[AccessLayerObjectConfig["BaseDataCloudWriter"], None] = None spark_config: Union[SparkConfig, None] = None spark_provider_config: Union[ SparkProviderConfig[BaseSparkSessionProvider], None diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index ed2e570..89a70de 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -19,8 +19,8 @@ Optional, ) -import requests from loguru import logger +import requests from datacustomcode.einstein_platform_client import EinsteinPlatformClient from datacustomcode.llm_gateway.base import LLMGateway @@ -44,8 +44,7 @@ def __init__( def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: api_url = ( - f"{self.EINSTEIN_PLATFORM_URL}/models/" - f"{request.model_name}/generations" + f"{self.EINSTEIN_PLATFORM_URL}/models/{request.model_name}/generations" ) payload: Dict[str, Any] = {"prompt": request.prompt} diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index 8e92dbd..427326e 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -80,8 +80,6 @@ def generate_text(runtime: Runtime): print(llm_response.error_code) - - def function(request: dict, runtime: Runtime) -> dict: logger.info("Inside Function") logger.info(request) From e0ac7e7f0dba0e43fe46b8ea1787a5d12c4a5745 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Sun, 26 Apr 2026 14:46:55 -0400 Subject: [PATCH 09/14] fix isort --- src/datacustomcode/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index 775287e..901b295 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from pyspark.sql import SparkSession + from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter From 06425e468b1ec6f8a5da272de21aae73d29364a7 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Sun, 26 Apr 2026 15:37:42 -0400 Subject: [PATCH 10/14] remote unnecessary log message --- src/datacustomcode/einstein_platform_client.py | 7 ------- src/datacustomcode/einstein_predictions/impl/default.py | 3 +-- src/datacustomcode/llm_gateway/default.py | 3 +-- .../templates/function/payload/entrypoint.py | 4 +++- 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/datacustomcode/einstein_platform_client.py b/src/datacustomcode/einstein_platform_client.py index cf2bcbf..986f23f 100644 --- a/src/datacustomcode/einstein_platform_client.py +++ b/src/datacustomcode/einstein_platform_client.py @@ -30,13 +30,6 @@ class EinsteinPlatformClient: EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1" - EINSTEIN_WARNING_MESSAGE = ( - "If your code uses Einstein APIs, make sure you have " - 'configured the SDK to use "client_credentials" auth type. ' - "Refer to https://developer.salesforce.com/docs/ai/agentforce/" - "guide/agent-api-get-started.html#create-a-salesforce-app " - "to create your external client app." - ) def __init__( self, diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index 431c41e..6a9cec3 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -95,8 +95,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse: if not response.ok and not response.text: error_msg = ( f"Einstein Prediction request failed: {api_url} - " - f"{response.status_code} {response.reason}. " - f"{self.EINSTEIN_WARNING_MESSAGE}" + f"{response.status_code} {response.reason}" ) logger.error(error_msg) except requests.exceptions.RequestException as e: diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index 89a70de..c7ee853 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -62,8 +62,7 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: if not response.ok and not response.text: error_msg = ( f"Generate text request failed: {api_url} - " - f"{response.status_code} {response.reason}. " - f"{self.EINSTEIN_WARNING_MESSAGE}" + f"{response.status_code} {response.reason}" ) logger.error(error_msg) except requests.exceptions.RequestException as e: diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index 427326e..bdd6301 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -90,7 +90,9 @@ def function(request: dict, runtime: Runtime) -> dict: """ You can use your AI models configured in Salesforce - to generate texts or predict an outcome + to generate texts or predict an outcome. + First configure an external client app before using these AI APIs + https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api-get-started.html#create-a-salesforce-app" """ # generate_text(runtime) # make_einstein_prediction(runtime) From 036f56f600819dafed0a0c49e6cecf875891ae92 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Tue, 28 Apr 2026 13:29:26 -0400 Subject: [PATCH 11/14] update based on the PR feedback, including removing code duplication --- pyproject.toml | 1 + .../einstein_platform_client.py | 6 ++- .../einstein_platform_config.py | 41 +++++++++++++++++++ .../einstein_predictions/impl/default.py | 15 +------ .../einstein_predictions_config.py | 16 ++------ src/datacustomcode/llm_gateway/default.py | 30 +++++--------- src/datacustomcode/llm_gateway_config.py | 16 ++------ .../templates/function/payload/entrypoint.py | 11 +++-- 8 files changed, 71 insertions(+), 65 deletions(-) create mode 100644 src/datacustomcode/einstein_platform_config.py diff --git a/pyproject.toml b/pyproject.toml index 250c9dc..2245007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ python = ">=3.10,<3.12" pyyaml = "^6.0" salesforce-cdp-connector = ">=1.0.19" setuptools_scm = "^7.1.0" +requests = "2.33.1" [tool.poetry.group.dev.dependencies] build = "*" diff --git a/src/datacustomcode/einstein_platform_client.py b/src/datacustomcode/einstein_platform_client.py index 986f23f..7ea86ff 100644 --- a/src/datacustomcode/einstein_platform_client.py +++ b/src/datacustomcode/einstein_platform_client.py @@ -29,12 +29,15 @@ class EinsteinPlatformClient: - EINSTEIN_PLATFORM_URL = "https://api.salesforce.com/einstein/platform/v1" + EINSTEIN_PLATFORM_MODELS_URL = ( + "https://api.salesforce.com/einstein/platform/v1/models" + ) def __init__( self, credentials_profile: Optional[str] = None, sf_cli_org: Optional[str] = None, + **kwargs: Any, ): if sf_cli_org: self._token_provider: TokenProvider = SFCLITokenProvider(sf_cli_org) @@ -44,6 +47,7 @@ def __init__( self._token_provider = CredentialsTokenProvider(profile) logger.debug(f"Using credentials token provider with profile: {profile}") self.token_response = None + super().__init__(**kwargs) def get_headers(self): if self.token_response is None: diff --git a/src/datacustomcode/einstein_platform_config.py b/src/datacustomcode/einstein_platform_config.py new file mode 100644 index 0000000..135809d --- /dev/null +++ b/src/datacustomcode/einstein_platform_config.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + ClassVar, + Optional, + Type, + cast, +) + +from datacustomcode.common_config import BaseObjectConfig + + +class CredentialsObjectConfig(BaseObjectConfig): + type_to_create: ClassVar[Type] + credentials_profile: Optional[str] = None + sf_cli_org: Optional[str] = None + + def to_object(self): + """Create an object instance, automatically including credentials in options""" + + options = self.options.copy() + if self.credentials_profile is not None: + options["credentials_profile"] = self.credentials_profile + if self.sf_cli_org is not None: + options["sf_cli_org"] = self.sf_cli_org + + type_ = self.type_to_create.subclass_from_config_name(self.type_config_name) + return cast(type_, type_(**options)) diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index 6a9cec3..8f68d82 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -18,7 +18,6 @@ ClassVar, Dict, List, - Optional, ) from loguru import logger @@ -43,17 +42,6 @@ class DefaultEinsteinPredictions(EinsteinPlatformClient, EinsteinPredictions): PredictionType.MULTI_OUTCOME: "multi-outcome", } - def __init__( - self, - credentials_profile: Optional[str] = None, - sf_cli_org: Optional[str] = None, - **kwargs, - ): - EinsteinPlatformClient.__init__( - self, credentials_profile=credentials_profile, sf_cli_org=sf_cli_org - ) - EinsteinPredictions.__init__(self, **kwargs) - def predict(self, request: PredictionRequest) -> PredictionResponse: endpoint = self.ENDPOINT_MAP.get(request.prediction_type) if not endpoint: @@ -63,8 +51,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse: ) api_url = ( - f"{self.EINSTEIN_PLATFORM_URL}/models/" - f"{request.model_api_name}/{endpoint}" + f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_api_name}/{endpoint}" ) prediction_columns: List[Dict[str, Any]] = [] diff --git a/src/datacustomcode/einstein_predictions_config.py b/src/datacustomcode/einstein_predictions_config.py index 41fe879..1b4758f 100644 --- a/src/datacustomcode/einstein_predictions_config.py +++ b/src/datacustomcode/einstein_predictions_config.py @@ -19,25 +19,17 @@ Type, TypeVar, Union, - cast, ) -from datacustomcode.common_config import ( - BaseConfig, - BaseObjectConfig, - default_config_file, -) +from datacustomcode.common_config import BaseConfig, default_config_file +from datacustomcode.einstein_platform_config import CredentialsObjectConfig from datacustomcode.einstein_predictions.base import EinsteinPredictions _E = TypeVar("_E", bound=EinsteinPredictions) -class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]): - type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract] - - def to_object(self) -> _E: - type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast("_E", type_(**self.options)) +class EinsteinPredictionsObjectConfig(CredentialsObjectConfig, Generic[_E]): + type_to_create: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract] class EinsteinPredictionsConfig(BaseConfig): diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index c7ee853..33fca9a 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -13,11 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ( - Any, - Dict, - Optional, -) +from typing import Any, Dict from loguru import logger import requests @@ -26,25 +22,17 @@ from datacustomcode.llm_gateway.base import LLMGateway from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest from datacustomcode.llm_gateway.types.generate_text_response import GenerateTextResponse +from datacustomcode.llm_gateway.types.generate_text_response_builder import ( + GenerateTextResponseBuilder, +) class DefaultLLMGateway(EinsteinPlatformClient, LLMGateway): CONFIG_NAME = "DefaultLLMGateway" - def __init__( - self, - credentials_profile: Optional[str] = None, - sf_cli_org: Optional[str] = None, - **kwargs, - ): - EinsteinPlatformClient.__init__( - self, credentials_profile=credentials_profile, sf_cli_org=sf_cli_org - ) - LLMGateway.__init__(self, **kwargs) - def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: api_url = ( - f"{self.EINSTEIN_PLATFORM_URL}/models/{request.model_name}/generations" + f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_name}/generations" ) payload: Dict[str, Any] = {"prompt": request.prompt} @@ -69,6 +57,8 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: logger.error(f"Generate text request failed: {api_url} {e}") raise RuntimeError(f"Generate text request failed: {e}") from e - return GenerateTextResponse( - status_code=response.status_code, data=self.parse_response(response) - ) + response_dict = { + "status_code": response.status_code, + "data": self.parse_response(response), + } + return GenerateTextResponseBuilder.build(response_dict) diff --git a/src/datacustomcode/llm_gateway_config.py b/src/datacustomcode/llm_gateway_config.py index d96d564..a65d0eb 100644 --- a/src/datacustomcode/llm_gateway_config.py +++ b/src/datacustomcode/llm_gateway_config.py @@ -19,25 +19,17 @@ Type, TypeVar, Union, - cast, ) -from datacustomcode.common_config import ( - BaseConfig, - BaseObjectConfig, - default_config_file, -) +from datacustomcode.common_config import BaseConfig, default_config_file +from datacustomcode.einstein_platform_config import CredentialsObjectConfig from datacustomcode.llm_gateway.base import LLMGateway _E = TypeVar("_E", bound=LLMGateway) -class LLMGatewayObjectConfig(BaseObjectConfig, Generic[_E]): - type_base: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract] - - def to_object(self) -> _E: - type_ = self.type_base.subclass_from_config_name(self.type_config_name) - return cast("_E", type_(**self.options)) +class LLMGatewayObjectConfig(CredentialsObjectConfig, Generic[_E]): + type_to_create: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract] class LLMGatewayConfig(BaseConfig): diff --git a/src/datacustomcode/templates/function/payload/entrypoint.py b/src/datacustomcode/templates/function/payload/entrypoint.py index bdd6301..a1cd685 100644 --- a/src/datacustomcode/templates/function/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/payload/entrypoint.py @@ -59,7 +59,7 @@ def make_einstein_prediction(runtime: Runtime) -> None: ) prediction_response = runtime.einstein_predictions.predict(prediction_request) - print( + logger.info( f"Einstein prediction results - success: [{prediction_response.is_success}] " f"response data: {prediction_response.data}" ) @@ -73,11 +73,10 @@ def generate_text(runtime: Runtime): .build() ) llm_response = runtime.llm_gateway.generate_text(llm_request) - - if llm_response.is_success: - print(llm_response.text) - else: - print(llm_response.error_code) + logger.info( + f"LLM Gateway generate text results - success: [{llm_response.is_success}] " + f"response data: {llm_response.data}" + ) def function(request: dict, runtime: Runtime) -> dict: From 8419b88aead8064669fb27991fe39526bf5334ce Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Tue, 28 Apr 2026 13:44:22 -0400 Subject: [PATCH 12/14] update poetry lock file --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index cf0885d..1fba4b3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3539,4 +3539,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.12" -content-hash = "dbcd5823f305acc0a54e7f1fdd0b85a32da36eafdcab48ee37fe0dc299b80bfe" +content-hash = "f147a38ecba7556f5b328cff4841ad7087ad41f9a6da1c48e80a3dbef0c773d9" From 429fcd87c89cc301500842f9b34374952638938c Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Tue, 28 Apr 2026 14:01:07 -0400 Subject: [PATCH 13/14] Pretty format TOM --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2245007..7450e15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,9 +105,9 @@ pydantic = "2.13.1" pyspark = "3.5.1" python = ">=3.10,<3.12" pyyaml = "^6.0" +requests = "2.33.1" salesforce-cdp-connector = ">=1.0.19" setuptools_scm = "^7.1.0" -requests = "2.33.1" [tool.poetry.group.dev.dependencies] build = "*" From 28187f2afc3475916208ce3e98d081e430467880 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Wed, 29 Apr 2026 09:05:58 -0400 Subject: [PATCH 14/14] more refactoring --- .../einstein_platform_client.py | 20 ++++++++++++++++++- .../einstein_predictions/impl/default.py | 19 +----------------- src/datacustomcode/llm_gateway/default.py | 19 +----------------- 3 files changed, 21 insertions(+), 37 deletions(-) diff --git a/src/datacustomcode/einstein_platform_client.py b/src/datacustomcode/einstein_platform_client.py index 7ea86ff..761f80a 100644 --- a/src/datacustomcode/einstein_platform_client.py +++ b/src/datacustomcode/einstein_platform_client.py @@ -20,6 +20,7 @@ ) from loguru import logger +import requests from datacustomcode.token_provider import ( CredentialsTokenProvider, @@ -49,7 +50,7 @@ def __init__( self.token_response = None super().__init__(**kwargs) - def get_headers(self): + def _get_headers(self): if self.token_response is None: self.token_response = self._token_provider.get_token() @@ -60,6 +61,23 @@ def get_headers(self): "x-client-feature-id": "ai-platform-models-connected-app", } + def make_post_request(self, url, payload): + try: + response = requests.post( + url, json=payload, headers=self._get_headers(), timeout=180 + ) + if not response.ok: + error_msg = ( + f"Request to {url} failed. " + f"Reason: {response.status_code} {response.reason} - " + f"Response body: {response.text}" + ) + logger.error(error_msg) + return response + except requests.exceptions.RequestException as e: + logger.error(f"Request to {url} failed: {e}") + raise RuntimeError(f"Request to {url} failed {e}") from e + def parse_response(self, response): response_data: Dict[str, Any] = {} if response.content: diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index 8f68d82..28e51f0 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -20,9 +20,6 @@ List, ) -from loguru import logger -import requests - from datacustomcode.einstein_platform_client import EinsteinPlatformClient from datacustomcode.einstein_predictions.base import EinsteinPredictions from datacustomcode.einstein_predictions.types import ( @@ -74,21 +71,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse: if request.settings: payload["settings"] = request.settings - logger.debug(f"Making Einstein prediction request to: {api_url}") - try: - response = requests.post( - api_url, json=payload, headers=self.get_headers(), timeout=180 - ) - if not response.ok and not response.text: - error_msg = ( - f"Einstein Prediction request failed: {api_url} - " - f"{response.status_code} {response.reason}" - ) - logger.error(error_msg) - except requests.exceptions.RequestException as e: - logger.error(f"Einstein Prediction request failed: {api_url} {e}") - raise RuntimeError(f"Einstein Prediction request failed: {e}") from e - + response = self.make_post_request(api_url, payload) return PredictionResponse( prediction_type=request.prediction_type, status_code=response.status_code, diff --git a/src/datacustomcode/llm_gateway/default.py b/src/datacustomcode/llm_gateway/default.py index 33fca9a..88374e3 100644 --- a/src/datacustomcode/llm_gateway/default.py +++ b/src/datacustomcode/llm_gateway/default.py @@ -15,9 +15,6 @@ from typing import Any, Dict -from loguru import logger -import requests - from datacustomcode.einstein_platform_client import EinsteinPlatformClient from datacustomcode.llm_gateway.base import LLMGateway from datacustomcode.llm_gateway.types.generate_text_request import GenerateTextRequest @@ -42,21 +39,7 @@ def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse: if request.tags: payload["tags"] = request.tags - logger.debug(f"Making Generate text request: {api_url}") - try: - response = requests.post( - api_url, json=payload, headers=self.get_headers(), timeout=180 - ) - if not response.ok and not response.text: - error_msg = ( - f"Generate text request failed: {api_url} - " - f"{response.status_code} {response.reason}" - ) - logger.error(error_msg) - except requests.exceptions.RequestException as e: - logger.error(f"Generate text request failed: {api_url} {e}") - raise RuntimeError(f"Generate text request failed: {e}") from e - + response = self.make_post_request(api_url, payload) response_dict = { "status_code": response.status_code, "data": self.parse_response(response),