Skip to content

Commit

Permalink
Added region name to the sagemaker cli. (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
d18s authored and aarondav committed Jun 16, 2018
1 parent 1bb2a15 commit b218deb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions mlflow/sagemaker/cli.py
Expand Up @@ -28,7 +28,8 @@ def commands():
@click.option("--bucket", "-b", help="S3 bucket to store model artifacts", required=True)
@click.option("--run_id", "-r", default=None, help="Run id")
@click.option("--container", "-c", default="mlflow_sage", help="container name")
def deploy(app_name, model_path, execution_role_arn, bucket, run_id=None, container="mlflow_sage"): # noqa
@click.option("--region-name", default="us-west-2", help="region name")
def deploy(app_name, model_path, execution_role_arn, bucket, run_id=None, container="mlflow_sage", region_name="us-west-2"): # noqa
""" Deploy model on sagemaker.
:param app_name: Name of the deployed app.
Expand All @@ -51,7 +52,8 @@ def deploy(app_name, model_path, execution_role_arn, bucket, run_id=None, contai
container_name=container,
app_name=app_name,
model_s3_path=model_s3_path,
run_id=run_id)
run_id=run_id,
region_name=region_name)


@commands.command("run-local")
Expand Down Expand Up @@ -130,4 +132,4 @@ def _check_compatible(path):
path = os.path.abspath(path)
servable = Model.load(os.path.join(path, "MLmodel"))
if pyfunc.FLAVOR_NAME not in servable.flavors:
raise Exception("Currenlty only supports pyfunc format.")
raise Exception("Currenlty only supports pyfunc format.")
4 changes: 2 additions & 2 deletions mlflow/sagemaker/deploy.py
Expand Up @@ -47,7 +47,7 @@ def _upload_s3(local_model_path, bucket, prefix):
shutil.rmtree(tmp_dir)


def _deploy(role, container_name, app_name, model_s3_path, run_id):
def _deploy(role, container_name, app_name, model_s3_path, run_id, region_name):
"""
Deploy model on sagemaker.
:param role:
Expand All @@ -57,7 +57,7 @@ def _deploy(role, container_name, app_name, model_s3_path, run_id):
:param run_id:
:return:
"""
sage_client = boto3.client('sagemaker', region_name="us-west-2")
sage_client = boto3.client('sagemaker', region_name)
ecr_client = boto3.client("ecr")
repository_conf = ecr_client.describe_repositories(
repositoryNames=[container_name])['repositories'][0]
Expand Down

0 comments on commit b218deb

Please sign in to comment.