# SageMakerAsyncEndpoint

Amazon SageMaker Asynchronous Inference is one of the four deployment options in Amazon SageMaker, together with real-time endpoint, batch inference, and serverless inference. It queues incoming requests and processes them asynchronously, making this option ideal for requests with large payload sizes up to 1GB, long processing times, and near real-time latency requirements.

However the main value that it provides when dealing with large Foundation Models, especially during a Proof Of Concept (PoC) or doing development cycles, is the capability to configure Asynchronous Inference to scale the instance count to zero when there are no requests to process, thereby saving costs. More information about SageMaker Asynchronous Inference can be found in the [AWS documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html).

This notebook has been tested with the llama-7b-chat-hf model, deployed on a `ml.g5.2xlarge`.

In [None]:
!pip3 install langchain boto3 sagemaker

## Set up

You have to set up following required parameters of the `SagemakerAsyncEndpoint` call:
- `endpoint_name`: The name of the endpoint from the deployed Sagemaker model.
    Must be unique within an AWS Region.

This obviously means that an async sagemaker endpoint needs to be running at this endpoint name.

Not required but important to notice:
- `credentials_profile_name`: The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which
    has either access keys or role information specified.
    If not specified, the default credential profile or, if on an EC2 instance,
    credentials from IMDS will be used.
    See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

Make sure, that the profile has proper rights to call the sagemaker endpoint.

## Example

In [None]:
import json
from typing import Dict

from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate

example_doc_1 = """
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
"""

docs = [
    Document(
        page_content=example_doc_1,
    )
]


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"prompt": prompt, "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]


query = """How long was Elizabeth hospitalized?
"""

prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer:"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

## Example using the system configured AWS profile

In [None]:
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import SagemakerAsyncEndpoint

content_handler = ContentHandler()

chain = load_qa_chain(
    llm=SagemakerAsyncEndpoint(
        endpoint_name="endpoint-name",
        credentials_profile_name="credentials-profile-name",
        region_name="us-west-2",
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    ),
    prompt=PROMPT,
)

chain({"input_documents": docs, "question": query}, return_only_outputs=True)

## Example to initialize with external boto3 session

### for cross account scenarios

In this case an already initialized boto3 session is needed as shown in the example.

In [None]:
import boto3

roleARN = "arn:aws:iam::123456789:role/cross-account-role"
sts_client = boto3.client("sts")
response = sts_client.assume_role(
    RoleArn=roleARN, RoleSessionName="CrossAccountSession"
)

session = boto3.Session(
    region_name="us-west-2",
    aws_access_key_id=response["Credentials"]["AccessKeyId"],
    aws_secret_access_key=response["Credentials"]["SecretAccessKey"],
    aws_session_token=response["Credentials"]["SessionToken"],
)

content_handler = ContentHandler()

chain = load_qa_chain(
    llm=SagemakerAsyncEndpoint(
        endpoint_name="endpoint-name",
        session=session,
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    ),
    prompt=PROMPT,
)

chain({"input_documents": docs, "question": query}, return_only_outputs=True)