Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pydantic = "2.13.1"
pyspark = "3.5.1"
python = ">=3.10,<3.12"
pyyaml = "^6.0"
requests = "2.33.1"
salesforce-cdp-connector = ">=1.0.19"
setuptools_scm = "^7.1.0"

Expand Down
4 changes: 0 additions & 4 deletions src/datacustomcode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@
from datacustomcode.credentials import AuthType, Credentials
from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
from datacustomcode.io.writer.print import PrintDataCloudWriter
from datacustomcode.proxy.client.LocalProxyClientProvider import (
LocalProxyClientProvider,
)

__all__ = [
"AuthType",
"Client",
"Credentials",
"LocalProxyClientProvider",
"PrintDataCloudWriter",
"QueryAPIDataCloudReader",
]
2 changes: 1 addition & 1 deletion src/datacustomcode/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def do_oauth_browser_flow(

# Start callback server
click.echo(f"\nStarting local callback server on {redirect_uri}...")
server, actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue)
server, _actual_port = _run_oauth_callback_server(redirect_uri, auth_code_queue)

# Build authorization URL with final redirect_uri
auth_url = (
Expand Down
32 changes: 13 additions & 19 deletions src/datacustomcode/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,15 @@ def deploy(
function_invoke_opt: str,
sf_cli_org: Optional[str],
):
from datacustomcode.credentials import Credentials
from datacustomcode.deploy import (
COMPUTE_TYPES,
AccessTokenResponse,
CodeExtensionMetadata,
_retrieve_access_token_from_sf_cli,
deploy_full,
)
from datacustomcode.token_provider import (
CredentialsTokenProvider,
SFCLITokenProvider,
)

logger.debug("Deploying project")

Expand Down Expand Up @@ -220,22 +221,15 @@ def deploy(
function_invoke_options = function_invoke_opt.split(",")
metadata.functionInvokeOptions = function_invoke_options

auth: Union[Credentials, AccessTokenResponse]
if sf_cli_org:
try:
auth = _retrieve_access_token_from_sf_cli(sf_cli_org)
except RuntimeError as e:
click.secho(f"Error: {e}", fg="red")
raise click.Abort() from None
else:
try:
auth = Credentials.from_available(profile=profile)
except ValueError as e:
click.secho(
f"Error: {e}",
fg="red",
)
raise click.Abort() from None
try:
if sf_cli_org:
auth = SFCLITokenProvider(sf_cli_org).get_token()
else:
auth = CredentialsTokenProvider(profile).get_token()
except RuntimeError as e:
click.secho(f"Error: {e}", fg="red")
raise click.Abort() from None

deploy_full(path, metadata, auth, network)


Expand Down
8 changes: 0 additions & 8 deletions src/datacustomcode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
from datacustomcode.proxy.client.base import BaseProxyClient
from datacustomcode.spark.base import BaseSparkSessionProvider


Expand Down Expand Up @@ -107,15 +106,13 @@ class Client:
_reader: BaseDataCloudReader
_writer: BaseDataCloudWriter
_file: DefaultFindFilePath
_proxy: Optional[BaseProxyClient]
_data_layer_history: dict[DataCloudObjectType, set[str]]
_code_type: str

def __new__(
cls,
reader: Optional[BaseDataCloudReader] = None,
writer: Optional["BaseDataCloudWriter"] = None,
proxy: Optional[BaseProxyClient] = None,
spark_provider: Optional["BaseSparkSessionProvider"] = None,
code_type: str = "script",
) -> Client:
Expand Down Expand Up @@ -223,11 +220,6 @@ def write_to_dmo(
self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO)
return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs) # type: ignore[no-any-return]

def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markdlv-sf was mentioning this is still used?

if self._proxy is None:
raise ValueError("No proxy configured; set proxy or proxy_config")
return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens) # type: ignore[no-any-return]

def find_file_path(self, file_name: str) -> Path:
"""Return a file path"""

Expand Down
2 changes: 1 addition & 1 deletion src/datacustomcode/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,6 @@ def _cmd_output(


def cmd_output(*cmd: str, **kwargs: Any) -> Union[str, None]:
returncode, stdout_b, stderr_b = _cmd_output(*cmd, **kwargs)
_returncode, stdout_b, _stderr_b = _cmd_output(*cmd, **kwargs)
stdout = stdout_b.decode() if stdout_b is not None else None
return stdout
29 changes: 7 additions & 22 deletions src/datacustomcode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@
# This lets all readers and writers to be findable via config
from datacustomcode.io import * # noqa: F403
from datacustomcode.io.base import BaseDataAccessLayer
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH002
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH002
from datacustomcode.proxy.base import BaseProxyAccessLayer
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002
from datacustomcode.spark.base import BaseSparkSessionProvider

if TYPE_CHECKING:
from pyspark.sql import SparkSession

from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter


_T = TypeVar("_T", bound="BaseDataAccessLayer")

Expand All @@ -55,7 +54,7 @@ class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]):

def to_object(self, spark: SparkSession) -> _T:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_T, type_(spark=spark, **self.options))
return cast("_T", type_(spark=spark, **self.options))


class SparkConfig(ForceableConfig):
Expand All @@ -74,31 +73,18 @@ class SparkConfig(ForceableConfig):

_P = TypeVar("_P", bound=BaseSparkSessionProvider)

_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)


class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]):
"""Config for proxy clients that take no constructor args (e.g. no spark)."""

type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer

def to_object(self) -> _PX:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_PX, type_(**self.options))


class SparkProviderConfig(BaseObjectConfig, Generic[_P]):
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider

def to_object(self) -> _P:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_P, type_(**self.options))
return cast("_P", type_(**self.options))


