In [4]:
import os
import boto3
import sagemaker
from sentence_transformers import SentenceTransformer, util

role = 'arn:aws:iam::484906661071:role/sagemaker'

sess = sagemaker.Session()
bucket = sess.default_bucket()
prefix = "sagemaker/sentence-transformers/multi-qa-mpnet-base-dot-v1"

In [None]:
model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1')
model.save_pretrained("./model")

In [1]:
!ls -rtlh ./model/

total 419M
-rw-rw-r-- 1 josh josh  116 Nov 11 14:01 config_sentence_transformers.json
-rw-rw-r-- 1 josh josh  669 Nov 11 14:01 config.json
-rw-rw-r-- 1 josh josh  438 Nov 11 14:01 tokenizer_config.json
-rw-rw-r-- 1 josh josh  239 Nov 11 14:01 special_tokens_map.json
-rw-rw-r-- 1 josh josh 418M Nov 11 14:01 pytorch_model.bin
-rw-rw-r-- 1 josh josh 227K Nov 11 14:01 vocab.txt
-rw-rw-r-- 1 josh josh 456K Nov 11 14:01 tokenizer.json
-rw-rw-r-- 1 josh josh   53 Nov 11 14:01 sentence_bert_config.json
drwxrwxr-x 2 josh josh 4.0K Nov 11 14:01 1_Pooling
-rw-rw-r-- 1 josh josh 8.2K Nov 11 14:01 README.md
-rw-rw-r-- 1 josh josh  229 Nov 11 14:01 modules.json


In [2]:
!cd model && tar czvf ../model.tar.gz *

1_Pooling/
1_Pooling/config.json
config.json
config_sentence_transformers.json
modules.json
pytorch_model.bin
README.md
sentence_bert_config.json
special_tokens_map.json
tokenizer_config.json
tokenizer.json
vocab.txt


In [5]:
fObj = open("model.tar.gz", "rb")
key = os.path.join(prefix, "model.tar.gz")
boto3.Session().resource("s3").Bucket(bucket).Object(key).upload_fileobj(fObj)
print(os.path.join(bucket, key))

sagemaker-us-east-1-484906661071/sagemaker/sentence-transformers/multi-qa-mpnet-base-dot-v1/model.tar.gz


In [6]:
pretrained_model_data = "s3://{}/{}".format(bucket, key)
pretrained_model_data

's3://sagemaker-us-east-1-484906661071/sagemaker/sentence-transformers/multi-qa-mpnet-base-dot-v1/model.tar.gz'

In [7]:
from sagemaker.pytorch import PyTorch, PyTorchModel
from sagemaker.predictor import RealTimePredictor
from sagemaker import get_execution_role

class StringPredictor(RealTimePredictor):
    def __init__(self, endpoint_name, sagemaker_session):
        super(StringPredictor, self).__init__(endpoint_name, sagemaker_session, content_type='text/plain')

In [25]:
pytorch_model = PyTorchModel(model_data = pretrained_model_data, 
                             role=role, 
                             entry_point ='inference.py',
                             source_dir = './code', 
                             framework_version = '1.3.1',
                             py_version = 'py3',
                             predictor_cls=StringPredictor)

predictor = pytorch_model.deploy(instance_type='ml.m5.large', initial_instance_count=1)

--------!

The class RealTimePredictor has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
content_type is a no-op in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


In [32]:
import boto3

client = boto3.client('sagemaker-runtime')

custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"         # An example of a trace ID.
endpoint_name = predictor.endpoint_name                            # Your endpoint name.
content_type = "text/plain"                                        # The MIME type of the input data in the request body.
accept = "text/plain"                                              # The desired MIME type of the inference in the response.
payload = "test input"                                             # Payload for inference.
response = client.invoke_endpoint(
    EndpointName=endpoint_name, 
    CustomAttributes=custom_attributes, 
    ContentType=content_type,
    Accept=accept,
    Body=payload
    )

print(response)          

res = response['Body'].read()

embedding = [float(i) for i in res.decode('UTF-8').split(']')[0].split('[')[1].split(',\n')]

print(embedding)

{'ResponseMetadata': {'RequestId': '2562b011-31d0-43f5-bce6-b85922580642', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '2562b011-31d0-43f5-bce6-b85922580642', 'x-amzn-invoked-production-variant': 'AllTraffic', 'date': 'Fri, 12 Nov 2021 20:38:03 GMT', 'content-type': 'text/plain', 'content-length': '17907'}, 'RetryAttempts': 0}, 'ContentType': 'text/plain', 'InvokedProductionVariant': 'AllTraffic', 'Body': <botocore.response.StreamingBody object at 0x7f9bb7e6a100>}
[-0.0976128950715065, -0.7333332300186157, -0.2931416630744934, -0.1154574528336525, -0.02998112328350544, -0.12487849593162537, 0.591175377368927, 0.2437036782503128, 0.16138815879821777, 0.22071115672588348, 0.39980119466781616, -0.15567228198051453, -0.2770650088787079, 0.291347473859787, 0.012367546558380127, -0.002163163386285305, -0.22430819272994995, 0.19981400668621063, -0.1338842660188675, -0.10652517527341843, 0.11312411725521088, -0.12170013785362244, -0.27420100569725037, 0.02895766869187355, -0.173