## vllm-lmi rollingbatch Yi-34B-chat-4bits deployment guide

### In this tutorial, you will use vllm backend of Large Model Inference(LMI) DLC to deploy Yi-34B-chat-4bits and run inference with it.

Please make sure the following permission granted before running the notebook:

* S3 bucket push access
* SageMaker access




### Step 1: Let's bump up SageMaker and import stuff

In [None]:
%pip install sagemaker --upgrade  --quiet

In [None]:
%pip install transformers sentencepiece --upgrade  --quiet

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()

### Step 2: Start preparing model artifacts

In LMI container, we expect some artifacts to help setting up the model

* serving.properties (required): Defines the model server settings
* model.py (optional): A python file to define the core inference logic
* requirements.txt (optional): Any additional pip wheel need to install

In [None]:
%%writefile serving.properties
engine=Python
option.model_id=01-ai/Yi-34B-Chat-4bits
option.tensor_parallel_degree=4
option.max_rolling_batch_size=64
option.rolling_batch=vllm
option.quantize=awq
option.dtype=fp16

In [None]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

### Step 3: Start building SageMaker endpoint

#### Getting the container image URI

In [None]:
image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=sess.boto_session.region_name,
        version="0.26.0"
    )

#### Upload artifact on S3 and create SageMaker model

In [None]:
model_name = "01-ai/Yi-34B-Chat-4bits"
s3_code_prefix = f"large-model-vllm/{model_name}code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

#### Create SageMaker endpoint with a specified instance type

In [None]:
%%time
instance_type = "ml.g4dn.12xlarge"
endpoint_name = sagemaker.utils.name_from_base(f"lmi-model-{model_name.replace('/', '-')}")
print(f"endpoint_name: {endpoint_name}")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             container_startup_health_check_timeout=1800
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
)

### Step 4: Run inference

In [None]:
from transformers import AutoTokenizer

MODEL_DIR = model_name
# model = AutoModelForCausalLM.from_pretrained(MODEL_DIR, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=False)



In [None]:
import json
# messages = [
#     {"role": "user", "content": "世界上第二高的山峰是哪座"},
#     {"role": "assistant", "content": "世界上第二高的山峰是中国的珠穆朗玛峰。珠穆朗玛峰位于中国西藏自治区南部边境与尼泊尔交界的喜马拉雅山脉中，海拔高度为8848.86米（29,029英尺）。它是世界上海拔最高的山脉，同时也是中国大陆的最高峰。\\n\\n珠穆朗玛峰的名字来源于藏语，“珠穆”意为女神，“朗玛”意为母仪天下，整体意为“大地之母”。这座山峰不仅在登山界享有盛誉，也是全球登山爱好者和探险"},
#     {"role": "user", "content": "一个句子总结"}
    
# ]
messages = [
    {"role": "user", "content": "世界上第二高的山峰是哪座"},
    
]


input_text = tokenizer.apply_chat_template(conversation=messages, tokenize=False, add_generation_prompt=True)
print(json.dumps(input_text, ensure_ascii=False))

In [None]:
%%time
parameters = {
                "max_new_tokens":128,
                "do_sample":True,
                "temperature": 0.6,
                "eos_token_id": 7,
                "top_p": 0.8
            }
response = predictor.predict(
    {"inputs": input_text, "parameters": {"max_new_tokens":128, "do_sample":True}}
)

text = str(response, 'utf-8')
print(text)

## Streaming

In [None]:
import json
import boto3
from utils.LineIterator import LineIterator

smr_client = boto3.client("sagemaker-runtime")
def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):
    response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(payload), 
        ContentType="application/json",
        CustomAttributes='accept_eula=false'
    )
    return response_stream



def print_response_stream(response_stream):
    event_stream = response_stream.get('Body')
    for line in LineIterator(event_stream):
        print(line, end='')

In [None]:
parameters = {
                "max_new_tokens":1024,
                "do_sample":True,
                "temperature": 0.6,
                "top_p": 0.8,
                "repetition_penalty": 1.2,
                "stop":"<|im_end|>"
            }
payload = {
    "inputs":  input_text,
    "parameters": parameters,
    "stream": True ## <-- to have response stream.
}
response_stream = get_realtime_response_stream(smr_client, endpoint_name, payload)
print_response_stream(response_stream)

## Clear Resource

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()