# SageMaker Chat Integration with LangGraph

This notebook demonstrates how to invoke a DeepSeek SageMaker endpoint using LangGraph for stateful chat management.

In [None]:
# Install dependencies from requirements.txt
!pip install -r requirements.txt -q
# Install sagemaker separately to avoid strict numpy pinning on Python 3.13
!pip install 'sagemaker==2.251.1' --no-deps

In [None]:
import sagemaker
import boto3
import json
from sagemaker.predictor import retrieve_default
from typing import Annotated, TypedDict, List, Dict, Any
from langgraph.graph import StateGraph, START, END

# Configuration
PROFILE_NAME = 'default'
ENDPOINT_NAME = "jumpstart-dft-deepseek-llm-r1-disti-20251206-053241"

try:
    REGION_NAME = boto3.Session(profile_name=PROFILE_NAME).region_name
except Exception:
    REGION_NAME = boto3.Session().region_name

print(f"Using Endpoint: {ENDPOINT_NAME}")
print(f"Region: {REGION_NAME}")

In [None]:
# Initialize Predictor
# Create a Boto3 session with the specific profile
try:
    boto_session = boto3.Session(profile_name=PROFILE_NAME)
    # Create a SageMaker session using the boto3 session
    sagemaker_session = sagemaker.Session(boto_session=boto_session)
    print(f"Authenticated with profile: {PROFILE_NAME}")
except Exception as e:
    print(f"Failed to use profile {PROFILE_NAME}, falling back to default. Error: {e}")
    sagemaker_session = None

# retrieve_default automatically configures the predictor based on the endpoint
if sagemaker_session:
    predictor = retrieve_default(ENDPOINT_NAME, sagemaker_session=sagemaker_session)
else:
    predictor = retrieve_default(ENDPOINT_NAME)

In [None]:
# Define Graph State
class State(TypedDict):
    # Messages list stores the conversation history
    messages: List[Dict[str, str]]

# Define the Chat Node
def call_model(state: State):
    print("Invoking model...")
    messages = state["messages"]
    
    # Prepare payload for DeepSeek model
    payload = {
        "messages": messages,
        "max_tokens": 512,  # Adjust token limit as needed
        "temperature": 0.7,
        "top_p": 0.9
    }
    
    try:
        # Invoke endpoint
        response = predictor.predict(payload)
        
        # Parse response
        if isinstance(response, bytes):
            response_data = json.loads(response.decode('utf-8'))
        else:
            response_data = response
            
        # print(f"Raw Response: {response_data}") # Uncomment for debugging
        
        # Extract content
        content = ""
        if isinstance(response_data, dict) and 'choices' in response_data:
             content = response_data['choices'][0]['message']['content']
        elif isinstance(response_data, list) and len(response_data) > 0 and 'generated_text' in response_data[0]:
             content = response_data[0]['generated_text']
        else:
             # Fallback
             content = str(response_data)
             
        # Create assistant message
        assistant_message = {"role": "assistant", "content": content}
        
        # Return updated state
        return {"messages": messages + [assistant_message]}
        
    except Exception as e:
        print(f"Error invoking endpoint: {e}")
        return {"messages": messages + [{"role": "assistant", "content": "Error: Failed to get response."}]}

# Build the Graph
workflow = StateGraph(State)

# Add nodes
workflow.add_node("deepseek_agent", call_model)

# Add edges
workflow.add_edge(START, "deepseek_agent")
workflow.add_edge("deepseek_agent", END)

# Compile the graph
app = workflow.compile()

In [None]:
# Test the Graph with a single turn
initial_state = {
    "messages": [
        {"role": "user", "content": "Create a Flappy Bird game in Python."}
    ]
}

output = app.invoke(initial_state)
print("\n--- Final Output ---")
print(output["messages"][-1]["content"])

In [None]:
# Interactive Chat Function
def chat_session():
    print("Starting chat session. Type 'quit' to exit.")
    conversation_history = []
    
    while True:
        user_input = input("User: ")
        if user_input.lower() in ["quit", "exit"]:
            break
            
        conversation_history.append({"role": "user", "content": user_input})
        
        # Run graph
        result = app.invoke({"messages": conversation_history})
        
        # Update history with the result
        conversation_history = result["messages"]
        
        print(f"Assistant: {conversation_history[-1]['content']}")

# chat_session() # Uncomment to run