## Deploying Mistral 7B with TensorRT-LLM through SageMaker LMI container and streaming outputs

### 1. Import required packages, set up


In [None]:
!pip install -U boto3 sagemaker --quiet

In [None]:
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

### 2. Build SageMaker endpoint

In this step, we will build SageMaker endpoint from scratch, using the AWS Large Model Inference (LMI) container.


#### 2.1. Get the container image URI


[All available LMI container images](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)

In [None]:
image_uri = image_uris.retrieve(
        # framework="djl-tensorrtllm",
        framework="djl-deepspeed", # use this container version for vLLM, deepspeed and lmidist backends
        region=sess.boto_session.region_name,
        version="0.26.0"
    )

#### 2.2. Set up required enviroment variales and create SageMaker Model


The model to be loaded, as well as all the common and specific LMI backend configurations can be set via environment variables. Instead of pulling from HF Model Hub, you can pull a model from an S3 bucket. You can also configure these parameters via a serving.properties file that you pass to the endpoint, which allows you to pass other artifacts along with it (even model artifacts) in a model folder. See [here](https://github.com/deepjavalibrary/djl-serving/blob/master/serving/docs/lmi_new/deployment_guide/configurations.md#container-and-model-configurations) for more details on server configuration, and see [here](https://github.com/deepjavalibrary/djl-serving/blob/master/serving/docs/lmi_new/deployment_guide/backend-selection.md) for guidance on what backend to select. Specific vLLM guidance can be found [here](https://github.com/deepjavalibrary/djl-serving/blob/master/serving/docs/lmi_new/user_guides/vllm_user_guide.md).

In [None]:
env = {
    'HF_MODEL_ID':'mistralai/Mistral-7B-Instruct-v0.1',
    'OPTION_ROLLING_BATCH':'vllm',
    'TENSOR_PARALLEL_DEGREE': 'max',
    'OPTION_MAX_MODEL_LEN':'4000'
}

model = Model(image_uri=image_uri, role=role, env=env)

#### 2.3. Create SageMaker real-time endpoint

We will deploy our model to a `g5.xlarge` instance, backed by a single 24GB A10G GPU.

In [None]:
instance_type = "ml.g5.xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    # container_startup_health_check_timeout=3600
)

### 3. Create prompt building utility functions


Let's write a function that builds a prompt format to induce instruction-following behaviour. Note that this format is model dependant (the Mistral model was instruction-tuned with the [INST] tag, but other models may have see different special tokens)

In [None]:
def build_mistral_prompt(instructions):
    stop_token = "</s>"
    start_token = "<s>"
    startPrompt = f"{start_token}[INST] "
    endPrompt = " [/INST]"
    conversation = []
    for index, instruction in enumerate(instructions):
        if instruction["role"] == "system" and index == 0:
            conversation.append(f"<<SYS>>\n{instruction['content']}\n<</SYS>>\n")
        elif instruction["role"] == "user":
            conversation.append(instruction["content"].strip())
        else:
            conversation.append(f"{endPrompt} {instruction['content'].strip()} {stop_token}{startPrompt}")

    return startPrompt + "".join(conversation) + endPrompt

And  another function to join our base system prompt and actual user request

In [None]:
def get_instructions(user_content):

    '''
    Note: We are creating a fresh user content everytime by initializing instructions for every user_content.
    This is to avoid past user_content when you are inferencing multiple times with new ask everytime.
    ''' 

    system_content = '''
    You are a friendly and knowledgeable email marketing agent, Mr.MightyMark, working at AnyCompany. 
    Your goal is to send email to subscribers to help them understand the value of the new product and generate excitement for the launch.

    Here are some tips on how to achieve your goal:

    Be personal. Address each subscriber by name and use a friendly and conversational tone.
    Be informative. Explain the key features and benefits of the new product in a clear and concise way.
    Be persuasive. Highlight how the new product can solve the subscriber's problems or improve their lives.
    Be engaging. Use emojis to make your emails more visually appealing and interesting to read.

    By following these tips, you can use email marketing to help your company launch a successful software product.
    '''

    instructions = [
        { "role": "system","content": f"{system_content} "},
    ]
    
    instructions.append({"role": "user", "content": f"{user_content}"})
    
    return instructions

Now we build and print our final prompt

In [None]:
user_ask_1 = f'''
AnyCompany recently announced new service launch named AnyCloud Internet Service.
Write a short email about the product launch with Call to action to Alice Smith, whose email is alice.smith@example.com
Mention the Coupon Code: EARLYB1RD to get 20% for 1st 3 months.
'''
instructions = get_instructions(user_ask_1)
prompt = build_mistral_prompt(instructions)
print(prompt)

### 4. Test inference by streaming model output

First, we define a function that wraps around the sagemaker runtime's invoke_endpoint_with_response_stream method

In [None]:
import boto3
import json

sagemaker_runtime = boto3.client('sagemaker-runtime')
sagemaker_client = boto3.client('sagemaker')

In [None]:
def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):
    response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(payload), 
        ContentType="application/json",
        # CustomAttributes='accept_eula=false'
    )
    return response_stream

We will use a utility function to parse the output buffer from the SageMaker endpoint and print tokens as they are streamed

In [None]:
from utils.LineIterator import LineIterator

def print_response_stream(response_stream):
    event_stream = response_stream.get('Body')
    for line in LineIterator(event_stream):
        print(line, end='')

Set your selected generation parameters, and create a payload

In [None]:
inference_params = {
        "do_sample": True,
        "top_p": 0.6,
        "temperature": 0.9,
        "top_k": 50,
        "max_new_tokens": 512,
    }

payload = {
    "inputs":  prompt,
    "parameters": inference_params
}

In [None]:
resp = get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload)
print_response_stream(resp)

### 5. Clean up endpoint

Finally, we terminate the endpoint so that it's not consuming resources

In [None]:
predictor = sagemaker.predictor.Predictor(endpoint_name)
predictor.delete_endpoint()