# 🔮 Inference with AWS SageMaker Endpoint

**Model**: `openai/gpt-oss-120b`  
**AWS Profile**: `default`

This notebook demonstrates how to perform inference on a deployed SageMaker endpoint.

---

## 📋 Table of Contents

1. [Prerequisites & Setup](#1-prerequisites--setup)
2. [Connect to Endpoint](#2-connect-to-endpoint)
3. [Basic Inference](#3-basic-inference)
4. [Advanced Inference Options](#4-advanced-inference-options)
5. [Batch Inference](#5-batch-inference)
6. [Performance Testing](#6-performance-testing)
7. [Endpoint Monitoring](#7-endpoint-monitoring)
8. [Cleanup](#8-cleanup)

---

## 1. Prerequisites & Setup

Install and import required libraries.

In [None]:
# Step 1.1: Install required packages
!pip install sagemaker boto3 pandas matplotlib tqdm

print("\n✅ Packages installed successfully!")

In [None]:
# Step 1.2: Import required libraries
import os
import json
import time
import datetime
from typing import List, Dict, Any, Optional

import boto3
import sagemaker
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

print(f"✅ Libraries imported successfully!")
print(f"📦 SageMaker version: {sagemaker.__version__}")

---

## 2. Connect to Endpoint

Configure AWS credentials and connect to the deployed SageMaker endpoint.

In [None]:
# Step 2.1: Configure AWS Profile

# ============================================
# 📌 CONFIGURATION - EDIT HERE
# ============================================

AWS_PROFILE = "default"

# Set the endpoint name from deployment notebook
# Option 1: Manually set the endpoint name
ENDPOINT_NAME = None  # Set this if you know your endpoint name, e.g., "openai-gpt-oss-120b-20231204-121500"

# Option 2: Load from saved configuration file
CONFIG_FILE = "endpoint_config.json"

# ============================================

# Set environment variables
os.environ['AWS_PROFILE'] = AWS_PROFILE
os.environ['AWS_DEFAULT_PROFILE'] = AWS_PROFILE

print(f"🔐 Using AWS Profile: {AWS_PROFILE}")

In [None]:
# Step 2.2: Verify AWS credentials and create session

try:
    # Create session with profile
    boto_session = boto3.Session(profile_name=AWS_PROFILE)
    sts_client = boto_session.client('sts')
    
    # Get caller identity to verify credentials
    identity = sts_client.get_caller_identity()
    
    print("✅ AWS credentials verified successfully!")
    print(f"   👤 Account ID: {identity['Account']}")
    print(f"   🆔 User ARN: {identity['Arn']}")
    
    # Create SageMaker session
    sagemaker_session = sagemaker.Session(boto_session=boto_session)
    region = sagemaker_session.boto_region_name
    print(f"   🌍 AWS Region: {region}")
    
except Exception as e:
    print(f"❌ Error verifying AWS credentials: {e}")
    raise

In [None]:
# Step 2.3: Get endpoint name

# Load from config file if not manually set
if ENDPOINT_NAME is None:
    if os.path.exists(CONFIG_FILE):
        with open(CONFIG_FILE, "r") as f:
            config = json.load(f)
        ENDPOINT_NAME = config.get("endpoint_name")
        print(f"📄 Loaded endpoint from config: {ENDPOINT_NAME}")
    else:
        # List available endpoints and let user choose
        print("⚠️ No endpoint name configured. Listing available endpoints...")
        print()
        
        sagemaker_client = boto_session.client('sagemaker')
        response = sagemaker_client.list_endpoints(
            SortBy='CreationTime',
            SortOrder='Descending',
            MaxResults=10,
            StatusEquals='InService'
        )
        
        endpoints = response.get('Endpoints', [])
        
        if endpoints:
            print("📋 Available endpoints (InService):")
            for i, ep in enumerate(endpoints, 1):
                print(f"   {i}. {ep['EndpointName']} (Created: {ep['CreationTime']})")
            
            # Use the most recent one by default
            ENDPOINT_NAME = endpoints[0]['EndpointName']
            print(f"\n🎯 Using most recent endpoint: {ENDPOINT_NAME}")
            print("   (Change ENDPOINT_NAME variable above if needed)")
        else:
            print("❌ No InService endpoints found!")
            print("   Please run the deployment notebook first.")
else:
    print(f"📛 Using configured endpoint: {ENDPOINT_NAME}")

In [None]:
# Step 2.4: Create predictor client

try:
    # Verify endpoint exists and is InService
    sagemaker_client = boto_session.client('sagemaker')
    endpoint_info = sagemaker_client.describe_endpoint(EndpointName=ENDPOINT_NAME)
    
    status = endpoint_info['EndpointStatus']
    
    if status != 'InService':
        print(f"⚠️ Endpoint status: {status}")
        print("   Waiting for endpoint to be InService...")
        
        # Wait for endpoint to be ready
        waiter = sagemaker_client.get_waiter('endpoint_in_service')
        waiter.wait(
            EndpointName=ENDPOINT_NAME,
            WaiterConfig={'Delay': 30, 'MaxAttempts': 60}
        )
        print("   ✅ Endpoint is now InService!")
    
    # Create predictor
    predictor = Predictor(
        endpoint_name=ENDPOINT_NAME,
        sagemaker_session=sagemaker_session,
        serializer=JSONSerializer(),
        deserializer=JSONDeserializer()
    )
    
    print(f"\n✅ Connected to endpoint: {ENDPOINT_NAME}")
    print(f"   📊 Status: {status}")
    print(f"   📅 Created: {endpoint_info['CreationTime']}")
    
except Exception as e:
    print(f"❌ Error connecting to endpoint: {e}")
    raise

---

## 3. Basic Inference

Send simple requests to the endpoint and get responses.

In [None]:
# Step 3.1: Simple text generation

def generate_text(prompt: str, max_new_tokens: int = 256, **kwargs) -> Dict[str, Any]:
    """
    Generate text using the deployed model.
    
    Args:
        prompt: Input text prompt
        max_new_tokens: Maximum tokens to generate
        **kwargs: Additional generation parameters
        
    Returns:
        dict: Response from the model
    """
    payload = {
        "inputs": prompt,
        "max_new_tokens": max_new_tokens,
        **kwargs
    }
    
    start_time = time.time()
    response = predictor.predict(payload)
    latency = time.time() - start_time
    
    return {
        "response": response,
        "latency_seconds": latency
    }


# Test with a simple prompt
print("🧪 Testing basic inference...")
print()

test_prompt = "Hello! Can you tell me a short story about a brave robot?"

result = generate_text(test_prompt, max_new_tokens=200)

print(f"📝 Prompt: {test_prompt}")
print()
print(f"💬 Response:")

if isinstance(result["response"], dict):
    generated_text = result["response"].get("generated_text", result["response"])
    if isinstance(generated_text, list):
        for text in generated_text:
            print(f"   {text}")
    else:
        print(f"   {generated_text}")
else:
    print(f"   {result['response']}")

print()
print(f"⏱️ Latency: {result['latency_seconds']:.2f} seconds")

In [None]:
# Step 3.2: Multiple prompts

prompts = [
    "What is the capital of France?",
    "Explain machine learning in simple terms.",
    "Write a haiku about programming.",
    "What are the benefits of cloud computing?"
]

print("🧪 Testing with multiple prompts...")
print("="*60)

for i, prompt in enumerate(prompts, 1):
    print(f"\n📌 Test {i}/{len(prompts)}")
    print(f"📝 Prompt: {prompt}")
    print()
    
    result = generate_text(prompt, max_new_tokens=150)
    
    if isinstance(result["response"], dict):
        generated_text = result["response"].get("generated_text", [result["response"]])
        if isinstance(generated_text, list):
            print(f"💬 Response: {generated_text[0]}")
        else:
            print(f"💬 Response: {generated_text}")
    else:
        print(f"💬 Response: {result['response']}")
    
    print(f"⏱️ Latency: {result['latency_seconds']:.2f}s")
    print("-"*60)

print("\n✅ Multiple prompts testing completed!")

---

## 4. Advanced Inference Options

Explore different generation parameters to control the model output.

In [None]:
# Step 4.1: Temperature variations

print("🌡️ Testing different temperature settings...")
print("   (Higher temperature = more creative/random)")
print("="*60)

prompt = "Complete this story: Once upon a time in a magical forest, there lived a"

temperatures = [0.1, 0.5, 0.9]

for temp in temperatures:
    print(f"\n🌡️ Temperature: {temp}")
    print(f"📝 Prompt: {prompt}")
    
    result = generate_text(
        prompt,
        max_new_tokens=100,
        temperature=temp,
        do_sample=True
    )
    
    response = result["response"]
    if isinstance(response, dict):
        text = response.get("generated_text", [str(response)])
        if isinstance(text, list):
            text = text[0]
    else:
        text = str(response)
    
    print(f"💬 Response: {text}")
    print("-"*60)

In [None]:
# Step 4.2: Top-p (nucleus sampling) variations

print("🎯 Testing different top_p settings...")
print("   (Lower top_p = more focused/deterministic)")
print("="*60)

prompt = "The best advice for learning to code is"

top_p_values = [0.5, 0.9, 0.95]

for top_p in top_p_values:
    print(f"\n🎯 Top-p: {top_p}")
    
    result = generate_text(
        prompt,
        max_new_tokens=100,
        temperature=0.7,
        top_p=top_p,
        do_sample=True
    )
    
    response = result["response"]
    if isinstance(response, dict):
        text = response.get("generated_text", [str(response)])
        if isinstance(text, list):
            text = text[0]
    else:
        text = str(response)
    
    print(f"💬 Response: {text}")
    print("-"*60)

In [None]:
# Step 4.3: Custom generation with all parameters

def generate_custom(
    prompt: str,
    max_new_tokens: int = 256,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.1,
    do_sample: bool = True
) -> str:
    """
    Generate text with full control over parameters.
    
    Args:
        prompt: Input text
        max_new_tokens: Max tokens to generate
        temperature: Sampling temperature (0.0-2.0)
        top_p: Nucleus sampling probability
        top_k: Top-k sampling
        repetition_penalty: Penalty for repeated tokens
        do_sample: Enable sampling (vs greedy decoding)
        
    Returns:
        str: Generated text
    """
    payload = {
        "inputs": prompt,
        "max_new_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "do_sample": do_sample
    }
    
    response = predictor.predict(payload)
    
    if isinstance(response, dict):
        text = response.get("generated_text", response)
        if isinstance(text, list):
            return text[0] if text else str(response)
        return str(text)
    return str(response)


# Test with custom parameters
print("⚙️ Custom generation example:")
print()

custom_response = generate_custom(
    prompt="Write a professional email requesting a meeting:",
    max_new_tokens=200,
    temperature=0.5,  # More focused
    top_p=0.85,
    repetition_penalty=1.2  # Reduce repetition
)

print(custom_response)

---

## 5. Batch Inference

Process multiple prompts efficiently.

In [None]:
# Step 5.1: Batch processing function

def batch_inference(
    prompts: List[str],
    max_new_tokens: int = 150,
    show_progress: bool = True
) -> pd.DataFrame:
    """
    Process multiple prompts and return results as DataFrame.
    
    Args:
        prompts: List of input prompts
        max_new_tokens: Max tokens per response
        show_progress: Show progress bar
        
    Returns:
        pd.DataFrame: Results with prompts, responses, and latencies
    """
    results = []
    
    iterator = tqdm(prompts, desc="Processing") if show_progress else prompts
    
    for prompt in iterator:
        try:
            start_time = time.time()
            
            response = predictor.predict({
                "inputs": prompt,
                "max_new_tokens": max_new_tokens
            })
            
            latency = time.time() - start_time
            
            # Extract text from response
            if isinstance(response, dict):
                text = response.get("generated_text", str(response))
                if isinstance(text, list):
                    text = text[0] if text else ""
            else:
                text = str(response)
            
            results.append({
                "prompt": prompt,
                "response": text,
                "latency_seconds": latency,
                "status": "success"
            })
            
        except Exception as e:
            results.append({
                "prompt": prompt,
                "response": None,
                "latency_seconds": None,
                "status": f"error: {str(e)}"
            })
    
    return pd.DataFrame(results)


# Test batch inference
batch_prompts = [
    "What is Python?",
    "Explain REST APIs.",
    "What is Docker?",
    "Define microservices architecture.",
    "What is CI/CD?"
]

print("📦 Running batch inference...")
results_df = batch_inference(batch_prompts)

print("\n📊 Results:")
print(results_df[["prompt", "latency_seconds", "status"]].to_string())

In [None]:
# Step 5.2: Display batch results

print("\n📋 Detailed Batch Results:")
print("="*80)

for idx, row in results_df.iterrows():
    print(f"\n📌 Prompt {idx + 1}: {row['prompt']}")
    print(f"💬 Response: {row['response'][:300] if row['response'] else 'N/A'}...")
    print(f"⏱️ Latency: {row['latency_seconds']:.2f}s" if row['latency_seconds'] else "⏱️ Latency: N/A")
    print("-"*80)

---

## 6. Performance Testing

Measure endpoint performance with multiple requests.

In [None]:
# Step 6.1: Latency test

def run_latency_test(
    prompt: str,
    num_requests: int = 10,
    max_new_tokens: int = 100
) -> Dict[str, float]:
    """
    Run latency test with multiple requests.
    
    Args:
        prompt: Test prompt
        num_requests: Number of requests to make
        max_new_tokens: Tokens per request
        
    Returns:
        dict: Latency statistics
    """
    latencies = []
    
    print(f"🧪 Running {num_requests} requests...")
    
    for i in tqdm(range(num_requests)):
        start_time = time.time()
        
        predictor.predict({
            "inputs": prompt,
            "max_new_tokens": max_new_tokens
        })
        
        latency = time.time() - start_time
        latencies.append(latency)
    
    return {
        "min": min(latencies),
        "max": max(latencies),
        "avg": sum(latencies) / len(latencies),
        "p50": sorted(latencies)[len(latencies) // 2],
        "p90": sorted(latencies)[int(len(latencies) * 0.9)],
        "p99": sorted(latencies)[int(len(latencies) * 0.99)] if len(latencies) >= 100 else max(latencies),
        "all_latencies": latencies
    }


# Run latency test
test_prompt = "Explain the concept of cloud computing in one paragraph."
latency_stats = run_latency_test(test_prompt, num_requests=10)

print("\n📊 Latency Statistics:")
print(f"   ⬇️ Min: {latency_stats['min']:.2f}s")
print(f"   ⬆️ Max: {latency_stats['max']:.2f}s")
print(f"   📊 Avg: {latency_stats['avg']:.2f}s")
print(f"   📏 P50: {latency_stats['p50']:.2f}s")
print(f"   📏 P90: {latency_stats['p90']:.2f}s")

In [None]:
# Step 6.2: Visualize latency distribution

plt.figure(figsize=(12, 5))

# Histogram
plt.subplot(1, 2, 1)
plt.hist(latency_stats['all_latencies'], bins=10, edgecolor='black', alpha=0.7)
plt.axvline(latency_stats['avg'], color='r', linestyle='--', label=f'Mean: {latency_stats["avg"]:.2f}s')
plt.xlabel('Latency (seconds)')
plt.ylabel('Frequency')
plt.title('Latency Distribution')
plt.legend()

# Line plot
plt.subplot(1, 2, 2)
plt.plot(range(1, len(latency_stats['all_latencies']) + 1), latency_stats['all_latencies'], 'b-o')
plt.axhline(latency_stats['avg'], color='r', linestyle='--', label=f'Mean: {latency_stats["avg"]:.2f}s')
plt.xlabel('Request #')
plt.ylabel('Latency (seconds)')
plt.title('Latency Over Time')
plt.legend()

plt.tight_layout()
plt.savefig('latency_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n📊 Chart saved to: latency_analysis.png")

In [None]:
# Step 6.3: Token throughput test

print("🔄 Testing token throughput at different lengths...")
print()

token_lengths = [50, 100, 200, 300]
throughput_results = []

for max_tokens in token_lengths:
    start_time = time.time()
    
    response = predictor.predict({
        "inputs": "Write a detailed explanation of artificial intelligence.",
        "max_new_tokens": max_tokens
    })
    
    latency = time.time() - start_time
    
    # Estimate tokens per second
    tokens_per_second = max_tokens / latency
    
    throughput_results.append({
        "max_tokens": max_tokens,
        "latency": latency,
        "tokens_per_second": tokens_per_second
    })
    
    print(f"   📏 {max_tokens} tokens: {latency:.2f}s ({tokens_per_second:.1f} tokens/sec)")

# Create throughput DataFrame
throughput_df = pd.DataFrame(throughput_results)
print("\n📊 Throughput Summary:")
print(throughput_df.to_string(index=False))

---

## 7. Endpoint Monitoring

Monitor endpoint metrics and CloudWatch statistics.

In [None]:
# Step 7.1: Get endpoint details

sagemaker_client = boto_session.client('sagemaker')

try:
    endpoint_response = sagemaker_client.describe_endpoint(EndpointName=ENDPOINT_NAME)
    endpoint_config = sagemaker_client.describe_endpoint_config(
        EndpointConfigName=endpoint_response['EndpointConfigName']
    )
    
    print("📊 Endpoint Details:")
    print(f"   📛 Name: {endpoint_response['EndpointName']}")
    print(f"   🔄 Status: {endpoint_response['EndpointStatus']}")
    print(f"   📅 Created: {endpoint_response['CreationTime']}")
    print(f"   🔄 Last Modified: {endpoint_response['LastModifiedTime']}")
    print()
    
    print("⚙️ Configuration:")
    for variant in endpoint_config['ProductionVariants']:
        print(f"   💻 Instance Type: {variant['InstanceType']}")
        print(f"   🔢 Instance Count: {variant.get('InitialInstanceCount', 'N/A')}")
        print(f"   📊 Variant Name: {variant['VariantName']}")
        
except Exception as e:
    print(f"❌ Error getting endpoint details: {e}")

In [None]:
# Step 7.2: CloudWatch metrics

cloudwatch = boto_session.client('cloudwatch')

# Define time range (last 1 hour)
end_time = datetime.datetime.utcnow()
start_time = end_time - datetime.timedelta(hours=1)

metrics = [
    ('Invocations', 'Sum'),
    ('ModelLatency', 'Average'),
    ('OverheadLatency', 'Average'),
    ('Invocation4XXErrors', 'Sum'),
    ('Invocation5XXErrors', 'Sum')
]

print("📊 CloudWatch Metrics (Last Hour):")
print("="*60)

for metric_name, stat in metrics:
    try:
        response = cloudwatch.get_metric_statistics(
            Namespace='AWS/SageMaker',
            MetricName=metric_name,
            Dimensions=[
                {'Name': 'EndpointName', 'Value': ENDPOINT_NAME},
                {'Name': 'VariantName', 'Value': 'AllTraffic'}
            ],
            StartTime=start_time,
            EndTime=end_time,
            Period=300,
            Statistics=[stat]
        )
        
        datapoints = response.get('Datapoints', [])
        if datapoints:
            value = sum(dp[stat] for dp in datapoints)
            if 'Latency' in metric_name:
                value = value / len(datapoints) / 1000  # Convert to seconds
                print(f"   {metric_name}: {value:.3f}s")
            else:
                print(f"   {metric_name}: {value:.0f}")
        else:
            print(f"   {metric_name}: No data")
            
    except Exception as e:
        print(f"   {metric_name}: Error - {e}")

print("="*60)

---

## 8. Cleanup

Delete the endpoint when you're done to avoid unnecessary costs.

> **⚠️ Warning**: Only run this when you no longer need the endpoint!

In [None]:
# Step 8.1: List all endpoints

print("📋 Active SageMaker Endpoints:")
print("="*60)

response = sagemaker_client.list_endpoints(
    SortBy='CreationTime',
    SortOrder='Descending',
    MaxResults=20
)

for endpoint in response['Endpoints']:
    print(f"   📛 {endpoint['EndpointName']}")
    print(f"      Status: {endpoint['EndpointStatus']}")
    print(f"      Created: {endpoint['CreationTime']}")
    print()

In [None]:
# Step 8.2: Delete endpoint function

def delete_endpoint(endpoint_name: str, force: bool = False) -> bool:
    """
    Delete a SageMaker endpoint.
    
    Args:
        endpoint_name: Name of the endpoint to delete
        force: Skip confirmation prompt
        
    Returns:
        bool: True if deleted successfully
    """
    if not force:
        confirm = input(f"⚠️ Delete endpoint '{endpoint_name}'? (yes/no): ")
        if confirm.lower() != 'yes':
            print("❌ Deletion cancelled.")
            return False
    
    try:
        # Delete endpoint
        sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
        print(f"✅ Endpoint '{endpoint_name}' deletion initiated.")
        
        # Get endpoint config name
        try:
            config_name = f"{endpoint_name}-config"
            sagemaker_client.delete_endpoint_config(EndpointConfigName=config_name)
            print(f"✅ Endpoint config '{config_name}' deleted.")
        except:
            pass  # Config might have different name
        
        return True
        
    except Exception as e:
        print(f"❌ Error deleting endpoint: {e}")
        return False


print("🗑️ Delete Endpoint Function Ready")
print(f"   To delete: delete_endpoint('{ENDPOINT_NAME}')")

In [None]:
# ⚠️ UNCOMMENT THE LINE BELOW TO DELETE THE ENDPOINT
# Only run this when you're COMPLETELY done with the endpoint!

# delete_endpoint(ENDPOINT_NAME)

# Or to skip confirmation:
# delete_endpoint(ENDPOINT_NAME, force=True)

---

## 📝 Summary

This notebook covered:

1. **Setup**: Installing packages and connecting to AWS
2. **Connection**: Connecting to the deployed SageMaker endpoint
3. **Basic Inference**: Simple text generation
4. **Advanced Options**: Temperature, top-p, and other parameters
5. **Batch Processing**: Processing multiple prompts efficiently
6. **Performance Testing**: Latency and throughput analysis
7. **Monitoring**: CloudWatch metrics and endpoint details
8. **Cleanup**: Deleting endpoints to save costs

### Quick Reference

```python
# Simple inference
response = predictor.predict({
    "inputs": "Your prompt here",
    "max_new_tokens": 256,
    "temperature": 0.7,
    "top_p": 0.9
})
```

### Useful Links

- [SageMaker Python SDK](https://sagemaker.readthedocs.io/)
- [Hugging Face on SageMaker](https://huggingface.co/docs/sagemaker/)
- [AWS SageMaker Documentation](https://docs.aws.amazon.com/sagemaker/)