-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfix_model_permission.py
44 lines (35 loc) · 1.76 KB
/
fix_model_permission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import argparse
import boto3
import json
import os
import logging
from botocore.exceptions import ClientError
# this script is a workaround to fix some permission issues with the file
# created for the model and stored in an S3 bucket
s3_client = boto3.client('s3')
sm_client = boto3.client('sagemaker')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--log-level", type=str, default=os.environ.get("LOGLEVEL", "INFO").upper())
parser.add_argument("--prod-config-file", type=str, default="prod-config-export.json")
args, _ = parser.parse_known_args()
# Configure logging to output the line number and message
log_format = "%(levelname)s: [%(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(format=log_format, level=args.log_level)
# first retrieve the name of the package that will be deployed
model_package_name = None
with open(args.prod_config_file, 'r') as f:
for param in json.loads(f.read()):
if param.get('ParameterKey') == 'ModelPackageName':
model_package_name = param.get('ParameterValue')
if model_package_name is None:
raise Exception("Configuration file must include ModelPackageName parameter")
# then, describe it to get the S3 URL of the model
resp = sm_client.describe_model_package(ModelPackageName=model_package_name)
model_data_url = resp['InferenceSpecification']['Containers'][0]['ModelDataUrl']
_,_,bucket_name,key = model_data_url.split('/', 3)
# finally, copy the file to override the permissions
with open('/tmp/model.tar.gz', 'wb') as data:
s3_client.download_fileobj(bucket_name, key, data)
with open('/tmp/model.tar.gz', 'rb') as data:
s3_client.upload_fileobj(data, bucket_name, key)