In [None]:
import os
import boto3
import json
import time
from statistics import mean, stdev
from config import SERVERLESS_CONFIG, ENDPOINT_NAME


endpoint_name = ENDPOINT_NAME
client = boto3.client("sagemaker-runtime")

# Single sample document for both warmup and multiplication
base_doc = """Machine learning is a subset of artificial intelligence that enables systems to learn and improve from experience without being explicitly programmed. 
It uses algorithms and statistical models to analyze and draw inferences from patterns in data."""


def invoke_endpoint(payload):
    start_time = time.time()
    response = client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType="application/json",
        Body=json.dumps(payload),
    )
    return time.time() - start_time, json.loads(response["Body"].read().decode("utf-8"))


def measure_latency(num_docs, num_iterations=2):
    # Warmup - serverless endpoints need to cold start, so we send a test request to wake it up
    invoke_endpoint(
        {
            "query": "test",
            "docs": [base_doc],
            "doc_ids": ["doc1"],
            "k": 1,
        }
    )
    time.sleep(15)

    # Create test payload
    payload = {
        "query": "What is machine learning?",
        "docs": [base_doc] * num_docs,
        "doc_ids": [f"doc{i+1}" for i in range(num_docs)],
        "k": num_docs,
    }

    # Measure latencies sequentially
    latencies = []
    for _ in range(num_iterations):
        latency, _ = invoke_endpoint(payload)
        latencies.append(latency)

    return {
        "mean": mean(latencies),
        "std": stdev(latencies) if len(latencies) > 1 else 0,
        "min": min(latencies),
        "max": max(latencies),
        "serverless_config": SERVERLESS_CONFIG,
    }


doc_counts = [10, 50, 100, 200, 500, 1000]
results = []

# Create test_results directory if it doesn't exist
os.makedirs("test_results", exist_ok=True)

# Create log file with timestamp
timestamp = time.strftime("%Y%m%d-%H%M%S")
log_file = os.path.join("test_results", f"latency_test_{timestamp}.txt")

with open(log_file, "w") as f:
    # Log serverless configuration
    f.write("Serverless Configuration:\n")
    f.write(f"Memory Size: {SERVERLESS_CONFIG['memory_size_in_mb']} MB\n")
    f.write(f"Max Concurrency: {SERVERLESS_CONFIG['max_concurrency']}\n")
    f.write("\nTest Results:\n")

    for num_docs in doc_counts:
        print(f"\nTesting with {num_docs} documents...")
        result = measure_latency(num_docs)
        results.append(result)

        # Log results to both console and file
        result_str = (
            f"Documents: {num_docs}\n"
            f"Mean: {result['mean']:.3f}s\n"
            f"Std: {result['std']:.3f}s\n"
            f"Min: {result['min']:.3f}s\n"
            f"Max: {result['max']:.3f}s\n"
            f"-------------------\n"
        )

        print(result_str)
        f.write(result_str)

print(f"\nResults have been logged to {log_file}")