class ClientConfig(BaseConfig):
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
reader_config: Union[AccessLayerObjectConfig["BaseDataCloudReader"], None] = None
writer_config: Union[AccessLayerObjectConfig["BaseDataCloudWriter"], None] = None
spark_config: Union[SparkConfig, None] = None
spark_provider_config: Union[
SparkProviderConfig[BaseSparkSessionProvider], None
Expand Down Expand Up @@ -126,7 +112,6 @@ def merge(

self.reader_config = merge(self.reader_config, other.reader_config)
self.writer_config = merge(self.writer_config, other.writer_config)
self.proxy_config = merge(self.proxy_config, other.proxy_config)
self.spark_config = merge(self.spark_config, other.spark_config)
self.spark_provider_config = merge(
self.spark_provider_config, other.spark_provider_config
Expand Down
11 changes: 6 additions & 5 deletions src/datacustomcode/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ spark_config:
spark.sql.execution.arrow.pyspark.enabled: 'true'
spark.driver.extraJavaOptions: -Djava.security.manager=allow

proxy_config:
type_config_name: LocalProxyClientProvider
einstein_predictions_config:
type_config_name: DefaultEinsteinPredictions
options:
credentials_profile: default

einstein_predictions_config:
type_config_name: DefaultEinsteinPredictions
options: {}
llm_gateway_config:
type_config_name: DefaultLLMGateway
options:
credentials_profile: default
89 changes: 1 addition & 88 deletions src/datacustomcode/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
import os
import re
import shutil
import subprocess
import tempfile
import time
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -37,15 +35,10 @@
import requests

from datacustomcode.cmd import cmd_output
from datacustomcode.credentials import AuthType
from datacustomcode.scan import find_base_directory, get_package_type

if TYPE_CHECKING:
from datacustomcode.credentials import Credentials

DATA_CUSTOM_CODE_PATH = "services/data/v63.0/ssot/data-custom-code"
DATA_TRANSFORMS_PATH = "services/data/v63.0/ssot/data-transforms"
AUTH_PATH = "services/oauth2/token"
WAIT_FOR_DEPLOYMENT_TIMEOUT = 3000

# Available compute types for Data Cloud deployments.
Expand Down Expand Up @@ -163,80 +156,6 @@ class AccessTokenResponse(BaseModel):
instance_url: str


def _retrieve_access_token(credentials: Credentials) -> AccessTokenResponse:
"""Get an access token for the Salesforce API."""
logger.debug("Getting oauth token...")

url = f"{credentials.login_url.rstrip('/')}/{AUTH_PATH.lstrip('/')}"

if credentials.auth_type == AuthType.OAUTH_TOKENS:
data = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token,
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
}
elif credentials.auth_type == AuthType.CLIENT_CREDENTIALS:
data = {
"grant_type": "client_credentials",
"client_id": credentials.client_id,
"client_secret": credentials.client_secret,
}
else:
raise ValueError(f"Unsupported auth_type: {credentials.auth_type}")

response = _make_api_call(url, "POST", data=data)
return AccessTokenResponse(**response)


def _retrieve_access_token_from_sf_cli(sf_cli_org: str) -> AccessTokenResponse:
"""Get an access token from the Salesforce CLI."""
try:
result = subprocess.run(
["sf", "org", "display", "--target-org", sf_cli_org, "--json"],
capture_output=True,
text=True,
check=True,
timeout=30,
)
except FileNotFoundError as exc:
raise RuntimeError(
"The 'sf' command was not found. "
"Please install Salesforce CLI: https://developer.salesforce.com/tools/salesforcecli"
) from exc
except subprocess.TimeoutExpired as exc:
raise RuntimeError(
f"'sf org display' timed out for org '{sf_cli_org}'"
) from exc
except subprocess.CalledProcessError as exc:
raise RuntimeError(
f"'sf org display' failed for org '{sf_cli_org}'.\n"
f"Ensure the org is authenticated via 'sf org login web'.\n"
f"stderr: {exc.stderr.strip()}"
) from exc

try:
data = json.loads(result.stdout)
except json.JSONDecodeError as exc:
raise RuntimeError(f"Failed to parse 'sf org display' output: {exc}") from exc

if data.get("status") != 0:
raise RuntimeError(
f"SF CLI error for org '{sf_cli_org}': "
f"{data.get('message', 'unknown error')}"
)

org_result = data.get("result", {})
access_token = org_result.get("accessToken")
instance_url = org_result.get("instanceUrl")
if not access_token or not instance_url:
raise RuntimeError(
f"'sf org display' did not return an access token or instance URL "
f"for org '{sf_cli_org}'"
)
return AccessTokenResponse(access_token=access_token, instance_url=instance_url)


class CreateDeploymentResponse(BaseModel):
fileUploadUrl: str

Expand Down Expand Up @@ -567,16 +486,11 @@ def zip(
def deploy_full(
directory: str,
metadata: CodeExtensionMetadata,
credentials: Union["Credentials", AccessTokenResponse],
access_token: AccessTokenResponse,
docker_network: str,
callback=None,
) -> AccessTokenResponse:
"""Deploy a data transform in the DataCloud."""
if isinstance(credentials, AccessTokenResponse):
access_token = credentials
else:
access_token = _retrieve_access_token(credentials)

# prepare payload
config = get_config(directory)

Expand All @@ -587,7 +501,6 @@ def deploy_full(
wait_for_deployment(access_token, metadata, callback)

# create data transform

if isinstance(config, DataTransformConfig):
create_data_transform(directory, access_token, metadata, config)
return access_token
Expand Down
Loading
Loading