# SageMaker Endpoint Inference with Streaming Support

This notebook demonstrates how to interact with a SageMaker endpoint for text generation, including both regular and streaming inference capabilities.

## Setup and Dependencies

First, we import the necessary libraries and set up our AWS configuration.

In [1]:
import os
import json
import boto3

## AWS Configuration

Here we configure our AWS session using a specific profile and region. This setup allows us to interact with SageMaker services.

In [13]:
profile_name = 'dev'
region_name = "us-east-1"
ENDPOINT_NAME = "webinar-model-endpoint-VLLM-1747432398"

# Create a Boto3 session using the specified profile
boto3_session = boto3.Session(profile_name=profile_name, region_name=region_name)

# Initialize the SageMaker session with the Boto3 session
sagemaker_runtime = boto3_session.client('sagemaker-runtime', region_name='us-east-1')

## Basic Inference Example

This section demonstrates a simple chat completion request to the model. We'll:
1. Set up the initial messages
2. Configure the generation parameters
3. Make the request to the endpoint
4. Process the response

In [5]:
messages = [
    {"role": "system", "content": "please be helpful"},
    {"role": "user", "content": "hello"}
]

payload = {
    "messages": messages,
    "parameters": {
        "do_sample": True,
        "max_new_tokens": 1024,
        "temperature": 0.2
    }
}

response = sagemaker_runtime.invoke_endpoint(
    EndpointName=ENDPOINT_NAME,
    ContentType='application/json',
    Body=json.dumps(payload)
)
 
result = response['Body'].read().decode('utf-8')
result = json.loads(result)
result

{'id': 'chatcmpl-140571528865312',
 'object': 'chat.completion',
 'created': 1747433007,
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': "Hello! How can I assist you today? Do you have a specific question or topic you'd like to discuss, or are you looking for general information"},
   'logprobs': None,
   'finish_reason': 'length'}],
 'usage': {'prompt_tokens': 20, 'completion_tokens': 30, 'total_tokens': 50}}

## Conversation Continuation

This example shows how to maintain a conversation by:
1. Appending the previous response to the message history
2. Adding a new user message
3. Making another request to continue the conversation

In [10]:
messages.append(result['choices'][0]['message'])
messages.append({"role": "user", "content": "thank you!"})

payload = {
    "messages": messages,
    "parameters": {
        "do_sample": True,
        "max_new_tokens": 1024,
        "temperature": 0.2
    }
}

response = sagemaker_runtime.invoke_endpoint(
    EndpointName=ENDPOINT_NAME,
    ContentType='application/json',
    Body=json.dumps(payload)
)
 
result = response['Body'].read().decode('utf-8')
result = json.loads(result)
result

{'id': 'chatcmpl-140668638153600',
 'object': 'chat.completion',
 'created': 1747433104,
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': 'You\'re welcome! It was a pleasure to chat with you, even if it was just a simple "hello"! If you ever need any help'},
   'logprobs': None,
   'finish_reason': 'length'}],
 'usage': {'prompt_tokens': 63, 'completion_tokens': 30, 'total_tokens': 93}}

## Streaming Support

The `LineIterator` class is a utility for handling streaming responses from the SageMaker endpoint. It:
- Parses byte stream input
- Handles partial JSON objects that might be split across multiple events
- Maintains a buffer to ensure complete messages are processed
- Provides an iterator interface for easy consumption of the stream

In [11]:
import io
class LineIterator:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

## Streaming Inference Example

This final example demonstrates how to use streaming inference, which:
- Provides real-time token generation
- Allows for immediate display of model outputs
- Uses the same message format but with streaming enabled
- Processes the response stream token by token

In [14]:
messages = [
    {"role": "system", "content": "please be helpful"},
    {"role": "user", "content": "hello"}
]

payload = {
    "messages": messages,
    "parameters": {
        "do_sample": True,
        "max_new_tokens": 2048,
        "temperature": 0.2,
        "stop": [ "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>" ]
    },
    "stream": True
}

response = sagemaker_runtime.invoke_endpoint_with_response_stream(
    EndpointName=ENDPOINT_NAME, 
    Body=json.dumps(payload), 
    ContentType="application/json"
)
event_stream = response['Body']

for line in LineIterator(event_stream):
    resp = json.loads(line)
    print(resp.get("choices")[0].get('delta').get('content'), end='')

Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat? I'm here to assist