In [None]:
"""
================================================================================
MAYA FALLBACK ENDPOINT - PRODUCTION DEPLOYMENT
================================================================================
Version: 2.0 (Production Grade)
Purpose: WhatsApp HR AI Bot - Fallback endpoint for max 10 users
Instance: ml.g5.2xlarge (1x NVIDIA A10G, 24GB VRAM)
Region: ap-south-1 (Mumbai)

Run cells in order: 1 -> 2 -> 3 -> 4
================================================================================
"""

import boto3
import json
import time
import logging
from datetime import datetime, timedelta
from botocore.exceptions import ClientError, BotoCoreError
from botocore.config import Config as BotoConfig
from typing import Optional, Dict, Any, Tuple
from functools import wraps

# =============================================================================
# LOGGING SETUP (Production-grade)
# =============================================================================

logging.basicConfig(
    level=logging.INFO,
    format='[%(asctime)s] [%(levelname)s] %(message)s',
    datefmt='%H:%M:%S'
)
logger = logging.getLogger("MayaDeployment")

# =============================================================================
# CONFIGURATION
# =============================================================================

class Config:
    """Centralized configuration for Maya fallback endpoint."""
    
    # Endpoint Settings
    ENDPOINT_NAME = "maya-prod-mumbai-1768132592"
    REGION = "ap-south-1"
    VARIANT_NAME = "AllTraffic"
    
    # Instance Configuration
    INSTANCE_TYPE = "ml.g5.2xlarge"
    
    # Scaling Limits
    MIN_INSTANCES = 1
    MAX_INSTANCES = 2
    
    # Scaling Behavior (Conservative for fallback)
    INVOCATIONS_TARGET = 150.0
    SCALE_OUT_COOLDOWN = 300   # 5 min
    SCALE_IN_COOLDOWN = 600    # 10 min
    
    # Model Location
    MODEL_S3_URI = "s3://sagemaker-ap-south-1-937127308917/maya-prod-v1-bf16/model.tar.gz"
    
    # Retry Configuration
    MAX_RETRIES = 3
    RETRY_DELAY_BASE = 1.0  # seconds
    RETRY_DELAY_MAX = 30.0  # seconds
    
    # Timeouts
    API_TIMEOUT_CONNECT = 10  # seconds
    API_TIMEOUT_READ = 30     # seconds

# =============================================================================
# RETRY DECORATOR (Exponential Backoff)
# =============================================================================

