In [None]:
!cd container; ./build_and_push.sh gnn-inference

In [None]:
# Code is referenced from Sagemaker examples
# https://github.com/awslabs/amazon-sagemaker-examples/blob/master/advanced_functionality/multi_model_bring_your_own/multi_model_endpoint_bring_your_own.ipynb

# Setup clients
import boto3
from sagemaker import get_execution_role

sm_client = boto3.client(service_name='sagemaker')
runtime_sm_client = boto3.client(service_name='sagemaker-runtime')

account_id = boto3.client('sts').get_caller_identity()['Account']
region = boto3.Session().region_name
role = get_execution_role()

In [None]:
# import models into hosting

from time import gmtime, strftime

model_url = 's3://jdurago-insight-2020a/output/baseline/drug-prediction-gcn-200128-2135-015-1feea50f/output/' # pointer to S3 bucket where model is saved
model_name = 'DEMO-MultiModelModel' + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
container_image = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account_id, region, 'gnn-inference')

print('Model name: ' + model_name)
print('Model data Url: ' + model_url)
print('Container image: ' + container_image)

container = {
    'Image': container_image,
    'ModelDataUrl': model_url,
    'Mode': 'MultiModel'
}
create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    Containers = [container])

print("Model Arn: " + create_model_response['ModelArn'])

In [None]:
# create endpoint configuration

endpoint_config_name = 'DEMO-MultiModelEndpointConfig-' + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print('Endpoint config name: ' + endpoint_config_name)

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType': 'ml.t2.medium',
        'InitialInstanceCount': 1,
        'InitialVariantWeight': 1,
        'ModelName': model_name,
        'VariantName': 'AllTraffic'}])

print("Endpoint config Arn: " + create_endpoint_config_response['EndpointConfigArn'])

In [None]:
# create endpoint
import time

endpoint_name = 'DEMO-MultiModelEndpoint-' + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print('Endpoint name: ' + endpoint_name)

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name)
print('Endpoint Arn: ' + create_endpoint_response['EndpointArn'])

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Endpoint Status: " + status)

print('Waiting for {} endpoint to be in service...'.format(endpoint_name))
waiter = sm_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=endpoint_name)

# Invoke Models

In [None]:
import json

# Data to be sent to endpoint
payload = 'CCC1=[O+][Cu-3]2([O+]=C(CC)C1)[O+]=C(CC)CC(CC)=[O+]2'
payload = json.dumps(payload)
payload

In [None]:
endpoint_name = 'graph-neural-net-endpoint'

In [None]:
# Run inference

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType='application/json',
    TargetModel='model.tar.gz', # this is the rest of the S3 path where the model artifacts are located
    Body=payload)

print(response['Body'].read().decode())