# ONNX Model Server Example Usage

This notebook demonstrates how to interact with the ONNX model serving API.

In [None]:
import requests
import json
import numpy as np
from typing import List, Dict, Any

# API base URL
BASE_URL = "http://localhost:8000"

print("ONNX Model Server API Demo")
print("=" * 30)

## 1. Health Check

In [None]:
# Check if the service is healthy
response = requests.get(f"{BASE_URL}/")
health_info = response.json()

print("Health Check:")
print(json.dumps(health_info, indent=2))

## 2. Model Information

In [None]:
# Get detailed model information
response = requests.get(f"{BASE_URL}/model/info")
model_info = response.json()

print("Model Information:")
print(json.dumps(model_info, indent=2))

## 3. Get Sample Data

In [None]:
# Get sample input data
response = requests.get(f"{BASE_URL}/sample")
sample_data = response.json()

print("Sample Data:")
print(json.dumps(sample_data, indent=2))

# Extract test input for predictions
test_input = sample_data.get("test_input", [])
expected_output = sample_data.get("expected_output", [])

## 4. Make Predictions

In [None]:
# Make predictions using sample data
if test_input:
    prediction_request = {
        "data": test_input
    }
    
    response = requests.post(
        f"{BASE_URL}/predict",
        json=prediction_request,
        headers={"Content-Type": "application/json"}
    )
    
    if response.status_code == 200:
        predictions = response.json()
        
        print("Predictions:")
        print(json.dumps(predictions, indent=2))
        
        # Compare with expected output
        if expected_output:
            print("\nComparison with Expected Output:")
            for i, (pred, expected) in enumerate(zip(predictions["predictions"], expected_output)):
                status = "✅" if pred == expected else "❌"
                print(f"Sample {i+1}: Predicted={pred}, Expected={expected} {status}")
    else:
        print(f"Error: {response.status_code} - {response.text}")
else:
    print("No test input available")

## 5. Custom Prediction

In [None]:
# Create custom input data
custom_input = np.random.randn(3, 10).tolist()  # 3 samples with 10 features each

custom_request = {
    "data": custom_input
}

response = requests.post(
    f"{BASE_URL}/predict",
    json=custom_request,
    headers={"Content-Type": "application/json"}
)

if response.status_code == 200:
    custom_predictions = response.json()
    
    print("Custom Predictions:")
    print(f"Input shape: {np.array(custom_input).shape}")
    print(f"Predictions: {custom_predictions['predictions']}")
    
    if custom_predictions['probabilities']:
        print("\nClass Probabilities:")
        for i, probs in enumerate(custom_predictions['probabilities']):
            print(f"Sample {i+1}: {[f'{p:.3f}' for p in probs]}")
else:
    print(f"Error: {response.status_code} - {response.text}")

## 6. Performance Test

In [None]:
import time

# Performance test with multiple requests
n_requests = 10
batch_size = 5

print(f"Performance Test: {n_requests} requests with {batch_size} samples each")
print("-" * 50)

times = []
for i in range(n_requests):
    # Generate random input
    random_input = np.random.randn(batch_size, 10).tolist()
    
    request_data = {"data": random_input}
    
    # Measure request time
    start_time = time.time()
    response = requests.post(
        f"{BASE_URL}/predict",
        json=request_data,
        headers={"Content-Type": "application/json"}
    )
    end_time = time.time()
    
    request_time = (end_time - start_time) * 1000  # Convert to ms
    times.append(request_time)
    
    if response.status_code == 200:
        print(f"Request {i+1}: {request_time:.2f}ms ✅")
    else:
        print(f"Request {i+1}: {request_time:.2f}ms ❌ ({response.status_code})")

# Calculate statistics
avg_time = np.mean(times)
min_time = np.min(times)
max_time = np.max(times)
std_time = np.std(times)

print("\nPerformance Summary:")
print(f"Average response time: {avg_time:.2f}ms")
print(f"Min response time: {min_time:.2f}ms")
print(f"Max response time: {max_time:.2f}ms")
print(f"Standard deviation: {std_time:.2f}ms")
print(f"Requests per second: ~{1000/avg_time:.1f} RPS")