In [1]:
# initialize sagemaker variables
import sagemaker

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/bert-pytorch'

role = sagemaker.get_execution_role()

In [2]:
# download our trained model to the notebook runtime
!wget https://gradient-fire.s3.amazonaws.com/model.pth

--2020-02-29 06:48:15--  https://gradient-fire.s3.amazonaws.com/model.pth
Resolving gradient-fire.s3.amazonaws.com (gradient-fire.s3.amazonaws.com)... 52.216.98.147
Connecting to gradient-fire.s3.amazonaws.com (gradient-fire.s3.amazonaws.com)|52.216.98.147|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 437983063 (418M) [application/x-www-form-urlencoded]
Saving to: ‘model.pth’


2020-02-29 06:48:22 (62.6 MB/s) - ‘model.pth’ saved [437983063/437983063]



In [3]:
# create a tar file from the model file
import tarfile
with tarfile.open('model.tar.gz', mode='w:gz') as archive:
    archive.add('model.pth', recursive=True)

In [4]:
model_path = 'model.tar.gz'

In [5]:
# upload model artifacts to S3
model_artifact = sagemaker_session.upload_data(path=model_path, bucket=bucket, key_prefix=prefix)

In [6]:
model_artifact

's3://sagemaker-us-east-1-800756380562/sagemaker/bert-pytorch/model.tar.gz'

In [7]:
from sagemaker.predictor import RealTimePredictor
from sagemaker.pytorch import PyTorchModel

# setup the RealTimePredictor object for serializing the inputs to tensors for pytorch
class StringPredictor(RealTimePredictor):
    def __init__(self, endpoint_name, sagemaker_session):
        super(StringPredictor, self).__init__(endpoint_name, sagemaker_session, content_type='text/plain')

# build the sagemaker model
model = PyTorchModel(model_data=model_artifact,
                     role = role,
                     framework_version='1.0.0',
                     entry_point='predict.py',
                     source_dir='serve',
                     predictor_cls=StringPredictor)

In [8]:
%%time
# deploy the model as an endpoint
predictor = model.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

-------------------------!

In [16]:
predictor.endpoint

'sagemaker-pytorch-2020-02-29-07-02-00-055'

In [13]:
# make predictions
predictor.predict("Exciting, entertaining, and emotionally impactful,\
Avengers: Endgame does whatever it takes to deliver a satisfying finale to Marvel's epic Infinity Saga.")

b'1'

In [14]:
# make predictions
predictor.predict("To go even further, the Marvel Studios \
films are so bad precisely because they are good. Even when \
these films are firing on all cylinders – milking their computerized \
action set-pieces for maximum whiz-bang effect, nailing their glib one-liners,\
purposefully commanding a requisite sense of seriousness from their sprawling \
cast of superpowered characters – their impact on the motion-picture arts amounts \
to a net negative. They privilege sameness over invention, to such a fatal extent\
that even modest revisions on the established formula (as in the much-ballyhooed \
Thor: Ragnarok) are praised to wild excess. Their baseline tradition of passable quality \
inures us to demanding anything better, or anything else at all, really.")

b'0'

## AWS Lambda

Setup up a lambda function with the following code

```python
import boto3

def lambda_handler(event, context):

    # The SageMaker runtime is what allows us to invoke the endpoint that we've created.
    runtime = boto3.Session().client('sagemaker-runtime')

    # Now we use the SageMaker runtime to invoke our endpoint, sending the review we were given
    response = runtime.invoke_endpoint(EndpointName = '<ENDPOINT NAME>',    # The name of the endpoint we created
                                       ContentType = 'text/plain',                 # The data format that is expected
                                       Body = event['body'])                       # The actual review

    # The response is an HTTP response whose body contains the result of our inference
    result = response['Body'].read().decode('utf-8')

    return {
        'statusCode' : 200,
        'headers' : { 'Content-Type' : 'text/plain', 'Access-Control-Allow-Origin' : '*' },
        'body' : result
    }
```