# ChatGLM2-6B inference(streaming supported) using an LMI container with SageMaker
Reference documentation: https://aws.amazon.com/blogs/machine-learning/elevating-the-generative-ai-experience-introducing-streaming-support-in-amazon-sagemaker-hosting/

## 1. Prepare code artifact tar ball

In [2]:
%%writefile serving.properties
engine=Python
option.model_id=THUDM/chatglm2-6b
option.trust_remote_code=true
option.tensor_parallel_degree=1
option.enable_streaming=true
option.prefix_checkpoint_path=./pytorch_model.bin

Writing serving.properties


In [3]:
%%writefile model.py
from djl_python import Input
from djl_python import Output
from transformers import AutoTokenizer, AutoModel, AutoConfig

import json
import logging
import torch

class ChatGLM:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.initialized = False
        self.config = None
        
    def initialize(self, properties: dict):
        model_id_or_path = properties.get("model_id") or properties.get("model_dir")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=True)
        prefix_checkpoint_path = properties.get("prefix_checkpoint_path")
        if prefix_checkpoint_path:
            logging.warning("Loading ptuning checkpoint from: {} with original model: {}".format(prefix_checkpoint_path, model_id_or_path))
            self.config = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=True, pre_seq_len=128)
            self.model = AutoModel.from_pretrained(model_id_or_path, config=self.config, trust_remote_code=True)
            prefix_state_dict = torch.load(prefix_checkpoint_path)
            new_prefix_state_dict = {}
            for k, v in prefix_state_dict.items():
                if k.startswith("transformer.prefix_encoder."):
                    new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
            self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
        else:
            logging.warning("Loading original model from: {}".format(model_id_or_path))
            self.model = AutoModel.from_pretrained(model_id_or_path, trust_remote_code=True)
        self.model = self.model.cuda().eval()
        self.initialized = True
        
    def get_stream_generator(self, inputs):
        current_length = 0
        past_key_values, history = None, []
        prompt = inputs.get_as_string()
        content_type = inputs.get_property("content-type")
        if content_type == "application/json":
            request = inputs.get_as_json()
            logging.warning("JSON: {}".format(json.dumps(request)))
            prompt = request.pop("inputs")
            history = request.pop("history")
            params = request.pop("parameters", {})
        logging.warning(prompt)
        logging.warning(history)
        for response, history, past_key_values in self.model.stream_chat(self.tokenizer, prompt, history=history,
                                                                    past_key_values=past_key_values,
                                                                    return_past_key_values=True):
            yield response[current_length:]
            current_length = len(response)

_model = ChatGLM()

def handle(inputs: Input):
    if not _model.initialized:
        _model.initialize(inputs.get_properties())
        
    if inputs.is_empty():
        # initialization request
        return None
    
    return Output().add_stream_content(_model.get_stream_generator(inputs))

Writing model.py


In [5]:
import boto3
import sagemaker
from sagemaker import Model, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
bucket = sess.default_bucket()  # bucket to house artifacts

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


In [None]:
# Uncomment the following lines if you want to download ptuning checkpoint from s3
# s3_model_prefix = "chatglm2-6b/checkpoint-3000"
# s3_model_path = sess.download_data(path="./", bucket=bucket, key_prefix=s3_model_prefix)

In [4]:
%%sh
mkdir chatglm2-6b
# Uncomment the following lines if you want to use ptuning checkpoint
# cp pytorch_model.bin chatglm2-6b/
mv serving.properties chatglm2-6b/
# remove the following lines if not needed
mv model.py chatglm2-6b/
# mv requirements.txt chatglm2-6b/
tar czvf chatglm2-6b.tar.gz chatglm2-6b/
rm -rf chatglm2-6b

chatglm2-6b/
chatglm2-6b/pytorch_model.bin
chatglm2-6b/model.py
chatglm2-6b/serving.properties


In [6]:
s3_code_prefix = "chatglm2-6b-lmi/code"
code_artifact = sess.upload_data("chatglm2-6b.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

S3 Code or Model tar ball uploaded to --- > s3://sagemaker-us-east-1-822306801866/chatglm2-6b-lmi/code/chatglm2-6b.tar.gz


## 2. Inititate SageMaker Model and Endpoint

In [7]:
image_uri = sagemaker.image_uris.retrieve(
    framework="djl-deepspeed",
    region=sess.boto_session.region_name,
    version="0.23.0"
)

instance_type = "ml.g5.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("chatglm2-6b")

model = Model(sagemaker_session=sess, 
                image_uri=image_uri, 
                model_data=code_artifact, 
                role=role)

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=900
)

-----------------------------!

## 3. Test inferencing

In [8]:
import io

class LineIterator:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

In [16]:
import json

body = { "inputs": "你爱我吗", "history": [] }
smr = boto3.client(
    "sagemaker-runtime", region_name=region)
resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(body), ContentType="application/json")
event_stream = resp['Body']

for line in LineIterator(event_stream):
    resp = json.loads(line)
    try:
        print(resp.get("outputs"), end='')
    except:
        continue

爱，宝贝，就像流水和阳光，平凡而真实。