# Amazon API Gateway WebSocket API

> A WebSocket API in API Gateway is a collection of WebSocket routes that are integrated with backend HTTP endpoints, Lambda functions, or other AWS services. You can use API Gateway features to help you with all aspects of the API lifecycle, from creation through monitoring your production APIs. For more information, see [Working with WebSocket APIs](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api.html).

This notebook shows how LangChain Callback can be used to stream the LLM's response to a client via API Gateway WebSocket API.


## Streaming Callback Handler

Assuming we are using API Gateway and AWS Lambda to build an LLM chat application. Users will establish a WebSocket connection with API Gateway and send input to the LLM via Lambda. The Lambda will use LangChain to orchestrate the LLM and return the response to the client via API Gateway.

Here is an example of a Lambda handler that uses LangChain Callback to stream the LLM's response to the client via API Gateway WebSocket API.


In [None]:
import json
import os

import boto3
from langchain.callbacks import StreamingAmazonAPIGatewayWebSocketCallbackHandler
from langchain.chains import ConversationChain
from langchain.llms.openai import OpenAI
from langchain.memory import ConversationBufferMemory, DynamoDBChatMessageHistory

# environment variables
session_table_name = os.environ["SessionTableName"]

# init dependencies outside of handler, with streaming enabled
llm = OpenAI(streaming=True)
boto3_session = boto3.session.Session()


def handler(event, context):
    # parse event
    domain = event["requestContext"]["domainName"]
    stage = event["requestContext"]["stage"]
    connection_id = event["requestContext"]["connectionId"]
    body = json.loads(event["body"])

    # set callback handler
    # so that every time the model generates a chunk of response,
    # it is sent to the client
    callback = StreamingAmazonAPIGatewayWebSocketCallbackHandler(
        boto3_session,
        # see https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-how-to-call-websocket-api-connections.html
        f"https://{domain}/{stage}",
        connection_id,
        # format the websocket message
        on_token=lambda t: json.dumps({"kind": "token", "chunk": t}),
        on_end=lambda: json.dumps({"kind": "end"}),
        on_err=lambda e: json.dumps({"kind": "error"}),
    )
    llm.callbacks = [callback]

    history = DynamoDBChatMessageHistory(
        table_name=session_table_name,
        # use connection_id as session_id for simplicity.
        # in production, you should design the session_id yourself
        session_id=connection_id,
        boto3_session=boto3_session,
    )
    memory = ConversationBufferMemory(chat_memory=history)
    conversation = ConversationChain(llm=llm, memory=memory)

    conversation.predict(input=body["input"])

    return {"statusCode": 200}