# Text Embedding: How to run inference on the endpoint you have created?

In [2]:
import json
import boto3

Let's put in some example input text. You can put in any text and the model will provide a single, fixed-length embedding vector for each input text.

In [3]:
text1 = "How cute your dog is!"
text2 = "Your dog is so cute."
text3 = "The mitochondria is the powerhouse of the cell."

### Query endpoint that you have created
You can query the endpoint with a batch of input texts within a json payload. Here, we send a single request to the endpoint and the parsed response is a list of the embedding vectors.

In [4]:
newline, bold, unbold = '\n', '\033[1m', '\033[0m'
endpoint_name = 'jumpstart-dft-hf-textembedding-gpt-j-6b-fp16'


def query_endpoint_with_json_payload(encoded_json):
    client = boto3.client('runtime.sagemaker')
    response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/json', Body=encoded_json)
    return response


def parse_response_multiple_texts(query_response):
    model_predictions = json.loads(query_response['Body'].read())
    embeddings = model_predictions['embedding']
    return embeddings


payload = {"text_inputs": [text1, text2, text3]}
query_response = query_endpoint_with_json_payload(json.dumps(payload).encode('utf-8'))
embeddings = parse_response_multiple_texts(query_response)


In [6]:
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from typing import Dict, List
class ContentHandler2(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: List[str], model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"text_inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["embedding"]

In [8]:
from langchain.embeddings import SagemakerEndpointEmbeddings

In [12]:
endpoint_embedding = 'jumpstart-dft-hf-textembedding-gpt-j-6b-fp16'

In [13]:
content_handler2 = ContentHandler2()
embeddings = SagemakerEndpointEmbeddings(
    # endpoint_name="endpoint-name",
    # credentials_profile_name="credentials-profile-name",
    endpoint_name=endpoint_embedding,
    region_name="us-east-1",
    content_handler=content_handler2,
)