def retry_with_backoff(max_retries: int = Config.MAX_RETRIES, 
                       base_delay: float = Config.RETRY_DELAY_BASE,
                       max_delay: float = Config.RETRY_DELAY_MAX):
    """
    Decorator that retries a function with exponential backoff.
    Handles transient AWS failures gracefully.
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            last_exception = None
            
            for attempt in range(max_retries + 1):
                try:
                    return func(*args, **kwargs)
                except (ClientError, BotoCoreError) as e:
                    last_exception = e
                    
                    # Don't retry on permission/validation errors
                    if isinstance(e, ClientError):
                        error_code = e.response.get('Error', {}).get('Code', '')
                        non_retryable = ['ValidationException', 'AccessDeniedException', 
                                        'InvalidParameterException', 'ResourceNotFoundException']
                        if error_code in non_retryable:
                            raise
                    
                    if attempt < max_retries:
                        delay = min(base_delay * (2 ** attempt), max_delay)
                        logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay:.1f}s...")
                        time.sleep(delay)
                    else:
                        logger.error(f"All {max_retries + 1} attempts failed")
                        raise
                        
            return None
        return wrapper
    return decorator

# =============================================================================
# AWS CLIENTS (Thread-safe with retry config)
# =============================================================================

class AWSClients:
    """Production-grade AWS client manager with connection pooling and retries."""
    
    _clients: Dict[str, Any] = {}
    _boto_config = BotoConfig(
        connect_timeout=Config.API_TIMEOUT_CONNECT,
        read_timeout=Config.API_TIMEOUT_READ,
        retries={'max_attempts': 3, 'mode': 'adaptive'}
    )
    
    @classmethod
    def _get_client(cls, service_name: str):
        """Get or create a boto3 client with proper configuration."""
        if service_name not in cls._clients:
            cls._clients[service_name] = boto3.client(
                service_name, 
                region_name=Config.REGION,
                config=cls._boto_config
            )
        return cls._clients[service_name]
    
    @classmethod
    def sagemaker(cls):
        return cls._get_client('sagemaker')
    
    @classmethod
    def runtime(cls):
        return cls._get_client('sagemaker-runtime')
    
    @classmethod
    def autoscaling(cls):
        return cls._get_client('application-autoscaling')
    
    @classmethod
    def cloudwatch(cls):
        return cls._get_client('cloudwatch')
    
    @classmethod
    def reset(cls):
        """Reset all clients (useful for error recovery)."""
        cls._clients.clear()

# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def get_resource_id() -> str:
    """Returns the auto-scaling resource ID for the endpoint."""
    return f"endpoint/{Config.ENDPOINT_NAME}/variant/{Config.VARIANT_NAME}"

def log_success(msg: str):
    logger.info(f"[OK] {msg}")

def log_error(msg: str):
    logger.error(f"[ERR] {msg}")

def log_warn(msg: str):
    logger.warning(f"[!] {msg}")

def log_info(msg: str):
    logger.info(f"[>] {msg}")

# =============================================================================
# ENDPOINT STATUS CHECKER
# =============================================================================

@retry_with_backoff()
def get_endpoint_status() -> Tuple[str, int]:
    """
    Get endpoint status and instance count.
    Returns: (status, instance_count)
    Raises: Exception if endpoint not found
    """
    resp = AWSClients.sagemaker().describe_endpoint(EndpointName=Config.ENDPOINT_NAME)
    status = resp["EndpointStatus"]
    
    instance_count = 0
    for variant in resp.get("ProductionVariants", []):
        if variant["VariantName"] == Config.VARIANT_NAME:
            instance_count = variant.get("CurrentInstanceCount", 0)
            break
    
    return status, instance_count

@retry_with_backoff()
def get_scaling_config() -> Optional[Dict[str, Any]]:
    """
    Get current scaling configuration.
    Returns: Dict with min/max capacity or None if not configured
    """
    resp = AWSClients.autoscaling().describe_scalable_targets(
        ServiceNamespace='sagemaker',
        ResourceIds=[get_resource_id()],
        ScalableDimension='sagemaker:variant:DesiredInstanceCount'
    )
    
    if resp["ScalableTargets"]:
        target = resp["ScalableTargets"][0]
        return {
            "min_capacity": target["MinCapacity"],
            "max_capacity": target["MaxCapacity"]
        }
    return None

@retry_with_backoff()
def get_scaling_policies() -> list:
    """Get all scaling policies for the endpoint."""
    resp = AWSClients.autoscaling().describe_scaling_policies(
        ServiceNamespace='sagemaker',
        ResourceId=get_resource_id(),
        ScalableDimension='sagemaker:variant:DesiredInstanceCount'
    )
    return resp.get("ScalingPolicies", [])

# =============================================================================
# INITIALIZATION
# =============================================================================

print("=" * 70)
print("MAYA FALLBACK ENDPOINT - PRODUCTION DEPLOYMENT")
print("=" * 70)
log_success("Configuration loaded")
log_info(f"Endpoint: {Config.ENDPOINT_NAME}")
log_info(f"Region: {Config.REGION}")
log_info(f"Scaling: Min={Config.MIN_INSTANCES}, Max={Config.MAX_INSTANCES}")
log_info(f"Scale trigger: >{Config.INVOCATIONS_TARGET} invocations/min per instance")
log_info(f"Retry config: {Config.MAX_RETRIES} retries with exponential backoff")
print("=" * 70)

In [None]:
"""
================================================================================
CELL 2: PRE-FLIGHT CHECKS
================================================================================
Comprehensive validation before making any changes.
Safe to run multiple times.
"""

def preflight_check() -> Dict[str, Any]:
    """
    Production-grade pre-flight validation.
    
    Checks:
    1. Endpoint exists and is InService
    2. Current scaling configuration
    3. Existing policies (flags dangerous ones)
    4. Instance count sanity
    
    Returns: Dict with all findings and issues
    """
    
    results = {
        "timestamp": datetime.now().isoformat(),
        "endpoint_exists": False,
        "endpoint_status": None,
        "current_instances": 0,
        "scaling_config": None,
        "scaling_policies": [],
        "issues": [],
        "warnings": [],
        "ready_for_fix": False
    }
    
    print("=" * 70)
    print("PRE-FLIGHT CHECK")
    print("=" * 70)
    
    # -------------------------------------------------------------------------
    # CHECK 1: Endpoint Status
    # -------------------------------------------------------------------------
    log_info("Checking endpoint status...")
    try:
        status, instances = get_endpoint_status()
        results["endpoint_exists"] = True
        results["endpoint_status"] = status
        results["current_instances"] = instances
        
        if status == "InService":
            log_success(f"Endpoint InService with {instances} instance(s)")
        elif status == "Updating":
            log_warn(f"Endpoint is Updating - wait for it to complete")
            results["warnings"].append("Endpoint is currently updating")
        else:
            log_error(f"Endpoint status: {status}")
            results["issues"].append(f"Endpoint not InService: {status}")
            
    except ClientError as e:
        if "Could not find endpoint" in str(e):
            log_error("Endpoint does not exist!")
            results["issues"].append("Endpoint not found - needs to be created first")
        else:
            log_error(f"AWS Error: {e}")
            results["issues"].append(f"AWS Error: {str(e)}")
        print("=" * 70)
        return results
    except Exception as e:
        log_error(f"Unexpected error: {e}")
        results["issues"].append(f"Unexpected error: {str(e)}")
        print("=" * 70)
        return results
    
    # -------------------------------------------------------------------------
    # CHECK 2: Scaling Configuration
    # -------------------------------------------------------------------------
    log_info("Checking auto-scaling configuration...")
    try:
        scaling = get_scaling_config()
        results["scaling_config"] = scaling
        
        if scaling:
            min_cap = scaling["min_capacity"]
            max_cap = scaling["max_capacity"]
            log_info(f"Current scaling: Min={min_cap}, Max={max_cap}")
            
            if min_cap > Config.MIN_INSTANCES:
                results["issues"].append(
                    f"MinCapacity={min_cap} is higher than target {Config.MIN_INSTANCES}"
                )
            if max_cap > Config.MAX_INSTANCES:
                results["issues"].append(
                    f"MaxCapacity={max_cap} is higher than target {Config.MAX_INSTANCES}"
                )
            if min_cap == Config.MIN_INSTANCES and max_cap == Config.MAX_INSTANCES:
                log_success("Scaling limits are correctly configured")
        else:
            log_warn("No auto-scaling configured")
            results["warnings"].append("Auto-scaling not configured")
            
    except Exception as e:
        log_warn(f"Could not check scaling config: {e}")
        results["warnings"].append(f"Scaling check failed: {str(e)}")
    
    # -------------------------------------------------------------------------
    # CHECK 3: Scaling Policies
    # -------------------------------------------------------------------------
    log_info("Checking scaling policies...")
    try:
        policies = get_scaling_policies()
        results["scaling_policies"] = [p["PolicyName"] for p in policies]
        
        for policy in policies:
            policy_name = policy["PolicyName"]
            config = policy.get("TargetTrackingScalingPolicyConfiguration", {})
            
            # Check for dangerous latency-based policy
            custom_metric = config.get("CustomizedMetricSpecification", {})
            if custom_metric.get("MetricName") == "ModelLatency" or "Latency" in policy_name:
                log_error(f"DANGEROUS: Latency-based policy '{policy_name}'")
                results["issues"].append(
                    f"Latency policy '{policy_name}' causes phantom scaling - MUST REMOVE"
                )
            else:
                log_info(f"Policy: {policy_name}")
            
            # Check cooldowns
            scale_out = config.get("ScaleOutCooldown", 0)
            if scale_out < 60:
                results["issues"].append(
                    f"Policy '{policy_name}' has aggressive ScaleOutCooldown={scale_out}s"
                )
                
        if not policies:
            log_info("No scaling policies found")
            
    except Exception as e:
        log_warn(f"Could not check policies: {e}")
        results["warnings"].append(f"Policy check failed: {str(e)}")
    
    # -------------------------------------------------------------------------
    # CHECK 4: Instance Count Sanity
    # -------------------------------------------------------------------------
    if results["current_instances"] > Config.MAX_INSTANCES:
        results["issues"].append(
            f"OVER-PROVISIONED: {results['current_instances']} instances running "
            f"(max should be {Config.MAX_INSTANCES})"
        )
        log_error(f"Too many instances: {results['current_instances']}")
    
    # -------------------------------------------------------------------------
    # SUMMARY
    # -------------------------------------------------------------------------
    print("\n" + "=" * 70)
    print("SUMMARY")
    print("=" * 70)
    
    if results["issues"]:
        log_error(f"Found {len(results['issues'])} issue(s):")
        for i, issue in enumerate(results["issues"], 1):
            print(f"   {i}. {issue}")
    
    if results["warnings"]:
        log_warn(f"Found {len(results['warnings'])} warning(s):")
        for i, warn in enumerate(results["warnings"], 1):
            print(f"   {i}. {warn}")
    
    if not results["issues"] and not results["warnings"]:
        log_success("All checks passed - endpoint is healthy!")
    
    # Determine if we can proceed with fix
    results["ready_for_fix"] = (
        results["endpoint_status"] == "InService" and 
        len([i for i in results["issues"] if "not found" in i.lower()]) == 0
    )
    
    if results["issues"] and results["ready_for_fix"]:
        print("\n" + "-" * 70)
        log_info("Run Cell 3 to fix the issues above")
    
    print("=" * 70)
    return results

# Execute pre-flight check
preflight_results = preflight_check()

In [None]:
"""
================================================================================
CELL 3: APPLY BULLETPROOF AUTO-SCALING
================================================================================
Production-grade scaling configuration that WILL NOT crash.

