Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def _create_baseten_chain(
environment=baseten_options.environment,
progress_bar=progress_bar,
disable_chain_download=baseten_options.disable_chain_download,
deployment_name=baseten_options.deployment_name,
)
return BasetenChainService(
baseten_options.chain_name,
Expand Down
3 changes: 3 additions & 0 deletions truss-chains/truss_chains/private_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class PushOptionsBaseten(PushOptions):
include_git_info: bool
working_dir: pathlib.Path
disable_chain_download: bool = False
deployment_name: Optional[str] = None

@classmethod
def create(
Expand All @@ -279,6 +280,7 @@ def create(
working_dir: pathlib.Path,
environment: Optional[str] = None,
disable_chain_download: bool = False,
deployment_name: Optional[str] = None,
) -> "PushOptionsBaseten":
if promote and not environment:
environment = PRODUCTION_ENVIRONMENT_NAME
Expand All @@ -293,6 +295,7 @@ def create(
include_git_info=include_git_info,
working_dir=working_dir,
disable_chain_download=disable_chain_download,
deployment_name=deployment_name,
)


Expand Down
11 changes: 11 additions & 0 deletions truss/cli/chains_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
default=False,
help="Disable downloading of pushed chain source code from the UI.",
)
@click.option(
"--deployment-name",
type=str,
required=False,
help=(
"Name of the deployment created by the publish. Can only be used "
"in combination with '--publish' or '--promote'."
),
)
@click.pass_context
@common.common_options()
def push_chain(
Expand All @@ -236,6 +245,7 @@ def push_chain(
experimental_watch_chainlet_names: Optional[str],
include_git_info: bool = False,
disable_chain_download: bool = False,
deployment_name: Optional[str] = None,
) -> None:
"""
Deploys a chain remotely.
Expand Down Expand Up @@ -294,6 +304,7 @@ def push_chain(
include_git_info=include_git_info,
working_dir=source.parent if source.is_file() else source.resolve(),
disable_chain_download=disable_chain_download,
deployment_name=deployment_name,
)
service = deployment_client.push(
entrypoint_cls, options, progress_bar=progress.Progress
Expand Down
3 changes: 3 additions & 0 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def deploy_chain_atomic(
is_draft: bool = False,
original_source_artifact_s3_key: Optional[str] = None,
allow_truss_download: Optional[bool] = True,
deployment_name: Optional[str] = None,
):
if allow_truss_download is None:
allow_truss_download = True
Expand Down Expand Up @@ -360,6 +361,8 @@ def deploy_chain_atomic(
params.append(f"is_draft: {str(is_draft).lower()}")
if allow_truss_download is False:
params.append("allow_truss_download: false")
if deployment_name:
params.append(f'deployment_name: "{deployment_name}"')

params_str = PARAMS_INDENT.join(params)

Expand Down
4 changes: 4 additions & 0 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def create_chain_atomic(
environment: Optional[str],
original_source_artifact_s3_key: Optional[str] = None,
allow_truss_download: bool = True,
deployment_name: Optional[str] = None,
) -> ChainDeploymentHandleAtomic:
if environment and is_draft:
logging.info(
Expand All @@ -156,6 +157,7 @@ def create_chain_atomic(
truss_user_env=truss_user_env,
original_source_artifact_s3_key=original_source_artifact_s3_key,
allow_truss_download=allow_truss_download,
deployment_name=deployment_name,
)
elif chain_id:
# This is the only case where promote has relevance, since
Expand All @@ -171,6 +173,7 @@ def create_chain_atomic(
truss_user_env=truss_user_env,
original_source_artifact_s3_key=original_source_artifact_s3_key,
allow_truss_download=allow_truss_download,
deployment_name=deployment_name,
)
except ApiError as e:
if (
Expand All @@ -193,6 +196,7 @@ def create_chain_atomic(
truss_user_env=truss_user_env,
original_source_artifact_s3_key=original_source_artifact_s3_key,
allow_truss_download=allow_truss_download,
deployment_name=deployment_name,
)

return ChainDeploymentHandleAtomic(
Expand Down
3 changes: 3 additions & 0 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def push_chain_atomic(
environment: Optional[str] = None,
progress_bar: Optional[Type["progress.Progress"]] = None,
disable_chain_download: bool = False,
deployment_name: Optional[str] = None,
) -> ChainDeploymentHandleAtomic:
# If we are promoting a model to an environment after deploy, it must be published.
# Draft models cannot be promoted.
Expand All @@ -300,6 +301,7 @@ def push_chain_atomic(
origin=custom_types.ModelOrigin.CHAINS,
progress_bar=progress_bar,
disable_truss_download=disable_chain_download,
deployment_name=deployment_name,
)
oracle_data = custom_types.OracleData(
model_name=push_data.model_name,
Expand Down Expand Up @@ -337,6 +339,7 @@ def push_chain_atomic(
environment=environment,
original_source_artifact_s3_key=raw_chain_s3_key,
allow_truss_download=not disable_chain_download,
deployment_name=deployment_name,
)
logging.info("Successfully pushed to baseten. Chain is building and deploying.")
return chain_deployment_handle
Expand Down
44 changes: 44 additions & 0 deletions truss/tests/cli/test_chains_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,47 @@ def test_chains_push_help_includes_disable_chain_download():

assert result.exit_code == 0
assert "--disable-chain-download" in result.output


def test_chains_push_with_deployment_name_flag():
"""Test that --deployment-name flag is properly parsed and passed through."""
runner = CliRunner()

mock_entrypoint_cls = Mock()
mock_entrypoint_cls.meta_data.chain_name = "test_chain"
mock_entrypoint_cls.display_name = "TestChain"

mock_service = Mock()
mock_service.run_remote_url = "http://test.com/run_remote"
mock_service.is_websocket = False

with patch(
"truss_chains.framework.ChainletImporter.import_target"
) as mock_importer:
with patch("truss_chains.deployment.deployment_client.push") as mock_push:
mock_importer.return_value.__enter__.return_value = mock_entrypoint_cls
mock_push.return_value = mock_service

result = runner.invoke(
truss_cli,
[
"chains",
"push",
"test_chain.py",
"--deployment-name",
"custom_deployment",
"--remote",
"test_remote",
"--publish",
"--dryrun",
],
)

assert result.exit_code == 0

mock_push.assert_called_once()
call_args = mock_push.call_args
options = call_args[0][1]

assert hasattr(options, "deployment_name")
assert options.deployment_name == "custom_deployment"
24 changes: 24 additions & 0 deletions truss/tests/remote/baseten/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,30 @@ def test_deploy_chain_deployment(mock_post, baseten_api):
assert 'chain_id: "chain_id"' in gql_mutation
assert "dependencies:" in gql_mutation
assert "entrypoint:" in gql_mutation
assert "deployment_name" not in gql_mutation


@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
def test_deploy_chain_deployment_with_deployment_name(mock_post, baseten_api):
baseten_api.deploy_chain_atomic(
environment="production",
chain_id="chain_id",
dependencies=[],
entrypoint=ChainletDataAtomic(
name="chainlet-1",
oracle=OracleData(
model_name="model-1",
s3_key="s3-key-1",
encoded_config_str="encoded-config-str-1",
),
),
truss_user_env=b10_types.TrussUserEnv.collect(),
deployment_name="chain-deployment-name",
)

gql_mutation = mock_post.call_args[1]["json"]["query"]

assert 'deployment_name: "chain-deployment-name"' in gql_mutation


@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
Expand Down
11 changes: 10 additions & 1 deletion truss/tests/remote/baseten/test_chain_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_push_chain_atomic_with_chain_upload(
chain_root = context["chain_root"]

context["mock_prepare_push"].return_value = mock_push_data
deployment_name = "custom_deployment"

result = remote.push_chain_atomic(
chain_name=chain_name,
Expand All @@ -194,13 +195,18 @@ def test_push_chain_atomic_with_chain_upload(
truss_user_env=truss_user_env,
chain_root=chain_root,
publish=True,
deployment_name=deployment_name,
)
assert result == mock_create_chain_atomic.return_value

mock_archive_dir.assert_called_once_with(dir=chain_root, progress_bar=None)
mock_upload_chain_artifact.assert_called_once()

mock_create_chain_atomic.assert_called_once()
create_kwargs = mock_create_chain_atomic.call_args.kwargs
assert create_kwargs["deployment_name"] == deployment_name

prepare_kwargs = context["mock_prepare_push"].call_args.kwargs
assert prepare_kwargs["deployment_name"] == deployment_name


@patch("truss.remote.baseten.remote.create_chain_atomic")
Expand Down Expand Up @@ -239,6 +245,9 @@ def test_push_chain_atomic_without_chain_upload(
mock_upload.assert_not_called()

mock_create_chain_atomic.assert_called_once()
create_kwargs = mock_create_chain_atomic.call_args.kwargs
assert "deployment_name" in create_kwargs
assert create_kwargs["deployment_name"] is None


@patch("truss.remote.baseten.core.multipart_upload_boto3")
Expand Down
51 changes: 51 additions & 0 deletions truss/tests/remote/baseten/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,57 @@ def test_create_chain_no_existing_chain(remote):
assert deployment_handle.chain_deployment_id == "new-chain-deployment-id"


def test_create_chain_with_deployment_name(remote):
with requests_mock.Mocker() as m:
m.post(
_TEST_REMOTE_GRAPHQL_PATH,
[
{"json": {"data": {"chains": []}}},
{
"json": {
"data": {
"deploy_chain_atomic": {
"chain_deployment": {
"id": "new-chain-deployment-id",
"chain": {
"id": "new-chain-id",
"hostname": "hostname",
},
}
}
}
}
},
],
)

deployment_name = "chain-deployment"
create_chain_atomic(
api=remote.api,
chain_name="new_chain",
entrypoint=ChainletDataAtomic(
name="chainlet-1",
oracle=OracleData(
model_name="model-1",
s3_key="s3-key-1",
encoded_config_str="encoded-config-str-1",
),
),
dependencies=[],
truss_user_env=b10_types.TrussUserEnv.collect(),
is_draft=False,
environment=None,
deployment_name=deployment_name,
)

create_chain_graphql_request = m.request_history[1]

assert (
'deployment_name: "chain-deployment"'
in create_chain_graphql_request.json()["query"]
)


def test_create_chain_with_existing_chain_promote_to_environment_publish_false(remote):
mock_deploy_response = {
"chain_deployment": {
Expand Down
Loading