In [16]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs

In [17]:
##update region based on where you are deploying
container_uri = '763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124'

In [18]:
instance_type = "ml.p5en.48xlarge"

In [19]:
%%writefile serving.properties
engine=Python
option.quantize=fp8
option.dtype=fp16
option.trust_remote_code=True
option.tensor_parallel_degree=max
option.gpu_memory_utilization=.87
option.max_model_len=29392
option.model_id=deepseek-ai/DeepSeek-R1
option.max_rolling_batch_size=2
option.rolling_batch=vllm

Writing serving.properties


Try vllm==0.7.0



In [20]:
%%writefile requirements.txt
vllm==0.7.0

Writing requirements.txt


In [21]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
mv requirements.txt mymodel/
tar czvf mymodel.tar.gz mymodel/

mkdir: cannot create directory ‘mymodel’: File exists


mymodel/
mymodel/serving.properties
mymodel/requirements.txt


In [22]:
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)

In [23]:
model = Model(image_uri=container_uri,
              model_data=code_artifact,
              role=role,)

In [24]:
endpoint_name = sagemaker.utils.name_from_base("DeepSeek-R1")

In [None]:
model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout = 2800
)

-----------------------------------------------------------------------------------------

In [15]:
import io
import json
import time
import boto3
from IPython.display import clear_output

# SageMaker Runtime client
smr_client = boto3.client("sagemaker-runtime")

class LineIterator:
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord("\n"):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if "PayloadPart" not in chunk:
                print("Unknown event type:" + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk["PayloadPart"]["Bytes"])

def format_deepseek_chat_template(user_input, chat_history=None):
    """
    Format input according to DeepSeek R1 chat template
    
    Args:
    - user_input (str): Current user message
    - chat_history (list, optional): Previous conversation turns
    
    Returns:
    - str: Formatted chat input with special tokens
    """
    # Start with the beginning of sentence token
    formatted_input = "<｜begin▁of▁sentence｜>"
    
    # Add chat history if provided
    if chat_history:
        for turn in chat_history:
            formatted_input += f"<｜User｜>{turn['user']}<｜Assistant｜>{turn['assistant']}"
    
    # Add current user input
    formatted_input += f"<｜User｜>{user_input}<｜Assistant｜>"
    
    return formatted_input

def stream_chat_response(endpoint_name, inputs, max_new_tokens=8192):
    # Format the input using the DeepSeek chat template
    formatted_inputs = format_deepseek_chat_template(inputs)
    
    body = {
        "inputs": formatted_inputs,
        "parameters": {
            "max_new_tokens": max_new_tokens,
            "do_sample": True,
        },
        "stream": True,
    }

    resp = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(body),
        ContentType="application/json",
    )

    event_stream = resp["Body"]
    start_json = b"{"
    full_response = ""
    start_time = time.time()
    token_count = 0

    for line in LineIterator(event_stream):
        if line != b"" and start_json in line:
            data = json.loads(line[line.find(start_json):].decode("utf-8"))
            token_text = data["token"]["text"]
            full_response += token_text
            token_count += 1

            # Calculate tokens per second
            elapsed_time = time.time() - start_time
            tps = token_count / elapsed_time if elapsed_time > 0 else 0

            # Clear the output and reprint everything
            clear_output(wait=True)
            print("Bot:", full_response)
            print(f"\nTokens per Second: {tps:.2f}", end="")

    print("\n") # Add a newline after response is complete
    return full_response

def chat(endpoint_name):
    print("Welcome to the SageMaker Streaming Chat! Type 'exit' to quit.")
    chat_history = []
    while True:
        user_input = input("\nYou: ")
        if user_input.lower() == "exit":
            break
        bot_response = stream_chat_response(endpoint_name, user_input)
        
        # Update chat history
        chat_history.append({
            'user': user_input,
            'assistant': bot_response
        })

# Replace with your SageMaker endpoint name
endpoint_name = "DeepSeek-R1-2025-02-04-23-24-01-082"

# Start the chat
chat(endpoint_name)

Bot: <think>

</think>

Greetings! I'm DeepSeek-R1, an artificial intelligence assistant created by DeepSeek. I'm at your service and would be delighted to assist you with any inquiries or tasks you may have.

Tokens per Second: 23.16



In [None]:
"""Write python code that can call and stream from a SageMaker Real time endpoint hosting an llm. This will be run in a jupyter notebook cell, and provide a chat experience in the std out of the notebook cell for a user to chat back and forth with a reasoning model. There should be code to support constantly showing the tps of a given stream back from the llm as it typewriters out the tokens to the stdout in the cell. This should be situated at the top of the stdout in refresh to avoid preventing strange behavior. Include the deepseek r1 chat template to format user requests with. The model takes raw input so the chat template will need to be applied."""