Features:
- Removes ALL dangerous policies (latency-based)
- Sets hard limits (Min=1, Max=4)
- Single conservative policy
- Exponential backoff retries
- Idempotent (safe to run multiple times)
- Automatic scale-down if over-provisioned
================================================================================
"""

@retry_with_backoff()
def _delete_scaling_policy(policy_name: str, resource_id: str) -> bool:
    """Delete a single scaling policy with retry."""
    AWSClients.autoscaling().delete_scaling_policy(
        PolicyName=policy_name,
        ServiceNamespace='sagemaker',
        ResourceId=resource_id,
        ScalableDimension='sagemaker:variant:DesiredInstanceCount'
    )
    return True

@retry_with_backoff()
def _register_scaling_target(resource_id: str, min_cap: int, max_cap: int) -> bool:
    """Register scalable target with retry."""
    AWSClients.autoscaling().register_scalable_target(
        ServiceNamespace='sagemaker',
        ResourceId=resource_id,
        ScalableDimension='sagemaker:variant:DesiredInstanceCount',
        MinCapacity=min_cap,
        MaxCapacity=max_cap
    )
    return True

@retry_with_backoff()
def _put_scaling_policy(resource_id: str) -> bool:
    """Apply the conservative scaling policy with retry."""
    AWSClients.autoscaling().put_scaling_policy(
        PolicyName='Maya-Fallback-Conservative',
        ServiceNamespace='sagemaker',
        ResourceId=resource_id,
        ScalableDimension='sagemaker:variant:DesiredInstanceCount',
        PolicyType='TargetTrackingScaling',
        TargetTrackingScalingPolicyConfiguration={
            'TargetValue': Config.INVOCATIONS_TARGET,
            'PredefinedMetricSpecification': {
                'PredefinedMetricType': 'SageMakerVariantInvocationsPerInstance'
            },
            'ScaleInCooldown': Config.SCALE_IN_COOLDOWN,
            'ScaleOutCooldown': Config.SCALE_OUT_COOLDOWN,
            'DisableScaleIn': False
        }
    )
    return True

@retry_with_backoff()
def _force_scale_down(target_instances: int) -> bool:
    """Force endpoint to scale down with retry."""
    AWSClients.sagemaker().update_endpoint_weights_and_capacities(
        EndpointName=Config.ENDPOINT_NAME,
        DesiredWeightsAndCapacities=[{
            'VariantName': Config.VARIANT_NAME,
            'DesiredInstanceCount': target_instances
        }]
    )
    return True


def fix_autoscaling() -> bool:
    """
    Apply bulletproof auto-scaling configuration.
    
    This function is:
    - Idempotent: Safe to run multiple times
    - Resilient: Retries transient failures
    - Safe: Won't break a working endpoint
    
    Returns: True if successful, False otherwise
    """
    
    print("=" * 70)
    print("APPLYING BULLETPROOF AUTO-SCALING CONFIGURATION")
    print("=" * 70)
    
    resource_id = get_resource_id()
    success = True
    
    # =========================================================================
    # STEP 1: Verify endpoint is ready
    # =========================================================================
    log_info("Step 1/5: Verifying endpoint status...")
    try:
        status, current_instances = get_endpoint_status()
        
        if status != "InService":
            log_error(f"Endpoint is '{status}' - cannot modify. Wait for InService.")
            return False
            
        log_success(f"Endpoint InService with {current_instances} instance(s)")
        
    except Exception as e:
        log_error(f"Failed to verify endpoint: {e}")
        return False
    
    # =========================================================================
    # STEP 2: Remove ALL existing scaling policies
    # =========================================================================
    log_info("Step 2/5: Removing existing scaling policies...")
    try:
        policies = get_scaling_policies()
        removed = 0
        
        for policy in policies:
            policy_name = policy["PolicyName"]
            try:
                _delete_scaling_policy(policy_name, resource_id)
                log_info(f"   Removed: {policy_name}")
                removed += 1
            except Exception as e:
                log_warn(f"   Could not remove {policy_name}: {e}")
        
        if removed > 0:
            log_success(f"Removed {removed} policy(ies)")
        else:
            log_info("No existing policies to remove")
            
    except Exception as e:
        log_warn(f"Error checking policies: {e}")
    
    # =========================================================================
    # STEP 3: Set scaling limits
    # =========================================================================
    log_info(f"Step 3/5: Setting limits (Min={Config.MIN_INSTANCES}, Max={Config.MAX_INSTANCES})...")
    try:
        _register_scaling_target(resource_id, Config.MIN_INSTANCES, Config.MAX_INSTANCES)
        log_success(f"Scaling limits configured")
    except Exception as e:
        log_error(f"Failed to set scaling limits: {e}")
        return False
    
    # =========================================================================
    # STEP 4: Apply conservative scaling policy
    # =========================================================================
    log_info("Step 4/5: Applying conservative scaling policy...")
    try:
        _put_scaling_policy(resource_id)
        log_success(f"Policy applied: scale at >{Config.INVOCATIONS_TARGET} invocations/min")
        log_info(f"   Scale-out cooldown: {Config.SCALE_OUT_COOLDOWN}s")
        log_info(f"   Scale-in cooldown: {Config.SCALE_IN_COOLDOWN}s")
    except Exception as e:
        log_error(f"Failed to apply policy: {e}")
        return False
    
    # =========================================================================
    # STEP 5: Force scale-down if needed
    # =========================================================================
    if current_instances > Config.MAX_INSTANCES:
        log_info(f"Step 5/5: Scaling down from {current_instances} to {Config.MIN_INSTANCES}...")
        try:
            _force_scale_down(Config.MIN_INSTANCES)
            log_success(f"Scale-down initiated (takes 5-10 min to complete)")
        except Exception as e:
            log_warn(f"Could not force scale-down: {e}")
            log_info("The new policy will handle this automatically")
    else:
        log_success(f"Step 5/5: Instance count OK ({current_instances} <= {Config.MAX_INSTANCES})")
    
    # =========================================================================
    # COMPLETE
    # =========================================================================
    print("\n" + "=" * 70)
    print("CONFIGURATION COMPLETE")
    print("=" * 70)
    log_success("Auto-scaling is now bulletproof:")
    print(f"""
   Configuration Applied:
   ----------------------
   Min instances:        {Config.MIN_INSTANCES}
   Max instances:        {Config.MAX_INSTANCES} (HARD CAP)
   Scale-out trigger:    >{Config.INVOCATIONS_TARGET} invocations/min per instance
   Scale-out cooldown:   {Config.SCALE_OUT_COOLDOWN}s (5 min)
   Scale-in cooldown:    {Config.SCALE_IN_COOLDOWN}s (10 min)
   Latency scaling:      DISABLED (was causing issues)
   
   Safety Features:
   ----------------
   - Exponential backoff retries on all AWS calls
   - Single policy (no conflicts)
   - High threshold (won't trigger for 10 users)
   - Long cooldowns (prevents rapid scaling)
    """)
    print("-" * 70)
    log_info("Run Cell 4 to verify endpoint health")
    print("=" * 70)
    
    return True


# Execute the fix
fix_success = fix_autoscaling()

In [None]:
"""
================================================================================
CELL 4: HEALTH CHECK & INFERENCE TEST
================================================================================
Comprehensive verification that everything is working.
"""

@retry_with_backoff()
def _test_inference() -> Tuple[bool, float, str]:
    """
    Run a test inference with retry.
    Returns: (success, latency_ms, response_text)
    """
    test_prompt = (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
        "You are Maya. Reply with exactly one word.<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>\n\n"
        "Hi<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    
    payload = {
        "inputs": test_prompt,
        "parameters": {
            "max_new_tokens": 10,
            "temperature": 0.1,
            "do_sample": False,
            "return_full_text": False
        }
    }
    
    start = time.time()
    response = AWSClients.runtime().invoke_endpoint(
        EndpointName=Config.ENDPOINT_NAME,
        ContentType="application/json",
        Body=json.dumps(payload)
    )
    latency_ms = (time.time() - start) * 1000
    
    result = json.loads(response["Body"].read().decode())
    generated = result[0]["generated_text"] if isinstance(result, list) else result["generated_text"]
    
    return True, latency_ms, generated.strip()


def health_check() -> Dict[str, Any]:
    """
    Comprehensive health check with test inference.
    
    Checks:
    1. Endpoint status
    2. Scaling configuration matches target
    3. No dangerous policies
    4. Inference works
    
    Returns: Dict with all results
    """
    
    results = {
        "timestamp": datetime.now().isoformat(),
        "checks_passed": 0,
        "checks_failed": 0,
        "endpoint_healthy": False,
        "scaling_correct": False,
        "inference_working": False,
        "inference_latency_ms": None,
        "details": {}
    }
    
    print("=" * 70)
    print("HEALTH CHECK")
    print("=" * 70)
    
    # -------------------------------------------------------------------------
    # CHECK 1: Endpoint Status
    # -------------------------------------------------------------------------
    log_info("Check 1/4: Endpoint status...")
    try:
        status, instances = get_endpoint_status()
        results["details"]["endpoint_status"] = status
        results["details"]["instance_count"] = instances
        
        if status == "InService":
            log_success(f"InService with {instances} instance(s)")
            results["checks_passed"] += 1
            results["endpoint_healthy"] = True
        else:
            log_error(f"Status: {status}")
            results["checks_failed"] += 1
    except Exception as e:
        log_error(f"Failed: {e}")
        results["checks_failed"] += 1
    
    # -------------------------------------------------------------------------
    # CHECK 2: Scaling Configuration
    # -------------------------------------------------------------------------
    log_info("Check 2/4: Scaling configuration...")
    try:
        scaling = get_scaling_config()
        
        if scaling:
            min_ok = scaling["min_capacity"] == Config.MIN_INSTANCES
            max_ok = scaling["max_capacity"] == Config.MAX_INSTANCES
            results["details"]["scaling"] = scaling
            
            if min_ok and max_ok:
                log_success(f"Correct: Min={scaling['min_capacity']}, Max={scaling['max_capacity']}")
                results["checks_passed"] += 1
                results["scaling_correct"] = True
            else:
                log_warn(f"Mismatch: Min={scaling['min_capacity']}, Max={scaling['max_capacity']}")
                log_info(f"Expected: Min={Config.MIN_INSTANCES}, Max={Config.MAX_INSTANCES}")
                results["checks_failed"] += 1
        else:
            log_warn("No scaling configured")
            results["checks_failed"] += 1
    except Exception as e:
        log_error(f"Failed: {e}")
        results["checks_failed"] += 1
    
    # -------------------------------------------------------------------------
    # CHECK 3: No Dangerous Policies
    # -------------------------------------------------------------------------
    log_info("Check 3/4: Scaling policies...")
    try:
        policies = get_scaling_policies()
        policy_names = [p["PolicyName"] for p in policies]
        results["details"]["policies"] = policy_names
        
        # Check for dangerous latency policies
        dangerous = []
        for policy in policies:
            name = policy["PolicyName"]
            config = policy.get("TargetTrackingScalingPolicyConfiguration", {})
            custom = config.get("CustomizedMetricSpecification", {})
            
            if custom.get("MetricName") == "ModelLatency" or "Latency" in name:
                dangerous.append(name)
        
        if dangerous:
            log_error(f"DANGEROUS policies found: {dangerous}")
            results["checks_failed"] += 1
        elif len(policies) == 1 and policies[0]["PolicyName"] == "Maya-Fallback-Conservative":
            log_success(f"Correct policy active: {policies[0]['PolicyName']}")
            results["checks_passed"] += 1
        elif len(policies) == 0:
            log_warn("No scaling policies")
            results["checks_passed"] += 1  # Not necessarily bad
        else:
            log_info(f"Policies: {policy_names}")
            results["checks_passed"] += 1
    except Exception as e:
        log_error(f"Failed: {e}")
        results["checks_failed"] += 1
    
    # -------------------------------------------------------------------------
    # CHECK 4: Test Inference
    # -------------------------------------------------------------------------
    log_info("Check 4/4: Test inference...")
    try:
        success, latency_ms, response = _test_inference()
        results["inference_latency_ms"] = latency_ms
        results["details"]["inference_response"] = response[:100]
        
        log_success(f"Inference OK: {latency_ms:.0f}ms")
        log_info(f"   Response: \"{response[:50]}...\"")
        results["checks_passed"] += 1
        results["inference_working"] = True
        
    except Exception as e:
        log_error(f"Inference failed: {e}")
        results["checks_failed"] += 1
    
    # -------------------------------------------------------------------------
    # SUMMARY
    # -------------------------------------------------------------------------
    print("\n" + "=" * 70)
    print("RESULTS")
    print("=" * 70)
    
    total = results["checks_passed"] + results["checks_failed"]
    
    if results["checks_failed"] == 0:
        log_success(f"ALL {total} CHECKS PASSED")
        print("""
   Endpoint Status:  HEALTHY
   Scaling Config:   CORRECT
   Policies:         SAFE
   Inference:        WORKING
   
   Your endpoint is production-ready!
        """)
    else:
        log_error(f"{results['checks_failed']}/{total} checks failed")
        if results["checks_passed"] > 0:
            log_info(f"{results['checks_passed']}/{total} checks passed")
    
    print("=" * 70)
    return results


# Run health check
health_results = health_check()

In [None]:
"""
================================================================================
CELL 5: EMERGENCY CONTROLS
================================================================================
Quick commands for emergency situations.
All functions have retry logic and are safe to run.
"""

def quick_status():
    """Quick status check - safe to run anytime."""
    print("-" * 50)
    print("QUICK STATUS")
    print("-" * 50)
    try:
        status, instances = get_endpoint_status()
        scaling = get_scaling_config()
        
        print(f"Endpoint:  {Config.ENDPOINT_NAME}")
        print(f"Status:    {status}")
        print(f"Instances: {instances}")
        
        if scaling:
            print(f"Scaling:   Min={scaling['min_capacity']}, Max={scaling['max_capacity']}")
        else:
            print("Scaling:   Not configured")
            
    except Exception as e:
        print(f"Error: {e}")
    print("-" * 50)

# Run status check
quick_status()


# =============================================================================
# EMERGENCY FUNCTIONS (uncomment to use)
# =============================================================================

@retry_with_backoff()
def force_scale_to_one():
    """
    EMERGENCY: Force endpoint to exactly 1 instance.
    Use when costs are spiking unexpectedly.
    """
    print("=" * 50)
    print("EMERGENCY: FORCING SCALE TO 1 INSTANCE")
    print("=" * 50)
    
    # Lock scaling to 1
    AWSClients.autoscaling().register_scalable_target(
        ServiceNamespace='sagemaker',
        ResourceId=get_resource_id(),
        ScalableDimension='sagemaker:variant:DesiredInstanceCount',
        MinCapacity=1,
        MaxCapacity=1
    )
    log_success("Scaling locked to Min=1, Max=1")
    
    # Force the scale
    AWSClients.sagemaker().update_endpoint_weights_and_capacities(
        EndpointName=Config.ENDPOINT_NAME,
        DesiredWeightsAndCapacities=[{
            'VariantName': Config.VARIANT_NAME,
            'DesiredInstanceCount': 1
        }]
    )
    log_success("Scale-down initiated (takes 5-10 min)")
    print("=" * 50)

# Uncomment to force scale to 1:
# force_scale_to_one()


@retry_with_backoff()
def disable_autoscaling():
    """
    Completely disable auto-scaling.
    Endpoint stays at current instance count.
    """
    print("=" * 50)
    print("DISABLING AUTO-SCALING")
    print("=" * 50)
    
    # Remove all policies
    policies = get_scaling_policies()
    for p in policies:
        try:
            _delete_scaling_policy(p["PolicyName"], get_resource_id())
            log_info(f"Removed policy: {p['PolicyName']}")
        except:
            pass
    
    # Deregister target
    try:
        AWSClients.autoscaling().deregister_scalable_target(
            ServiceNamespace='sagemaker',
            ResourceId=get_resource_id(),
            ScalableDimension='sagemaker:variant:DesiredInstanceCount'
        )
        log_success("Auto-scaling disabled")
    except Exception as e:
        log_warn(f"Could not deregister: {e}")
    
    print("Instance count is now FIXED at current value")
    print("=" * 50)

# Uncomment to disable auto-scaling:
# disable_autoscaling()


def delete_endpoint():
    """
    NUCLEAR OPTION: Delete the endpoint completely.
    Stops all billing immediately.
    Requires explicit confirmation.
    """
    print("=" * 50)
    print("WARNING: ENDPOINT DELETION")
    print("=" * 50)
    print(f"Endpoint: {Config.ENDPOINT_NAME}")
    print("This will PERMANENTLY delete the endpoint.")
    print("All instances will be terminated.")
    print("-" * 50)
    
    confirm = input(f"Type 'DELETE' to confirm: ")
    
    if confirm == "DELETE":
        try:
            AWSClients.sagemaker().delete_endpoint(EndpointName=Config.ENDPOINT_NAME)
            log_success("Endpoint deletion initiated")
            print("Billing will stop once deletion completes (~5 min)")
        except Exception as e:
            log_error(f"Deletion failed: {e}")
    else:
        print("Cancelled.")
    
    print("=" * 50)

# Uncomment to delete endpoint (requires confirmation):
# delete_endpoint()

In [None]:
"""
================================================================================
CELL 6: MONITORING DASHBOARD
================================================================================
View endpoint metrics to verify scaling behavior.
"""

@retry_with_backoff()
def _get_metric(metric_name: str, stat: str, hours: int) -> Optional[list]:
    """Fetch a single metric from CloudWatch."""
    end_time = datetime.utcnow()
    start_time = end_time - timedelta(hours=hours)
    
    resp = AWSClients.cloudwatch().get_metric_statistics(
        Namespace='AWS/SageMaker',
        MetricName=metric_name,
        Dimensions=[
            {'Name': 'EndpointName', 'Value': Config.ENDPOINT_NAME},
            {'Name': 'VariantName', 'Value': Config.VARIANT_NAME}
        ],
        StartTime=start_time,
        EndTime=end_time,
        Period=3600,  # 1 hour
        Statistics=[stat]
    )
    return resp.get('Datapoints', [])


def show_metrics(hours: int = 6):
    """
    Display endpoint metrics for the last N hours.
    
    Shows:
    - Total invocations
    - Average latency
    - Error counts
    """
    
    print("=" * 70)
    print(f"METRICS: Last {hours} hours")
    print("=" * 70)
    print(f"Endpoint: {Config.ENDPOINT_NAME}")
    print(f"Region:   {Config.REGION}")
    print("-" * 70)
    
    metrics = [
        ("Invocations", "Sum", "requests"),
        ("ModelLatency", "Average", "ms"),
        ("Invocation4XXErrors", "Sum", "errors"),
        ("Invocation5XXErrors", "Sum", "errors")
    ]
    
    for metric_name, stat, unit in metrics:
        try:
            datapoints = _get_metric(metric_name, stat, hours)
            
            if datapoints:
                values = [d[stat] for d in datapoints]
                total = sum(values)
                avg = total / len(values) if values else 0
                
                if metric_name == "ModelLatency":
                    # Convert microseconds to ms
                    print(f"{metric_name:25} avg {avg/1000:,.0f} {unit}")
                elif "Error" in metric_name:
                    if total > 0:
                        print(f"{metric_name:25} {total:,.0f} {unit} (!)")
                    else:
                        print(f"{metric_name:25} 0 {unit}")
                else:
                    print(f"{metric_name:25} {total:,.0f} {unit} ({avg:,.1f}/hr)")
            else:
                print(f"{metric_name:25} No data")
                
        except Exception as e:
            print(f"{metric_name:25} Error: {e}")
    
    print("-" * 70)
    
    # Show current scaling status
    try:
        status, instances = get_endpoint_status()
        scaling = get_scaling_config()
        
        print(f"Current Status:           {status}")
        print(f"Current Instances:        {instances}")
        if scaling:
            print(f"Scaling Config:           Min={scaling['min_capacity']}, Max={scaling['max_capacity']}")
    except Exception as e:
        print(f"Status Error: {e}")
    
    print("=" * 70)


def show_scaling_history(hours: int = 24):
    """
    Show scaling activity over the last N hours.
    Helps identify if scaling is behaving correctly.
    """
    
    print("=" * 70)
    print(f"SCALING ACTIVITY: Last {hours} hours")
    print("=" * 70)
    
    try:
        end_time = datetime.utcnow()
        start_time = end_time - timedelta(hours=hours)
        
        # Get invocations per instance (the scaling metric)
        resp = AWSClients.cloudwatch().get_metric_statistics(
            Namespace='AWS/SageMaker',
            MetricName='InvocationsPerInstance',
            Dimensions=[
                {'Name': 'EndpointName', 'Value': Config.ENDPOINT_NAME},
                {'Name': 'VariantName', 'Value': Config.VARIANT_NAME}
            ],
            StartTime=start_time,
            EndTime=end_time,
            Period=3600,
            Statistics=['Average', 'Maximum']
        )
        
        datapoints = sorted(resp.get('Datapoints', []), key=lambda x: x['Timestamp'])
        
        if datapoints:
            print(f"{'Time':20} {'Avg/Instance':15} {'Max/Instance':15}")
            print("-" * 50)
            for dp in datapoints[-12:]:  # Last 12 hours
                time_str = dp['Timestamp'].strftime('%Y-%m-%d %H:%M')
                avg = dp.get('Average', 0)
                max_val = dp.get('Maximum', 0)
                
                # Flag if close to scaling threshold
                flag = " <-- near threshold" if avg > Config.INVOCATIONS_TARGET * 0.7 else ""
                print(f"{time_str:20} {avg:>12.1f}   {max_val:>12.1f}{flag}")
        else:
            print("No scaling data available")
            
    except Exception as e:
        print(f"Error: {e}")
    
    print("=" * 70)


# Show metrics
show_metrics(6)