Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,18 @@ Instead of using `datacustomcode configure`, you can also set credentials via en
| `SFDC_REFRESH_TOKEN` | OAuth refresh token |
| `SFDC_ACCESS_TOKEN` | (Optional) OAuth core/access token |

**Einstein Platform API Environment (Optional):**
| Variable | Description |
|----------|-------------|
| `SFDC_EINSTEIN_API_ENV` | Einstein Platform API environment: `dev`, `test`, `stage`, or `prod`. If not set, automatically inferred from login URL. Set this explicitly if auto-detection fails. |

Example usage:
```bash
export SFDC_LOGIN_URL="https://login.salesforce.com"
export SFDC_CLIENT_ID="your_client_id"
export SFDC_CLIENT_SECRET="your_client_secret"
export SFDC_REFRESH_TOKEN="your_refresh_token"
export SFDC_EINSTEIN_API_ENV="test" # optional

datacustomcode run ./payload/entrypoint.py
```
Expand Down
31 changes: 27 additions & 4 deletions src/datacustomcode/einstein_platform_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import (
Any,
Dict,
Expand All @@ -30,10 +31,6 @@


class EinsteinPlatformClient:
EINSTEIN_PLATFORM_MODELS_URL = (
"https://api.salesforce.com/einstein/platform/v1/models"
)

def __init__(
self,
credentials_profile: Optional[str] = None,
Expand All @@ -48,8 +45,34 @@ def __init__(
self._token_provider = CredentialsTokenProvider(profile)
logger.debug(f"Using credentials token provider with profile: {profile}")
self.token_response = None
self._einstein_url_cache: Optional[str] = None
super().__init__(**kwargs)

def _get_einstein_platform_url(self) -> str:
if self._einstein_url_cache is not None:
return self._einstein_url_cache

env = os.environ.get("SFDC_EINSTEIN_API_ENV", "prod").lower()
if env not in ("dev", "test", "stage", "prod"):
logger.warning(
f"Unknown SFDC_EINSTEIN_API_ENV value '{env}', defaulting to prod"
)
env = "prod"

base_url = self._get_base_url_for_env(env)
logger.info(f"Using Einstein Platform API endpoint: {base_url} (env={env})")
self._einstein_url_cache = f"{base_url}/einstein/platform/v1/models"
return self._einstein_url_cache

def _get_base_url_for_env(self, env: str) -> str:
env_map = {
"dev": "https://dev.api.salesforce.com",
"test": "https://test.api.salesforce.com",
"stage": "https://stage.api.salesforce.com",
"prod": "https://api.salesforce.com",
}
return env_map.get(env, "https://api.salesforce.com")

def _get_headers(self):
if self.token_response is None:
self.token_response = self._token_provider.get_token()
Expand Down
2 changes: 1 addition & 1 deletion src/datacustomcode/einstein_predictions/impl/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse:
)

api_url = (
f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_api_name}/{endpoint}"
f"{self._get_einstein_platform_url()}/{request.model_api_name}/{endpoint}"
)

prediction_columns: List[Dict[str, Any]] = []
Expand Down
2 changes: 1 addition & 1 deletion src/datacustomcode/llm_gateway/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DefaultLLMGateway(EinsteinPlatformClient, LLMGateway):

def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
api_url = (
f"{self.EINSTEIN_PLATFORM_MODELS_URL}/{request.model_name}/generations"
f"{self._get_einstein_platform_url()}/{request.model_name}/generations"
)

payload: Dict[str, Any] = {"prompt": request.prompt}
Expand Down
Loading