## Deploy Llama 3.2 through vLLM on SageMaker Endpoint

원본 코드
- https://github.com/aws-samples/aws-ai-ml-workshop-kr/blob/master/genai/aws-gen-ai-kr/40_inference/31-Llama-3-1-Inference/1_llama-3-1-deploy-djl-lmi.ipynb

In [1]:
!pip install --quiet --upgrade sagemaker

### 1. Depoly model on SageMaker

In [2]:
import boto3
import sagemaker
from sagemaker import get_execution_role



sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


In [3]:
role = get_execution_role()
region = boto3.Session().region_name
sagemaker_session = sagemaker.session.Session()
sm_client = boto3.client("sagemaker", region_name=region)
sm_runtime_client = boto3.client("sagemaker-runtime")

#### 1-1. Setup Configuration

In [4]:
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

In [5]:
container_uri = sagemaker.image_uris.retrieve(
    framework="djl-lmi", version="0.29.0", region=region
)
container_uri
# container_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124"

'763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124'

In [11]:
instance_type = "ml.g5.8xlarge"
# instance_type = "ml.g5.12xlarge"
container_startup_health_check_timeout = 600

endpoint_name = sagemaker.utils.name_from_base("Meta-Llama-3-1-8B-Instruct")

print (f'container_uri: {container_uri}')
print (f'container_startup_health_check_timeout: {container_startup_health_check_timeout}')
print (f'instance_type: {instance_type}')
print (f'endpoint_name: {endpoint_name}')

container_uri: 763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124
container_startup_health_check_timeout: 600
instance_type: ml.g5.8xlarge
endpoint_name: Meta-Llama-3-1-8B-Instruct-2024-11-20-17-26-23-851


#### 1-2. Creat model with env variables

사용 가능 옵션
- https://github.com/deepjavalibrary/djl-serving/tree/master/serving/docs/lmi

“vllm” 옵션 선택 시 paged attention은 default 적용.
- vLLM env var. 를 이용해서 Backend 셋팅 가능 (e.g., VLLM_ATTENTION_BACKEND")
- https://docs.vllm.ai/en/latest/serving/env_vars.html



In [12]:
HF_TOKEN = "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"

In [13]:
deploy_env = {
    "HF_MODEL_ID": model_id,
    "OPTION_ROLLING_BATCH": "vllm",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "2",
    "OPTION_DTYPE":"fp16",
    "OPTION_MAX_MODEL_LEN": "1024",
    # "VLLM_ATTENTION_BACKEND": "XFORMERS",
    "HF_TOKEN": HF_TOKEN
}

In [14]:
model = sagemaker.Model(
    image_uri=container_uri, 
    role=role,
    env=deploy_env,
)

#### 1-3. Deploy model

In [15]:
model.deploy(
    instance_type=instance_type,
    initial_instance_count=1,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=container_startup_health_check_timeout
)

### 2. Invocation (Generate Text using the endpoint)

#### 2-1. Get a predictor for your endpoint

In [None]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

#### 2-1. Make a prediction with your endpoint

- Normal Question

In [None]:
outputs = predictor.predict(
    {
        "inputs": "write a quick sort algorithm in python and description",
        "parameters": {"do_sample": True, "max_new_tokens": 2048},
    }
)

print(outputs["generated_text"])

- Chat template
    - [DJL Chat Completions API Schema](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.html)

In [None]:
chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
    {"role": "user", "content": "I'd like to show off how chat templating works! anyway, write a quick sort algorithm in python and description"},
]

result = predictor.predict(
    {"messages": chat, "max_tokens": 1024}
)
result

### 3. Streaming

#### 3-1. Normal Question

In [None]:
import json
import random 

In [None]:
# 다양한 코딩 태스크를 위한 프롬프트 리스트
prompts = [
    "write a quick sort algorithm in python.",
    "Write a Python function to implement a binary search algorithm.",
    "Create a JavaScript function to flatten a nested array.",
    "Implement a simple REST API using Flask in Python.",
    "Write a SQL query to find the top 5 customers by total purchase amount.",
    "Create a React component for a todo list with basic CRUD operations.",
    "Implement a depth-first search algorithm for a graph in C++.",
    "Write a bash script to find and delete files older than 30 days.",
    "Create a Python class to represent a deck of cards with shuffle and deal methods.",
    "Write a regular expression to validate email addresses.",
    "Implement a basic CI/CD pipeline using GitHub Actions."
]

def generate_payload():
    # 랜덤하게 프롬프트 선택
    prompt = random.choice(prompts)
    
    # JSON 페이로드 생성
    body = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 400,
            # "return_full_text": False  # This does not work with Phi3
        },
        "stream": True,
    }
    
    # JSON을 문자열로 변환하고 bytes로 인코딩
    return json.dumps(body).encode('utf-8')

In [None]:
%%time
# Invoke the endpoint
resp = sm_runtime_client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name, 
    Body=generate_payload(), 
    ContentType="application/json"
)
print("Generated response:")
print("-" * 40)

buffer = ""
for event in resp['Body']:
    if 'PayloadPart' in event:
        chunk = event['PayloadPart']['Bytes'].decode()
        buffer += chunk
        try:
            # Try to parse the buffer as JSON
            data = json.loads(buffer)
            if 'token' in data:
                print(data['token']['text'], end='', flush=True)
            buffer = ""  # Clear the buffer after successful parsing
        except json.JSONDecodeError:
            # If parsing fails, keep the buffer for the next iteration
            pass

print("\n" + "-" * 40)

#### 3-2. With chat template

In [None]:
# 다양한 코딩 태스크를 위한 프롬프트 리스트
chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
    {"role": "user", "content": "I'd like to show off how chat templating works! anyway, write a quick sort algorithm in python and description"},
]

result = predictor.predict(
    {"messages": chat, "max_tokens": 1024}
)

def generate_payload():
    # 랜덤하게 프롬프트 선택
    prompt = random.choice(prompts)
    
    # JSON 페이로드 생성
    body = {
        "messages": chat,
        "max_tokens": 1024,
        "stream": True,
    }
    
    # JSON을 문자열로 변환하고 bytes로 인코딩
    return json.dumps(body).encode('utf-8')

In [None]:
%%time
# Invoke the endpoint
resp = sm_runtime_client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name, 
    # Body=json.dumps(body), 
    Body=generate_payload(), 
    
    ContentType="application/json"
)
print("Generated response:")
print("-" * 40)

buffer = ""
for event in resp['Body']:
    if 'PayloadPart' in event:
        chunk = event['PayloadPart']['Bytes'].decode()
        buffer += chunk
        try:
            # Try to parse the buffer as JSON
            data = json.loads(buffer)
            if 'choices' in data:
                print(data['choices'][0]['delta']['content'], end='', flush=True)
            buffer = ""  # Clear the buffer after successful parsing
        except json.JSONDecodeError:
            # If parsing fails, keep the buffer for the next iteration
            pass

print("\n" + "-" * 40)

## 4. SageMaker Endpoint 삭제

In [None]:
# predictor.delete_predictor()