In [None]:
!pip install sagemaker --quiet --upgrade --force-reinstall
!pip install ipywidgets==7.0.0 --quiet
!pip install langchain --quiet --upgrade

# SageMaker Asynchronous Endpoints

Amazon SageMaker Asynchronous Inference is one of the four deployment options in Amazon SageMaker, together with real-time endpoint, batch inference, and serverless inference. It queues incoming requests and processes them asynchronously, making this option ideal for requests with large payload sizes up to 1GB, long processing times, and near real-time latency requirements. However the main value that it provides when dealing with large Foundation Models, especially during a Proof Of Concept (POC) or doing development cycles, is the capability to configure Asynchronous Inference to scale the instance count to zero when there are no requests to process, thereby saving costs. More information about SageMaker Asynchronous Inference can be found in the [AWS documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html).

This notebook has been tested with the TII Falcon-40B model from SageMaker JumpStart, deployed on a `ml.g5.12xlarge`.

## Invoking the asynchronous endpoint with Langchain

To invoke the endpoint, you need to place the request payload in Amazon S3 and provide a pointer to this payload as a part of the InvokeEndpointAsync request. Upon invocation, SageMaker queues the request for processing and returns an identifier and output location as a response. Upon processing, SageMaker places the result in the Amazon S3 location. You can optionally choose to receive success or error notifications with Amazon SNS.

In [5]:
import sagemaker

# Provide the endpoint name
endpoint_name = "async-sagemaker-endpoint-name"
# Get the default bucket - optional, you can also set your own
bucket = sagemaker.Session().default_bucket()
# Set the prefix to async-sagemaker-tests/inputs  - optional, you can also set your own
prefix = "async-sagemaker-tests/inputs"

In [6]:
payload = {
    "inputs": "Write a program to compute factorial in python:", 
    "parameters": {
        "max_new_tokens": 500
    }
}

To use an asynchronous endpoint to LangChain, we defined a new class, `SagemakerAsyncEndpoint`, that extends the `SagemakerEndpoint` class already available in LangChain. Additionally, we will have to provide:

- The Amazon S3 bucket and prefix where async inference will store the inputs (and outputs)
- A maximum number of seconds to wait before timing out
- An updated `_call()` function to query the endpoint with `invoke_endpoint_async()` instead of `invoke_endpoint()`
- A way to "wake up" the asynchronous endpoint if it's in cold-start (it scaled down to zero)

In [None]:
# First step: define the content handler
from typing import Dict
from langchain.llms.sagemaker_endpoint import LLMContentHandler


class ContentHandler(LLMContentHandler):
    content_type:str = "application/json"
    accepts:str = "application/json"
    len_prompt:int = 0

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        self.len_prompt = len(prompt)
        input_str = json.dumps({"inputs": prompt, "parameters": {"max_new_tokens": 100, "do_sample": False, "repetition_penalty": 1.1}})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> str:
        response_json = output.read()
        res = json.loads(response_json)
        ans = res[0]['generated_text']
        return ans

In [None]:
# Second step: define the LLM
from langchain.llms.sagemaker_async_endpoint import SagemakerAsyncEndpoint

llm = SagemakerAsyncEndpoint(
    input_bucket=bucket,
    input_prefix=prefix,
    endpoint_name=endpoint_name,
    region_name=sagemaker.Session().boto_region_name,
    content_handler=ContentHandler(),
)

In [None]:
# Final step: define the chain and run
from langchain import PromptTemplate
from langchain.chains import LLMChain

chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["query"],
        template="{query}",
    ),
)

print(chain.run(payload['inputs']))

## Cleanup

Once you're done with testing the generation of inferences from the endpoint, remember to delete the endpoint to avoid incurring in extra charges.

In [None]:
predictor.delete_endpoint()