In [1]:
%%writefile package.json
{
  "name": "fastapi-react-monorepo",
  "private": true,
  "scripts": {
    "install:all": "python -m venv .venv && .venv\\Scripts\\python.exe -m pip install uv && .venv\\Scripts\\uv.exe pip install -e api && .venv\\Scripts\\python.exe -m pip install bcrypt passlib[bcrypt] python-dotenv && (cd web && npm install)",
    "seed": ".venv\\Scripts\\python.exe api/scripts/seed_user.py",
    "dev": "concurrently -n \"API,WEB\" -c \"cyan,magenta\" \".venv\\Scripts\\python.exe -m uvicorn api.app.main:app --reload --env-file .env\" \"npm --prefix web run dev\"",
    "dev:backend": "python api/scripts/ensure_models.py && .venv\\Scripts\\python.exe -m uvicorn api.app.main:app --reload --env-file .env",
    "dev:full": "concurrently -n \"API,WEB\" -c \"cyan,magenta\" \"npm run dev:backend\" \"npm --prefix web run dev\"",
    "backend": ".venv\\Scripts\\python.exe -m uvicorn api.app.main:app --host 0.0.0.0 --port 8000 --env-file .env",
    "backend:dev": ".venv\\Scripts\\python.exe -m uvicorn api.app.main:app --reload --host 0.0.0.0 --port 8000 --env-file .env",
    "backend:fast": "set SKIP_BACKGROUND_TRAINING=1 && .venv\\Scripts\\python.exe -m uvicorn api.app.main:app --host 0.0.0.0 --port 8000 --env-file .env",
    "frontend": "npm --prefix web run dev",
    "ensure:models": "python api/scripts/ensure_models.py",
    "test:self-healing": "python test_self_healing.py",
    "test:import": "python test_import.py",
    "build:web": "npm --prefix web run build",
    "debug": "timeout /T 3 && curl -s http://127.0.0.1:8000/api/health && echo. && curl -s -X POST -d \"username=alice&password=secret\" -H \"Content-Type: application/x-www-form-urlencoded\" http://127.0.0.1:8000/api/token"
  },
  "devDependencies": {
    "concurrently": "^8.2.2"
  }
} 


Overwriting package.json


In [2]:
%%writefile invoke.yml
# invoke.yml
tasks:
  dev:
    - uv venv
    - uv sync
    - uvicorn api.app.main:app --reload
  test:
    - uv pip install pytest coverage
    - pytest -q
  lint:
    - uv pip install black isort flake8
    - black .
    - isort .
    - flake8


Overwriting invoke.yml


In [3]:
%%writefile .gitignore
.env
dev.env
.devcontainer/.env.runtime

mlruns/
mlflow_db/
mlruns_local/

node_modules/
frontend/node_modules/

archive/
.venv
uv.lock

test_iris.json
#.env.template

# Railway CLI (never commit tokens)
.railway/config.json

archive/

Overwriting .gitignore


In [4]:
%%writefile env.template
ENV_NAME="react_fastapi_railway"
CUDA_TAG="12.8.0"
DOCKER_BUILDKIT="1"
HOST_JUPYTER_PORT="8890"
HOST_TENSORBOARD_PORT="6008"
HOST_EXPLAINER_PORT="8050"
HOST_STREAMLIT_PORT="8501"
HOST_MLFLOW_PORT="5000"
HOST_APP_PORT="5100"
HOST_BACKEND_DEV_PORT="5002"
MLFLOW_TRACKING_URI="http://mlflow:5000"
MLFLOW_VERSION="2.12.2"
PYTHON_VER="3.10"
JAX_PLATFORM_NAME="gpu"
XLA_PYTHON_CLIENT_PREALLOCATE="true"
XLA_PYTHON_CLIENT_ALLOCATOR="platform"
XLA_PYTHON_CLIENT_MEM_FRACTION="0.95"
XLA_FLAGS="--xla_force_host_platform_device_count=1"
JAX_DISABLE_JIT="false"
JAX_ENABLE_X64="false"
TF_FORCE_GPU_ALLOW_GROWTH="false"
JAX_PREALLOCATION_SIZE_LIMIT_BYTES="8589934592"
RAILWAY_TOKEN="78d0d32f-d203-4542-a9a9-32a331aa8c29"
RAILWAY_VITE_API_URL="https://fastapi-production-1d13.up.railway.app"
VITE_API_URL=http://127.0.0.1:8000/api/v1
REACT_APP_API_URL="https://react-frontend-production-2805.up.railway.app"
SECRET_KEY="change-me-in-prod"
USERNAME_KEY="alice"
USER_PASSWORD="supersecretvalue"
DATABASE_URL="sqlite+aiosqlite:///./app.db"
RAILWAY_ENVIRONMENT="production"
RAILWAY_ENVIRONMENT_ID="fa10dc06-75ec-4c11-93d4-a0fde17996d0"
RAILWAY_ENVIRONMENT_NAME="production"
RAILWAY_PRIVATE_DOMAIN="empowering-appreciation.railway.internal"
RAILWAY_PROJECT_ID="fc9da558-31d6-4b28-9eda-2bbe56cc7390"
RAILWAY_PROJECT_NAME="responsible-abundance"
RAILWAY_SERVICE_ID="87c129ab-ba49-471a-88bb-853ace60180d"
RAILWAY_SERVICE_NAME="empowering-appreciation"



Overwriting env.template


In [5]:
%%writefile api/env.template
# Local development environment template
# Copy this to .env and modify as needed

# Database
DATABASE_URL=sqlite+aiosqlite:///./app.db

# Security (generate a secure key for production)
SECRET_KEY=your-secret-key-here
ACCESS_TOKEN_EXPIRE_MINUTES=30

# CORS
ALLOWED_ORIGINS=*

# Redis Configuration (for rate limiting)
REDIS_URL=redis://localhost:6379

# Rate Limiting Configuration
RATE_LIMIT_DEFAULT=60
RATE_LIMIT_CANCER=30
RATE_LIMIT_LOGIN=3
RATE_LIMIT_TRAINING=2
RATE_LIMIT_WINDOW=60
RATE_LIMIT_WINDOW_LIGHT=300   # 5 minutes for light endpoint (iris/predict)
RATE_LIMIT_LOGIN_WINDOW=20

# MLflow - use local file store for development
# Set to http://your-mlflow-server:5000 for production
# MLflow Configuration
MLFLOW_TRACKING_URI=file:./mlruns_local
MLFLOW_REGISTRY_URI=file:./mlruns_local


# Model Training Flags
SKIP_BACKGROUND_TRAINING=0
AUTO_TRAIN_MISSING=1
UNIT_TESTING=0

# Debug Flags (keep OFF in production)
DEBUG_RATELIMIT=0 

# JAX/XLA Configuration
# Host has a single logical CPU device – prevents JAX allocating all cores
XLA_FLAGS=--xla_force_host_platform_device_count=1

# Force CPU backend for JAX (uncomment if GPU issues occur)
# JAX_PLATFORM_NAME=cpu

# PyTensor configuration (CPU only to avoid C++ compilation)
PYTENSOR_FLAGS=device=cpu,floatX=float32 




Overwriting api/env.template


In [6]:
%%writefile logging.yaml
version: 1
disable_existing_loggers: False
formatters:
  default: 
    format: "[%(levelname).1s] %(asctime)s %(name)s ▶ %(message)s"
handlers:
  console:
    class: logging.StreamHandler
    formatter: default
  file:
    class: logging.FileHandler
    filename: logs/backend.log
    formatter: default
loggers:
  uvicorn.error:  
    level: INFO
    handlers: [console, file]
  uvicorn.access: 
    level: INFO
    handlers: [console, file]
  app:            
    level: DEBUG
    handlers: [console, file]
    propagate: False
  app.services.ml.model_service:
    level: DEBUG
    handlers: [console, file]
    propagate: False
root:
  level: INFO
  handlers: [console, file] 

Overwriting logging.yaml


In [7]:
%%writefile pyproject.toml
[project]
name = "react_fastapi_railway"
version = "0.1.0"
description = "Pytorch and Jax GPU docker container"
authors = [
  { name = "Geoffrey Hadfield" },
]
license = "MIT"
readme = "README.md"

# ─── Restrict to Python 3.10–3.12 ──────────────────────────────
requires-python = ">=3.10,<3.13"

dependencies = [
  # Core web framework
  "fastapi>=0.104.0",
  "uvicorn[standard]>=0.24.0",
  "python-dotenv>=1.0.0",

  # Settings and validation
  "pydantic>=2.0.0",
  "pydantic-settings>=2.0.0",

  # HTTP client and multipart parsing
  "httpx>=0.24.0",
  "python-multipart>=0.0.6",

  # Data & ML basics
  "numpy>=1.24.0",
  "pandas>=2.1.0",
  "scikit-learn>=1.3.0",
  "mlflow>=2.8.0",

  # (Your existing extras—keep if you still need them)
  "matplotlib>=3.4.0",
  "pymc>=5.0.0",
  "arviz>=0.14.0",
  "statsmodels>=0.13.0",
  "jupyterlab>=3.0.0",
  "seaborn>=0.11.0",
  "tabulate>=0.9.0",
  "shap>=0.40.0",
  "xgboost>=1.5.0",
  "lightgbm>=3.3.0",
  "catboost>=1.2.8,<1.3.0",
  "scipy>=1.7.0",
  "shapash[report]>=2.3.0",
  "shapiq>=0.1.0",
  "explainerdashboard==0.5.1",
  "ipywidgets>=8.0.0",
  "nutpie>=0.7.1",
  "numpyro>=0.18.0,<1.0.0",
  "jax==0.6.0",
  "jaxlib==0.6.0",
  "pytensor>=2.18.3",
  "aesara>=2.9.4",
  "tqdm>=4.67.0",
  "pyarrow>=12.0.0",
  "optuna>=3.0.0",
  "optuna-integration[mlflow]>=0.2.0",
  "omegaconf>=2.3.0,<2.4.0",
  "hydra-core>=1.3.2,<1.4.0",
  "aiosqlite>=0.19.0", 
  "python-jose[cryptography]>=3.3.0",
  "passlib[bcrypt]>=1.7.4",
  "bcrypt==4.0.1",  # Pin bcrypt version to resolve warning
  # Rate limiting
  "fastapi-limiter>=0.1.5",
  "aioredis>=2.0.0",
  "httpx>=0.24.0",
]

[project.optional-dependencies]
dev = [
  "pytest>=7.0.0",
  "black>=23.0.0",
  "isort>=5.0.0",
  "flake8>=5.0.0",
  "mypy>=1.0.0",
  "invoke>=2.2",
]

cuda = [
  "cupy-cuda12x>=12.0.0",
]

[tool.pytensor]
device    = "cuda"
floatX    = "float32"
allow_gc  = true
optimizer = "fast_run"



Overwriting pyproject.toml


In [8]:
%%writefile api/pyproject.toml
[project]
name = "api"
version = "1.0.0"
description = "FastAPI backend with React frontend"
requires-python = ">=3.8"
dependencies = [
    "fastapi>=0.104.0",
    "uvicorn>=0.24.0",
    "sqlalchemy>=2.0.23",
    "aiosqlite>=0.19.0",
    "python-jose[cryptography]>=3.3.0",
    "passlib[bcrypt]>=1.7.4",
    "python-multipart>=0.0.6",
    "pydantic>=2.4.2",
    "bcrypt==4.0.1",  # Pin bcrypt version to resolve warning
    # Rate limiting
    "fastapi-limiter>=0.1.5",
    "aioredis>=2.0.0",
    "httpx>=0.24.0",
    # ML dependencies
    "mlflow>=2.8.0",
    "scikit-learn>=1.3.0",
    "pandas>=2.0.0",
    "numpy>=1.24.0",
    "pymc>=5.7.0",
    "arviz>=0.15.0",
    "requests>=2.31.0",
    "jax[cpu]>=0.5.3,<0.7",
    "jaxlib>=0.5.3,<0.7",
    "numpyro>=0.14.1,<0.16"
]

[project.optional-dependencies]
dev = [
    "pytest>=7.0.0",
    "pytest-asyncio>=0.21.0",
    "httpx>=0.24.0"
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["app"]







Overwriting api/pyproject.toml


In [9]:
%%writefile api/railway.json
{
  "$schema": "https://railway.app/railway.schema.json",
  "build": { "builder": "NIXPACKS" },
  "deploy": {
    "startCommand": "bash ./start.sh",
    "healthcheckPath": "/api/v1/health",
    "healthcheckInterval": 10,
    "healthcheckTimeout": 300,
    "restartPolicyType": "ON_FAILURE",
    "restartPolicyMaxRetries": 10
  }
}


Overwriting api/railway.json


In [10]:
%%writefile api/start.sh
#!/usr/bin/env bash
set -euo pipefail

# ── sanity ─────────────────────────────────────────────────────────
if [[ -z "${PORT:-}" ]]; then
  echo "❌  PORT not set – Railway always provides it." >&2
  exit 1
fi

if [[ -z "${SECRET_KEY:-}" ]]; then
  echo "❌  SECRET_KEY is not set for the backend service – aborting." >&2
  exit 1
fi

echo "🚀  FastAPI boot; PORT=$PORT  PY=$(python -V)"
env | grep -E 'RAILWAY_|PORT|DATABASE_URL' | sed 's/SECRET_KEY=.*/SECRET_KEY=***/'

# ── optional local .env ------------------------------------------------------
[[ -f .env ]] && export $(grep -Ev '^#' .env | xargs)

# ── one-shot DB migrate + seed (blocks until done) ---------------------------
python -m scripts.seed_user

# ── run the API --------------------------------------------------------------
exec uvicorn app.main:app \
  --host 0.0.0.0 --port "$PORT" \
  --proxy-headers --forwarded-allow-ips="*" --log-level info



Overwriting api/start.sh


In [11]:
%%writefile api/scripts/seed_user.py
from pathlib import Path
from passlib.context import CryptContext
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy import select
import os, asyncio

# ── optional .env load (UTF-8 only) ───────────────────────────────
ENV_PATH = Path(__file__).resolve().parents[2] / ".env"
if ENV_PATH.exists():
    try:
        from dotenv import load_dotenv
        load_dotenv(ENV_PATH, encoding="utf-8")
    except UnicodeDecodeError:
        print("⚠️  .env not UTF-8 – skipped")

# ── model import (kept same) ──────────────────────────────────────
import sys; sys.path.append(str(Path(__file__).resolve().parents[1]))
from app.models import Base, User

USERNAME = os.getenv("USERNAME_KEY", "alice")
PASSWORD = os.getenv("USER_PASSWORD", "supersecretvalue")

pwd = CryptContext(schemes=["bcrypt"], deprecated="auto")
engine = create_async_engine("sqlite+aiosqlite:///./app.db")
session_factory = async_sessionmaker(engine, expire_on_commit=False)

async def main():
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)

    async with session_factory() as db:
        result = await db.execute(select(User).where(User.username == USERNAME))
        user = result.scalar_one_or_none()
        hashed = pwd.hash(PASSWORD)

        if user:
            user.hashed_password = hashed
            action = "Updated"
        else:
            db.add(User(username=USERNAME, hashed_password=hashed))
            action = "Created"
        await db.commit()
        print(f"{action} user {USERNAME}")

if __name__ == "__main__":
    asyncio.run(main())




Overwriting api/scripts/seed_user.py


In [12]:
%%writefile api/app/__init__.py
# api/app/utils/__init__.py
"""
Utility functions for the FastAPI application.
""" 

Overwriting api/app/__init__.py


In [13]:
%%writefile api/app/utils/env_sanitizer.py
# api/app/utils/env_sanitizer.py
"""
Early‑process clean‑up of env variables that mis‑configure JAX / PyTensor.
Import *before* anything touches JAX / PyMC.
"""

from __future__ import annotations
import os, logging, importlib.util

log = logging.getLogger(__name__)

_VALID_XLA_PREFIXES = ("--xla_", "--mmap_", "--tfrt_")

def _clean_xla_flags() -> None:
    """Remove invalid XLA_FLAGS tokens that cause crashes."""
    val = os.getenv("XLA_FLAGS")
    if not val:
        return
    tokens = [t for t in val.split() if t]
    bad = [t for t in tokens if not t.startswith(_VALID_XLA_PREFIXES)]
    if bad:
        log.warning("🧹 Removing invalid XLA_FLAGS tokens: %s", bad)
        tokens = [t for t in tokens if t not in bad]
    if tokens:
        os.environ["XLA_FLAGS"] = " ".join(tokens)
    else:        # was just '--'
        os.environ.pop("XLA_FLAGS", None)

def _downgrade_jax_backend() -> None:
    """Force JAX to use CPU if GPU is requested but not available."""
    # Check if GPU backend is explicitly requested
    platform_name = os.getenv("JAX_PLATFORM_NAME", "").lower()
    if platform_name in ("gpu", "cuda"):
        # Check if CUDA runtime is actually available
        cuda_spec = importlib.util.find_spec("jaxlib.cuda_extension")
        if cuda_spec is None:
            log.warning("⚠️ No CUDA runtime found – forcing JAX_PLATFORM_NAME=cpu")
            os.environ["JAX_PLATFORM_NAME"] = "cpu"
        else:
            log.info("✅ CUDA runtime detected, keeping GPU backend")

def _force_pytensor_cpu() -> None:
    """Force PyTensor to use CPU device to avoid C++ compilation issues."""
    # Only set if not already configured
    if "PYTENSOR_FLAGS" not in os.environ:
        os.environ["PYTENSOR_FLAGS"] = "device=cpu,floatX=float32"
        log.info("🔧 Set PyTensor to CPU device")
    
    # Also set legacy config for compatibility
    if "DEVICE" not in os.environ:
        os.environ["DEVICE"] = "cpu"

def _disable_pytensor_compilation() -> None:
    """Completely disable PyTensor C compilation to avoid MSVC issues."""
    # Force PyTensor to use Python backend instead of C compilation
    os.environ["PYTENSOR_FLAGS"] = "device=cpu,floatX=float32"
    
    # Disable C compilation entirely
    os.environ["PYTENSOR_COMPILE_OPTIMIZER"] = "fast_compile"
    os.environ["PYTENSOR_COMPILE_MODE"] = "FAST_COMPILE"
    
    # Force Python backend for PyTensor (no C compilation)
    os.environ["PYTENSOR_LINKER"] = "py"
    
    log.info("🔧 Disabled PyTensor C compilation, using Python backend")

def _check_cuda_environment() -> None:
    """Log CUDA-related environment variables for debugging."""
    cuda_vars = {k: v for k, v in os.environ.items() 
                 if 'CUDA' in k or 'GPU' in k or 'JAX' in k}
    if cuda_vars:
        log.info("🔍 CUDA/JAX environment variables: %s", cuda_vars)

def fix_ml_backends() -> None:
    """
    Comprehensive fix for JAX/PyTensor backend configuration.
    
    This function should be called **once** at the very top of app.main
    before any JAX or PyMC imports.
    """
    log.info("🔧 Sanitizing ML backend configuration...")
    
    _check_cuda_environment()
    _clean_xla_flags()
    _downgrade_jax_backend()
    _force_pytensor_cpu()
    _disable_pytensor_compilation()
    
    log.info("✅ ML backend sanitization complete")

# Legacy function for backward compatibility
def fix_xla_flags() -> None:
    """Legacy function - now calls the comprehensive fix."""
    fix_ml_backends() 


Overwriting api/app/utils/env_sanitizer.py


In [14]:
%%writefile api/app/middleware/__init__.py
# Middleware package 

Overwriting api/app/middleware/__init__.py


In [15]:
%%writefile api/app/middleware/concurrency.py
"""
Concurrency limiting middleware for heavy endpoints.
Provides semaphore-based concurrency control to prevent resource exhaustion.
"""

import asyncio
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from fastapi import HTTPException, status
import logging

logger = logging.getLogger(__name__)

class ConcurrencyLimiter(BaseHTTPMiddleware):
    """
    Middleware that limits concurrent requests to heavy endpoints.
    
    This is useful for CPU-intensive operations like Bayesian inference
    that could overwhelm the server if too many requests are processed simultaneously.
    """
    
    def __init__(self, app, max_concurrent: int = 4, heavy_endpoints: set = None):
        super().__init__(app)
        self._sem = asyncio.Semaphore(max_concurrent)
        self.heavy_endpoints = heavy_endpoints or {
            "/api/v1/cancer/predict",
            "/api/v1/iris/train", 
            "/api/v1/cancer/train"
        }
        logger.info(f"Concurrency limiter initialized with max {max_concurrent} concurrent requests")

    async def dispatch(self, request: Request, call_next) -> Response:
        """Process request with concurrency limiting for heavy endpoints."""
        path = request.url.path
        
        # Only apply concurrency limiting to heavy endpoints
        if path in self.heavy_endpoints:
            try:
                async with self._sem:
                    logger.debug(f"Processing heavy endpoint {path} with concurrency control")
                    return await call_next(request)
            except asyncio.TimeoutError:
                logger.warning(f"Concurrency timeout for {path}")
                raise HTTPException(
                    status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
                    detail="Server is busy processing other requests. Please try again in a moment.",
                    headers={"Retry-After": "30"}
                )
        else:
            # Light endpoints bypass concurrency control
            return await call_next(request) 

Overwriting api/app/middleware/concurrency.py


In [16]:
%%writefile api/app/deps/__init__.py
# Rate limiting dependencies package 

Overwriting api/app/deps/__init__.py


In [17]:
%%writefile api/app/deps/limits.py
"""
Rate limiting dependencies for FastAPI endpoints.
✅ FIX: identifier must be async (fastapi-limiter expects await).
"""

from __future__ import annotations

from fastapi_limiter.depends import RateLimiter
from starlette.requests import Request
from ..core.config import settings

# ───────────────────────── helpers ──────────────────────────
async def _path_aware_ip(request: Request) -> str:
    """
    Return `<ip>:<path>` so each endpoint has its own bucket.
    Cheap & fully-sync, but declared *async* because fastapi-limiter
    always awaits the identifier.
    """
    forwarded = request.headers.get("X-Forwarded-For")
    ip = (forwarded.split(",")[0].strip() if forwarded else request.client.host)
    return f"{ip}:{request.scope['path']}"

async def user_or_ip(request: Request) -> str:
    """
    Prefer JWT → keeps per-user buckets across NAT; otherwise fall back to IP+path.
    Also declared async for compatibility with fastapi-limiter.
    """
    auth = request.headers.get("Authorization", "")
    if auth.startswith("Bearer "):
        return auth[7:]
    return await _path_aware_ip(request)

# ──────────────────────── limiters ──────────────────────────

default_limit = RateLimiter(
    times=settings.RATE_LIMIT_DEFAULT,
    seconds=settings.RATE_LIMIT_WINDOW,
    identifier=user_or_ip,
)

light_limit = RateLimiter(                # /iris/predict
    times=120,
    seconds=settings.RATE_LIMIT_WINDOW_LIGHT,  # Use dedicated light window
    identifier=user_or_ip,                # ← switched to token-based
)

heavy_limit = RateLimiter(                # /cancer/predict
    times=settings.RATE_LIMIT_CANCER,
    seconds=settings.RATE_LIMIT_WINDOW,
    identifier=user_or_ip,
)

login_limit = RateLimiter(                # bad‑login attempts
    # `times` is exclusive – allow three failures, block 4th
    times=settings.RATE_LIMIT_LOGIN + 1,
    seconds=settings.RATE_LIMIT_LOGIN_WINDOW,
    identifier=_path_aware_ip,
)

training_limit = RateLimiter(             # /train endpoints
    times=settings.RATE_LIMIT_TRAINING,
    seconds=settings.RATE_LIMIT_WINDOW * 5,
    identifier=user_or_ip,
)

# Handy handle for debug & CI
def get_redis():
    from fastapi_limiter import FastAPILimiter as _L
    return _L.redis 


Overwriting api/app/deps/limits.py


In [18]:
%%writefile api/app/core/config.py
"""
Core configuration settings for the FastAPI application.
Centralizes environment variables and provides sensible defaults.
"""

import os
from typing import Optional

class Settings:
    """Application settings with environment-based configuration."""

    # Database
    DATABASE_URL: str = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./app.db")

    # Security
    SECRET_KEY: Optional[str] = os.getenv("SECRET_KEY")
    ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))

    # CORS
    ALLOWED_ORIGINS: str = os.getenv("ALLOWED_ORIGINS", "*")

    # Rate Limiting
    REDIS_URL: str = os.getenv("REDIS_URL", "redis://localhost:6379")
    RATE_LIMIT_DEFAULT: int = int(os.getenv("RATE_LIMIT_DEFAULT", "60"))
    RATE_LIMIT_CANCER: int = int(os.getenv("RATE_LIMIT_CANCER", "30"))
    RATE_LIMIT_LOGIN: int = int(os.getenv("RATE_LIMIT_LOGIN", "3"))
    RATE_LIMIT_TRAINING: int = int(os.getenv("RATE_LIMIT_TRAINING", "2"))
    RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60"))  # seconds
    RATE_LIMIT_WINDOW_LIGHT: int = int(os.getenv("RATE_LIMIT_WINDOW_LIGHT", "300"))  # 5 minutes for light endpoint
    RATE_LIMIT_LOGIN_WINDOW: int = int(os.getenv("RATE_LIMIT_LOGIN_WINDOW", "20"))  # seconds

    # MLflow in local-file mode by default
    MLFLOW_TRACKING_URI: str = os.getenv(
        "MLFLOW_TRACKING_URI",
        "file:./mlruns_local"
    )
    MLFLOW_REGISTRY_URI: str = os.getenv(
        "MLFLOW_REGISTRY_URI",
        MLFLOW_TRACKING_URI
    )

    # Model training flags
    SKIP_BACKGROUND_TRAINING: bool = os.getenv("SKIP_BACKGROUND_TRAINING", "0") == "1"
    AUTO_TRAIN_MISSING: bool = os.getenv("AUTO_TRAIN_MISSING", "1") == "1"
    UNIT_TESTING: bool = os.getenv("UNIT_TESTING", "0") == "1"

    # Debug flags
    DEBUG_RATELIMIT: bool = os.getenv("DEBUG_RATELIMIT", "0") == "1"

settings = Settings() 


Overwriting api/app/core/config.py


In [19]:
%%writefile api/app/crud.py
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from .models import User

async def get_user_by_username(db: AsyncSession, username: str):
    stmt = select(User).where(User.username == username)
    res = await db.execute(stmt)
    return res.scalar_one_or_none() 

Overwriting api/app/crud.py


In [20]:
%%writefile api/app/models.py
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import declarative_base

Base = declarative_base()

class User(Base):
    __tablename__ = "users"
    id = Column(Integer, primary_key=True, index=True)
    username = Column(String, unique=True, index=True)
    hashed_password = Column(String) 

Overwriting api/app/models.py


In [21]:
%%writefile api/app/db.py
# api/app/db.py
from contextlib import asynccontextmanager
import os, logging, asyncio
from sqlalchemy.ext.asyncio import (
    AsyncSession,
    create_async_engine,
    async_sessionmaker,
)
from fastapi_limiter import FastAPILimiter
import redis.asyncio as redis
from .models import Base
from .services.ml.model_service import model_service
from .core.config import settings

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Database engine & session factory (module-level singletons – cheap & safe)
# ---------------------------------------------------------------------------
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./app.db")
engine = create_async_engine(DATABASE_URL, echo=False, future=True)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)

# Global readiness flag
_app_ready: bool = False

def get_app_ready():
    """Get the current app ready status."""
    return _app_ready

# ---------------------------------------------------------------------------
# FastAPI lifespan – runs ONCE at startup / shutdown
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app):
    """
    Application lifespan context-manager.

    * creates DB tables
    * initialises ML models
    * (NEW) wires Redis-backed rate-limiter
    * sets global _app_ready flag
    * disposes resources on shutdown
    """
    global _app_ready

    logger.info("🗄️  Initializing database…  URL=%s", DATABASE_URL)
    try:
        async with engine.begin() as conn:
            # DDL is safe here; it blocks startup until complete
            await conn.run_sync(Base.metadata.create_all)
        logger.info("✅ Database tables created/verified successfully")

        # ── NEW: Initialize FastAPI-Limiter BEFORE serving traffic ──────────
        try:
            redis_conn = redis.from_url(
                settings.REDIS_URL,
                encoding="utf-8",
                decode_responses=True,
            )
            await FastAPILimiter.init(redis_conn, prefix="ratelimit")
            logger.info("🚦 Rate-limiter initialised (Redis %s)", settings.REDIS_URL)
            
            # Optional: clean slate for CI
            if os.getenv("FLUSH_TEST_LIMITS") == "1":
                try:
                    flushed = await redis_conn.flushdb()
                    logger.info("🧹 Redis FLUSHDB executed for test run, status=%s", flushed)
                except Exception as e:
                    logger.warning("Could not flush Redis in test mode: %s", e)
        except Exception as e:
            logger.warning("⚠️  Rate-limiter init failed: %s – continuing without limits", e)

        # Initialize application readiness
        logger.info("🚀 Startup event starting - _app_ready=%s", _app_ready)

        if settings.UNIT_TESTING:
            logger.info("🔒 UNIT_TESTING=1 – startup hooks bypassed")
            _app_ready = True
            logger.info("✅ _app_ready set to True (unit testing)")
        else:
            try:
                # Initialize ModelService first
                logger.info("🔧 Initializing ModelService")
                await model_service.initialize()
                logger.info("✅ ModelService initialized successfully")

                # Start background training tasks
                logger.info("🔄 Starting background training tasks")
                asyncio.create_task(model_service.startup())
                logger.info("✅ Background training tasks started")

                # Set ready to true after initialization (models will load in background)
                _app_ready = True
                logger.info("🚀 FastAPI ready – _app_ready=%s, health probes will pass immediately", _app_ready)

            except Exception as e:
                logger.error("❌ Startup event failed: %s", e)
                import traceback
                logger.error("❌ Startup traceback: %s", traceback.format_exc())
                # Set ready to true anyway so the API can serve requests
                _app_ready = True
                logger.warning("⚠️  Setting _app_ready=True despite startup errors")

        logger.info("🎯 Lifespan startup complete - _app_ready=%s", _app_ready)
        yield
    finally:
        logger.info("🔒 Shutting down…")
        try:
            await FastAPILimiter.close()           # NEW – graceful shutdown
        except Exception:
            pass
        await engine.dispose()

# ---------------------------------------------------------------------------
# Dependency injection helper
# ---------------------------------------------------------------------------
async def get_db() -> AsyncSession:
    """Yield a new DB session per request."""
    async with AsyncSessionLocal() as session:
        yield session



Overwriting api/app/db.py


In [22]:
%%writefile api/app/security.py
from __future__ import annotations
import os, logging, secrets
from datetime import datetime, timedelta
from typing import Optional

from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt, JWTError
from passlib.context import CryptContext
from pydantic import BaseModel

log = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# 1.  SECRET_KEY ***must*** be provided in the environment in production.
# ---------------------------------------------------------------------------
SECRET_KEY = os.getenv("SECRET_KEY")
if not SECRET_KEY:
    log.critical(
        "ENV variable SECRET_KEY is missing -- generating a temporary key. "
        "ALL issued JWTs will be invalid after a pod restart! "
        "Set it in Railway → Variables to disable this warning."
    )
    SECRET_KEY = secrets.token_urlsafe(32)   # fallback only for dev

ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30))

pwd_ctx = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/token")

class TokenData(BaseModel):
    username: Optional[str] = None

class LoginPayload(BaseModel):
    username: str
    password: str

async def get_credentials(request: Request) -> LoginPayload:
    """
    Accept either JSON **or** classic form‑encoded credentials.

    Order of precedence:
    1. If the request media‑type is JSON → parse it with Pydantic.
    2. Else parse as form-encoded data.
    """
    content_type = request.headers.get("content-type", "")
    
    if content_type.startswith("application/json"):
        # JSON branch
        try:
            body = await request.json()
            return LoginPayload(**body)
        except Exception as e:
            raise HTTPException(
                status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
                detail=f"Invalid JSON credentials: {e}",
            )
    else:
        # Form-encoded branch
        try:
            form_data = await request.form()
            username = form_data.get("username")
            password = form_data.get("password")
            
            if not username or not password:
                raise HTTPException(
                    status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
                    detail="username and password are required"
                )
            
            return LoginPayload(username=username, password=password)
        except Exception as e:
            raise HTTPException(
                status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
                detail=f"Invalid form credentials: {e}",
            )

def verify_password(raw: str, hashed: str) -> bool:
    return pwd_ctx.verify(raw, hashed)

def get_password_hash(pw: str) -> str:
    return pwd_ctx.hash(pw)

def create_access_token(subject: str) -> str:
    expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    return jwt.encode({"sub": subject, "exp": expire}, SECRET_KEY, algorithm=ALGORITHM)

async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if not username:
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
        return username
    except JWTError as exc:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from exc


Overwriting api/app/security.py


# additional models

In [23]:
%%writefile api/app/schemas/cancer.py
from pydantic import BaseModel, Field
from typing import List, Optional

class CancerFeatures(BaseModel):
    """Breast cancer diagnostic features."""
    mean_radius: float = Field(..., description="Mean of distances from center to points on perimeter")
    mean_texture: float = Field(..., description="Standard deviation of gray-scale values")
    mean_perimeter: float = Field(..., description="Mean size of the core tumor")
    mean_area: float = Field(..., description="Mean area of the core tumor")
    mean_smoothness: float = Field(..., description="Mean of local variation in radius lengths")
    mean_compactness: float = Field(..., description="Mean of perimeter^2 / area - 1.0")
    mean_concavity: float = Field(..., description="Mean of severity of concave portions of the contour")
    mean_concave_points: float = Field(..., description="Mean for number of concave portions of the contour")
    mean_symmetry: float = Field(..., description="Mean symmetry")
    mean_fractal_dimension: float = Field(..., description="Mean for 'coastline approximation' - 1")
    
    # SE features (standard error)
    se_radius: float = Field(..., description="Standard error of radius")
    se_texture: float = Field(..., description="Standard error of texture")
    se_perimeter: float = Field(..., description="Standard error of perimeter")
    se_area: float = Field(..., description="Standard error of area")
    se_smoothness: float = Field(..., description="Standard error of smoothness")
    se_compactness: float = Field(..., description="Standard error of compactness")
    se_concavity: float = Field(..., description="Standard error of concavity")
    se_concave_points: float = Field(..., description="Standard error of concave points")
    se_symmetry: float = Field(..., description="Standard error of symmetry")
    se_fractal_dimension: float = Field(..., description="Standard error of fractal dimension")
    
    # Worst features
    worst_radius: float = Field(..., description="Worst radius")
    worst_texture: float = Field(..., description="Worst texture")
    worst_perimeter: float = Field(..., description="Worst perimeter")
    worst_area: float = Field(..., description="Worst area")
    worst_smoothness: float = Field(..., description="Worst smoothness")
    worst_compactness: float = Field(..., description="Worst compactness")
    worst_concavity: float = Field(..., description="Worst concavity")
    worst_concave_points: float = Field(..., description="Worst concave points")
    worst_symmetry: float = Field(..., description="Worst symmetry")
    worst_fractal_dimension: float = Field(..., description="Worst fractal dimension")

class CancerPredictRequest(BaseModel):
    """Cancer prediction request (allows 'rows' alias)."""
    model_type: str = Field("bayes", description="Model type: 'bayes', 'logreg', or 'rf'")
    samples: List[CancerFeatures] = Field(
        ...,
        description="Breast-cancer feature vectors",
        alias="rows",
    )
    posterior_samples: Optional[int] = Field(
        None, ge=10, le=10_000, description="Posterior draws for uncertainty"
    )

    class Config:
        populate_by_name = True
        extra = "forbid"

class CancerPredictResponse(BaseModel):
    """Cancer prediction response."""
    predictions: List[str] = Field(..., description="Predicted diagnosis (M=malignant, B=benign)")
    probabilities: List[float] = Field(..., description="Probability of malignancy")
    uncertainties: Optional[List[float]] = Field(None, description="Uncertainty estimates (if requested)")
    input_received: List[CancerFeatures] = Field(..., description="Echo of input features") 

Overwriting api/app/schemas/cancer.py


In [24]:
%%writefile api/app/schemas/iris.py
from pydantic import BaseModel, Field
from typing import List, Optional

class IrisFeatures(BaseModel):
    """Iris measurement features."""
    sepal_length: float = Field(..., description="Sepal length in cm", ge=0, le=10)
    sepal_width: float = Field(..., description="Sepal width in cm", ge=0, le=10)
    petal_length: float = Field(..., description="Petal length in cm", ge=0, le=10)
    petal_width: float = Field(..., description="Petal width in cm", ge=0, le=10)

class IrisPredictRequest(BaseModel):
    """Iris prediction request (accepts legacy 'rows' alias)."""
    model_type: str = Field("rf", description="Model type: 'rf' or 'logreg'")
    samples: List[IrisFeatures] = Field(
        ...,
        description="List of iris measurements",
        alias="rows",
    )

    class Config:
        populate_by_name = True
        extra = "forbid"

class IrisPredictResponse(BaseModel):
    """Iris prediction response."""
    predictions: List[str] = Field(..., description="Predicted iris species")
    probabilities: List[List[float]] = Field(..., description="Class probabilities")
    input_received: List[IrisFeatures] = Field(..., description="Echo of input features") 

Overwriting api/app/schemas/iris.py


In [25]:
%%writefile api/app/ml/__init__.py
"""
ML sub-package – exposes built-in trainers so the service can import
`app.ml.builtin_trainers` with an absolute import.
"""

from .builtin_trainers import (
    train_iris_random_forest,
    train_iris_logreg,
    train_breast_cancer_bayes,
    train_breast_cancer_stub,
)

__all__ = [
    "train_iris_random_forest",
    "train_iris_logreg",
    "train_breast_cancer_bayes",
    "train_breast_cancer_stub",
] 


Overwriting api/app/ml/__init__.py


In [26]:
%%writefile api/app/ml/utils.py
# app/ml/utils.py  (minimal, cross‑platform)

def configure_pytensor_compiler(*_, **__):  # noqa: D401,E501
    """
    Stub kept for backward‑compatibility.

    The project now uses the **JAX backend**, so PyTensor never calls a C
    compiler.  This function therefore does nothing and always returns True.
    """
    return True 


Overwriting api/app/ml/utils.py


In [27]:
%%writefile api/app/ml/builtin_trainers.py
# api/ml/builtin_trainers.py
"""
Built-in trainers for Iris RF and Breast-Cancer Bayesian LogReg.
Executed automatically by ModelService when a model is missing.
"""

import logging
logger = logging.getLogger(__name__)

from pathlib import Path
import mlflow, mlflow.sklearn, mlflow.pyfunc
from sklearn.datasets import load_iris, load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import pandas as pd
import numpy as np
import tempfile
import pickle
import warnings
import subprocess
import os
import platform

# Conditional imports for heavy dependencies
if os.getenv("UNIT_TESTING") != "1" and os.getenv("SKIP_BACKGROUND_TRAINING") != "1":
    import pymc as pm
    import arviz as az
else:
    pm = None
    az = None

# ── NEW: Configure MLflow to use local file storage ─────────────────────────
# Set MLflow to use local file storage instead of remote server
os.environ.setdefault("MLFLOW_TRACKING_URI", "file:./mlruns_local")
os.environ.setdefault("MLFLOW_REGISTRY_URI", "file:./mlruns_local")

# Configure MLflow tracking URI immediately
mlflow.set_tracking_uri("file:./mlruns_local")
# ──────────────────────────────────────────────────────────────────────────────

MLFLOW_EXPERIMENT = "ml_fullstack_models"

# Only set experiment if not in unit test mode and after tracking URI is set
if os.getenv("UNIT_TESTING") != "1":
    try:
        mlflow.set_experiment(MLFLOW_EXPERIMENT)
    except Exception as e:
        logging.warning(f"Could not set MLflow experiment: {e}")

# -----------------------------------------------------------------------------
#  IRIS – point-estimate Random-Forest (enhanced with better parameters)
# -----------------------------------------------------------------------------
def train_iris_random_forest(
    n_estimators: int = 300,
    max_depth: int | None = None,
    random_state: int = 42
) -> str:
    """
    Train + register a Random-Forest on the Iris data and push it to MLflow.
    Returns the run_id (string). Enhanced with better parameters and stratified split.
    """
    iris = load_iris(as_frame=True)
    X, y = iris.data, iris.target
    X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.25,
                                              stratify=y, random_state=random_state)

    # Enhanced Random Forest with better parameters
    rf = RandomForestClassifier(
        n_estimators=n_estimators,
        max_depth=max_depth,
        random_state=random_state,
        n_jobs=-1,  # Use all available cores
        class_weight='balanced'  # Handle any class imbalance
    ).fit(X_tr, y_tr)

    preds = rf.predict(X_te)
    metrics = {
        "accuracy":  accuracy_score(y_te, preds),
        "f1_macro":  f1_score(y_te, preds, average="macro"),
        "precision_macro": precision_score(y_te, preds, average="macro"),
        "recall_macro":    recall_score(y_te, preds, average="macro"),
    }

    with mlflow.start_run(run_name="iris_random_forest") as run:
        # Log hyperparameters
        mlflow.log_params({
            "n_estimators": n_estimators,
            "max_depth": max_depth,
            "random_state": random_state
        })

        # Log metrics
        mlflow.log_metrics(metrics)

        # Create a custom pyfunc wrapper that exposes both predict and predict_proba
        class IrisRFWrapper(mlflow.pyfunc.PythonModel):
            def __init__(self, model):
                self.model = model

            def predict(self, model_input, params=None):
                # Return class probabilities for pyfunc interface
                # Convert to numpy array if it's a DataFrame
                if hasattr(model_input, 'values'):
                    X = model_input.values
                else:
                    X = model_input
                return self.model.predict_proba(X)

            def predict_proba(self, X):
                # Expose predict_proba for direct access
                if hasattr(X, 'values'):
                    X = X.values
                return self.model.predict_proba(X)

            def predict_classes(self, X):
                # Expose class prediction
                if hasattr(X, 'values'):
                    X = X.values
                return self.model.predict(X)

        iris_wrapper = IrisRFWrapper(rf)

        # Log model with proper signature
        mlflow.pyfunc.log_model(
            artifact_path="model",
            python_model=iris_wrapper,
            registered_model_name="iris_random_forest",
            input_example=X.head(),
            signature=mlflow.models.signature.infer_signature(X, iris_wrapper.predict(X))
        )
        return run.info.run_id

# -----------------------------------------------------------------------------
#  IRIS – logistic-regression trainer (NEW)
# -----------------------------------------------------------------------------

def train_iris_logreg(
    C: float = 1.0,
    max_iter: int = 400,
    random_state: int = 42,
) -> str:
    """
    Train and register a **multinomial Logistic Regression** model on the Iris
    dataset.  Returns the MLflow run_id so the caller can reload the model.
    """
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression

    # Load and split data (stratified)
    iris = load_iris(as_frame=True)
    X, y = iris.data, iris.target
    X_tr, X_te, y_tr, y_te = train_test_split(
        X, y, test_size=0.25, stratify=y, random_state=random_state
    )

    # Fit classifier
    clf = LogisticRegression(
        C=C,
        max_iter=max_iter,
        multi_class="multinomial",
        solver="lbfgs",
        n_jobs=-1,
        random_state=random_state,
    ).fit(X_tr, y_tr)

    # ------------------- wrap in consistent pyfunc --------------------------
    class IrisLogRegWrapper(mlflow.pyfunc.PythonModel):
        """Expose predict() as class probabilities so the service can rely on it."""

        def __init__(self, model):
            self.model = model

        def predict(self, model_input, params=None):  # noqa: D401 – MLflow signature
            X_ = model_input.values if hasattr(model_input, "values") else model_input
            return self.model.predict_proba(X_)

        # Explicit alias so hasattr(model, "predict_proba") works post-load
        def predict_proba(self, X):
            X_ = X.values if hasattr(X, "values") else X
            return self.model.predict_proba(X_)

    preds = clf.predict(X_te)
    metrics = {
        "accuracy": accuracy_score(y_te, preds),
        "f1_macro": f1_score(y_te, preds, average="macro"),
        "precision_macro": precision_score(y_te, preds, average="macro"),
        "recall_macro": recall_score(y_te, preds, average="macro"),
    }

    with mlflow.start_run(run_name="iris_logreg") as run:
        mlflow.log_params({"C": C, "max_iter": max_iter, "random_state": random_state})
        mlflow.log_metrics(metrics)

        mlflow.pyfunc.log_model(
            artifact_path="model",
            python_model=IrisLogRegWrapper(clf),
            registered_model_name="iris_logreg",
            input_example=X.head(),
            signature=mlflow.models.signature.infer_signature(X, clf.predict_proba(X)),
        )
        return run.info.run_id

# -----------------------------------------------------------------------------
#  BREAST-CANCER STUB – ultra-fast fallback model
# -----------------------------------------------------------------------------
def train_breast_cancer_stub(random_state: int = 42) -> str:
    """
    *Ultra-fast* fallback –  < 100 ms on any laptop.
    Trains vanilla LogisticRegression so the API can
    answer probability queries while the PyMC model cooks.
    """
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score
    import mlflow, tempfile, pickle, pandas as pd

    X, y = load_breast_cancer(return_X_y=True, as_frame=True)
    Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.3,
                                          stratify=y, random_state=random_state)

    clf = LogisticRegression(max_iter=200, n_jobs=-1).fit(Xtr, ytr)

    class CancerStubWrapper(mlflow.pyfunc.PythonModel):
        """Return P(malignant) both via predict() and predict_proba()."""

        def __init__(self, model):
            self.model = model

        def _pp(self, X):
            X_ = X.values if hasattr(X, "values") else X
            return self.model.predict_proba(X_)

        def predict(self, model_input, params=None):
            # Return 1-D array of malignant probabilities
            return self._pp(model_input)[:, 1]

        def predict_proba(self, X):
            return self._pp(X)

    acc = accuracy_score(yte, clf.predict(Xte))

    with tempfile.TemporaryDirectory() as td, mlflow.start_run(run_name="breast_cancer_stub") as run:
        mlflow.log_metric("accuracy", acc)
        mlflow.pyfunc.log_model(
            "model",
            python_model=CancerStubWrapper(clf),
            registered_model_name="breast_cancer_stub",
            input_example=X.head(),
            signature=mlflow.models.signature.infer_signature(X, clf.predict_proba(X)),
        )
        return run.info.run_id

# -----------------------------------------------------------------------------
#  BREAST-CANCER – hierarchical Bayesian logistic regression
# -----------------------------------------------------------------------------

def train_breast_cancer_bayes(
    draws: int = 1000,
    tune: int = 1000,
    target_accept: float = 0.90,
) -> str:
    """
    Hierarchical Bayesian logistic‑regression with varying intercepts by
    **mean_texture quintile**.

    * Uses **NumPyro NUTS** backend → **no C compilation** on Windows.  
    * Logs model to MLflow exactly like before so FastAPI can reload it.
    """

    import pymc as pm                      # PyMC ≥5.9
    import pandas as pd, numpy as np
    from sklearn.datasets import load_breast_cancer
    from sklearn.preprocessing import StandardScaler
    import mlflow, tempfile, pickle
    from pathlib import Path

    # Note: PyTensor config is set by env_sanitizer before import
    # No runtime config changes needed - they're already applied

    # 1️⃣  data ----------------------------------------------------------------
    X_df, y = load_breast_cancer(as_frame=True, return_X_y=True)
    quint, edges = pd.qcut(X_df["mean texture"], 5, labels=False, retbins=True)
    g        = np.asarray(quint, dtype="int64")          # 0‥4
    scaler   = StandardScaler().fit(X_df)
    Xs       = scaler.transform(X_df)

    # 2️⃣  model ---------------------------------------------------------------
    coords = {"group": np.arange(5)}
    with pm.Model(coords=coords) as m:
        α     = pm.Normal("α", 0.0, 1.0, dims="group")   # varying intercepts
        β     = pm.Normal("β", 0.0, 1.0, shape=Xs.shape[1])
        logit = α[g] + pm.math.dot(Xs, β)
        pm.Bernoulli("obs", logit_p=logit, observed=y)

        idata = pm.sample(
            draws=draws,
            tune=tune,
            chains=4,
            nuts_sampler="numpyro",        # <-- magic line
            target_accept=target_accept,
            progressbar=False,
        )

    # 3️⃣  lightweight pyfunc wrapper -----------------------------------------
    class _HierBayesWrapper(mlflow.pyfunc.PythonModel):
        def __init__(self, trace, sc, ed, cols):
            self.trace, self.scaler, self.edges, self.cols = trace, sc, ed, cols

        def _quint(self, df):
            tex = df["mean texture"].to_numpy()
            return np.clip(np.digitize(tex, self.edges, right=False), 0, 4)

        def predict(self, X, params=None):
            df  = X if isinstance(X, pd.DataFrame) else pd.DataFrame(X, columns=self.cols)
            xs  = self.scaler.transform(df)
            g   = self._quint(df)
            αg  = self.trace.posterior["α"].median(("chain", "draw")).values
            β   = self.trace.posterior["β"].median(("chain", "draw")).values
            log = αg[g] + np.dot(xs, β)
            return 1.0 / (1.0 + np.exp(-log))

    wrapper = _HierBayesWrapper(idata, scaler, edges[1:-1], X_df.columns.tolist())
    acc     = float(((wrapper.predict(X_df) > .5).astype(int) == y).mean())

    # 4️⃣  MLflow logging (unchanged) -----------------------------------------
    with tempfile.TemporaryDirectory() as td, mlflow.start_run(run_name="breast_cancer_bayes") as run:
        sc_path = Path(td) / "scaler.pkl"
        pickle.dump(scaler, open(sc_path, "wb"))
        mlflow.log_params(dict(draws=draws, tune=tune, target_accept=target_accept))
        mlflow.log_metric("accuracy", acc)
        mlflow.pyfunc.log_model(
            "model",
            python_model=wrapper,
            artifacts={"scaler": str(sc_path)},
            registered_model_name="breast_cancer_bayes",
            input_example=X_df.head(),
            signature=mlflow.models.signature.infer_signature(X_df, wrapper.predict(X_df)),
        )
        return run.info.run_id


Overwriting api/app/ml/builtin_trainers.py


In [28]:
%%writefile api/app/services/ml/model_service.py
"""
Model service – self-healing startup with background training.
"""

from __future__ import annotations
import asyncio, logging, os, time, socket
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, List, Tuple, Optional

import mlflow, pandas as pd, numpy as np
from mlflow.tracking import MlflowClient
from mlflow.exceptions import MlflowException

from app.core.config import settings
from app.ml.builtin_trainers import (
    train_iris_random_forest,
    train_iris_logreg,  # NEW
    train_breast_cancer_bayes,
    train_breast_cancer_stub,
)

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Cancer column mapping: Pydantic field names ➜ training column names
# ---------------------------------------------------------------------------
_CANCER_COLMAP: dict[str, str] = {
    # Means
    "mean_radius": "mean radius",
    "mean_texture": "mean texture",
    "mean_perimeter": "mean perimeter",
    "mean_area": "mean area",
    "mean_smoothness": "mean smoothness",
    "mean_compactness": "mean compactness",
    "mean_concavity": "mean concavity",
    "mean_concave_points": "mean concave points",
    "mean_symmetry": "mean symmetry",
    "mean_fractal_dimension": "mean fractal dimension",
    # SE
    "se_radius": "radius error",
    "se_texture": "texture error",
    "se_perimeter": "perimeter error",
    "se_area": "area error",
    "se_smoothness": "smoothness error",
    "se_compactness": "compactness error",
    "se_concavity": "concavity error",
    "se_concave_points": "concave points error",
    "se_symmetry": "symmetry error",
    "se_fractal_dimension": "fractal dimension error",
    # Worst
    "worst_radius": "worst radius",
    "worst_texture": "worst texture",
    "worst_perimeter": "worst perimeter",
    "worst_area": "worst area",
    "worst_smoothness": "worst smoothness",
    "worst_compactness": "worst compactness",
    "worst_concavity": "worst concavity",
    "worst_concave_points": "worst concave points",
    "worst_symmetry": "worst symmetry",
    "worst_fractal_dimension": "worst fractal dimension",
}

def _rename_cancer_columns(df: pd.DataFrame) -> pd.DataFrame:
    """
    Ensure DataFrame columns match the training schema used by MLflow artefacts.
    Unknown columns are left untouched so legacy models still work.
    """
    return df.rename(columns=_CANCER_COLMAP)

# Trainer mapping for self-healing
TRAINERS = {
    "iris_random_forest": train_iris_random_forest,
    "iris_logreg":        train_iris_logreg,  # NEW
    "breast_cancer_bayes": train_breast_cancer_bayes,
    "breast_cancer_stub":  train_breast_cancer_stub,
}

class ModelService:
    """
    Self-healing model service that loads existing models and schedules
    background training for missing ones.
    """

    _EXECUTOR = ThreadPoolExecutor(max_workers=2)

    def __init__(self) -> None:
        self._unit_test_mode = settings.UNIT_TESTING
        self.initialized = False

        # 🚫 Heavy clients only when NOT unit-testing
        self.client = None if self._unit_test_mode else None  # Will be set in initialize()
        self.mlflow_client = None

        self.models: Dict[str, Any] = {}
        self.status: Dict[str, str] = {
            "iris_random_forest": "missing",
            "iris_logreg":        "missing",  # NEW
            "breast_cancer_bayes": "missing",
            "breast_cancer_stub": "missing",
        }

    async def initialize(self) -> None:
        """
        Connect to MLflow – fall back to local file store if the configured
        tracking URI is unreachable *or* the client is missing critical methods
        (e.g. when mlflow-skinny accidentally shadows the full package).
        """
        if self.initialized:
            return

        # Log critical dependency versions for diagnostics
        try:
            import pytensor
            logger.info("📦 PyTensor version: %s", pytensor.__version__)
        except ImportError:
            logger.warning("⚠️  PyTensor not available")
        except Exception as e:
            logger.warning("⚠️  Could not determine PyTensor version: %s", e)

        def _needs_fallback(client) -> bool:
            # any missing attr is a strong signal we are on mlflow-skinny
            return not callable(getattr(client, "list_experiments", None))

        try:
            mlflow.set_tracking_uri(settings.MLFLOW_TRACKING_URI)
            self.mlflow_client = MlflowClient(settings.MLFLOW_TRACKING_URI)

            if _needs_fallback(self.mlflow_client):
                raise AttributeError("list_experiments not implemented – skinny build detected")

            # minimal probe (cheap & always present)
            self.mlflow_client.search_experiments(max_results=1)
            logger.info("🟢  Connected to MLflow @ %s", settings.MLFLOW_TRACKING_URI)

        except (MlflowException, socket.gaierror, AttributeError) as exc:
            logger.warning("🔄  Falling back to local MLflow store – %s", exc)
            mlflow.set_tracking_uri("file:./mlruns_local")
            self.mlflow_client = MlflowClient("file:./mlruns_local")
            logger.info("📂  Using local file store ./mlruns_local")

        await self._load_models()
        self.initialized = True

    async def _load_models(self) -> None:
        """Load existing models from MLflow."""
        await self._try_load("iris_random_forest")
        await self._try_load("iris_logreg")      # NEW
        await self._try_load("breast_cancer_bayes")
        await self._try_load("breast_cancer_stub")

    async def startup(self, auto_train: bool | None = None) -> None:
        """
        Faster: serve stub immediately; heavy Bayesian job in background.
        """
        if self._unit_test_mode:
            logger.info("🔒 UNIT_TESTING=1 – skipping model loading")
            return                      # 👉 nothing else runs

        # Initialize MLflow connection first
        await self.initialize()

        if settings.SKIP_BACKGROUND_TRAINING:
            logger.warning("⏩ SKIP_BACKGROUND_TRAINING=1 – models will load on-demand")
            # We still *try* to load existing artefacts so prod works
            await self._try_load("iris_random_forest")
            await self._try_load("iris_logreg")
            await self._try_load("breast_cancer_bayes")
            return

        auto = auto_train if auto_train is not None else settings.AUTO_TRAIN_MISSING
        logger.info("🔄 Model-service startup (auto_train=%s)", auto)

        # 1️⃣ try to load whatever already exists
        await self._try_load("iris_random_forest")
        await self._try_load("iris_logreg")

        # 2️⃣ Load bayes – if exists we're done
        if not await self._try_load("breast_cancer_bayes"):
            # 3️⃣ Ensure stub is *synchronously* available
            if not await self._try_load("breast_cancer_stub"):
                logger.info("Training stub cancer model …")
                await asyncio.get_running_loop().run_in_executor(
                    self._EXECUTOR, train_breast_cancer_stub
                )
                await self._try_load("breast_cancer_stub")

            # 4️⃣ Fire full PyMC build in background unless disabled
            if not settings.SKIP_BACKGROUND_TRAINING:
                logger.info("Scheduling full Bayesian retrain in background")
                asyncio.create_task(
                    self._train_and_reload("breast_cancer_bayes", train_breast_cancer_bayes)
                )

        # 5️⃣ Train iris models if missing
        if not await self._try_load("iris_random_forest"):
            logger.info("Training iris random-forest …")
            await asyncio.get_running_loop().run_in_executor(
                self._EXECUTOR, train_iris_random_forest
            )
            await self._try_load("iris_random_forest")

        if not await self._try_load("iris_logreg"):
            logger.info("Training iris logistic-regression …")
            await asyncio.get_running_loop().run_in_executor(
                self._EXECUTOR, train_iris_logreg
            )
            await self._try_load("iris_logreg")

    async def _try_load(self, name: str) -> None:
        """Try to load a model and update status."""
        model = await self._load_production_model(name)
        if model:
            self.models[name] = model
            self.status[name] = "loaded"
            logger.info("✅ %s loaded", name)
            return True
        return False

    async def _train_and_reload(self, name: str, trainer) -> None:
        """Train a model in background and reload it, with verbose phase logs."""
        try:
            t0 = time.perf_counter()
            logger.info("🏗️  BEGIN training %s", name)
            self.status[name] = "training"

            loop = asyncio.get_running_loop()
            await loop.run_in_executor(self._EXECUTOR, trainer)

            logger.info("📦 Training %s complete in %.1fs – re-loading", name,
                        time.perf_counter() - t0)
            model = await self._load_production_model(name)
            if not model:
                raise RuntimeError(f"{name} trained but could not be re-loaded")

            self.models[name] = model
            self.status[name] = "loaded"
            logger.info("✅ %s trained & loaded", name)

        except Exception as exc:
            self.status[name] = "failed"
            logger.error("❌ %s failed: %s", name, exc, exc_info=True)  # ← keeps trace
            # NEW: persist last_error for UI / debug endpoint
            self.status[f"{name}_last_error"] = str(exc)

    async def _load_production_model(self, name: str) -> Optional[Any]:
        """
        1. Registry 'Production' stage → load.  
        2. Otherwise most recent run with runName == name.
        Returns None if not found.
        """
        try:
            versions = self.mlflow_client.search_model_versions(f"name='{name}'")
            prod = [v for v in versions if v.current_stage == "Production"]
            if prod:
                uri = f"models:/{name}/{prod[0].version}"
                logger.info("↪︎  Loading %s from registry:%s", name, prod[0].version)
                return mlflow.pyfunc.load_model(uri)
        except MlflowException:
            pass

        # Fallback – scan experiments for latest run
        runs = []
        for exp in self.mlflow_client.search_experiments():
            runs.extend(self.mlflow_client.search_runs(
                [exp.experiment_id],
                f"tags.mlflow.runName = '{name}'",
                order_by=["attributes.start_time DESC"],
                max_results=1))
        if runs:
            uri = f"runs:/{runs[0].info.run_id}/model"
            logger.info("↪︎  Loading %s from latest run:%s", name, runs[0].info.run_id)
            return mlflow.pyfunc.load_model(uri)
        return None

    # Manual training endpoints (for UI)
    async def train_iris(self) -> None:
        await self._train_and_reload("iris_random_forest", TRAINERS["iris_random_forest"])

    async def train_cancer(self) -> None:
        await self._train_and_reload("breast_cancer_bayes", TRAINERS["breast_cancer_bayes"])

    # Predict methods (unchanged from your previous version)
    async def predict_iris(
        self,
        features: List[Dict[str, float]],
        model_type: str = "rf",
    ) -> Tuple[List[str], List[List[float]]]:
        """
        Predict Iris species from measurements.

        Args:
            features: List of iris measurements as dictionaries
            model_type: Model type to use (only 'rf' supported)

        Returns:
            Tuple of (predicted_class_names, class_probabilities)
        """
        if model_type not in ("rf", "logreg"):
            raise ValueError("model_type must be 'rf' or 'logreg'")

        model_name = "iris_random_forest" if model_type == "rf" else "iris_logreg"
        model = self.models.get(model_name)
        if not model:
            raise RuntimeError(f"{model_name} not loaded")

        # Convert to DataFrame with proper column names (matching training data)
        X_df = pd.DataFrame([{
            "sepal length (cm)": sample["sepal_length"],
            "sepal width (cm)": sample["sepal_width"], 
            "petal length (cm)": sample["petal_length"],
            "petal width (cm)": sample["petal_width"]
        } for sample in features])

        # Obtain probabilities in a backward-compatible way
        if hasattr(model, "predict_proba"):
            probs = model.predict_proba(X_df)
        else:
            # Legacy artefact – derive 1-hot probas from class indices
            preds_idx = model.predict(X_df)
            import numpy as _np
            probs = _np.zeros((len(preds_idx), 3), dtype=float)
            probs[_np.arange(len(preds_idx)), preds_idx.astype(int)] = 1.0

        # Ensure numpy array then list list
        preds = probs.argmax(axis=1)                 # numerical class indices

        # Map numerical classes to species names
        class_names = ["setosa", "versicolor", "virginica"]
        pred_names = [class_names[i] for i in preds]

        return pred_names, probs.tolist()

    async def predict_cancer(
        self,
        features: List[Dict[str, float]],
        model_type: str = "bayes",
        posterior_samples: Optional[int] = None,
    ) -> Tuple[List[str], List[float], Optional[List[Tuple[float, float]]]]:
        """
        Predict breast cancer diagnosis from features using hierarchical Bayesian model.
        Falls back to stub model if Bayesian model is not available.

        Args:
            features: List of cancer measurements as dictionaries
            model_type: Model type to use ('bayes' or 'stub')
            posterior_samples: Number of posterior samples for uncertainty (Bayesian only)

        Returns:
            Tuple of (predicted_labels, probabilities, uncertainty_intervals)
        """
        # Determine which model to use
        if model_type == "bayes":
            model = self.models.get("breast_cancer_bayes")
            if not model:
                # Fall back to stub model
                model = self.models.get("breast_cancer_stub")
                if not model:
                    raise RuntimeError("No cancer model available")
                logger.info("Using stub cancer model (Bayesian model not ready)")
        elif model_type == "stub":
            model = self.models.get("breast_cancer_stub")
            if not model:
                raise RuntimeError("Stub cancer model not loaded")
        else:
            raise ValueError("model_type must be 'bayes' or 'stub'")

        # Convert to DataFrame with proper column names
        X_df_raw = pd.DataFrame(features)
        X_df = _rename_cancer_columns(X_df_raw)

        # Get predictions
        if model_type == "bayes" and "breast_cancer_bayes" in self.models:
            # Use Bayesian model with uncertainty
            probs = model.predict(X_df)
            labels = ["malignant" if p > 0.5 else "benign" for p in probs]
        else:
            # Use stub model (sklearn LogisticRegression)
            if hasattr(model, "predict_proba"):
                probs_full = model.predict_proba(X_df)
                probs = probs_full[:, 1]
            else:
                # Legacy artefact: model.predict returns hard class 0/1
                preds_bin = model.predict(X_df).astype(float)
                probs = preds_bin  # deterministic 0/1 acts as prob
            labels = ["malignant" if p > 0.5 else "benign" for p in probs]

        # Compute uncertainty intervals if requested (Bayesian model only)
        ci = None
        if posterior_samples and model_type == "bayes" and "breast_cancer_bayes" in self.models:
            try:
                # Access the underlying python model to get the trace
                python_model = model.unwrap_python_model()

                # Access posterior samples for uncertainty quantification
                draws = python_model.trace.posterior
                αg = draws["α_group"].stack(samples=("chain", "draw"))
                β = draws["β"].stack(samples=("chain", "draw"))

                # Get group indices and standardized features
                g = python_model._group_index(X_df)
                Xs = python_model.scaler.transform(X_df)

                # Compute posterior predictive samples
                logits = αg.values[:, g] + np.dot(β.values.T, Xs.T)      # shape (S, N)
                pp = 1 / (1 + np.exp(-logits))

                # Compute 95% credible intervals
                lo, hi = np.percentile(pp, [2.5, 97.5], axis=0)
                ci = list(zip(lo.tolist(), hi.tolist()))

            except Exception as e:
                logger.warning(f"Failed to compute uncertainty intervals: {e}")
                ci = None

        return labels, probs.tolist(), ci


# Global singleton
model_service = ModelService()


Overwriting api/app/services/ml/model_service.py


In [29]:
%%writefile api/app/main.py
import logging
import os
import asyncio
from fastapi import FastAPI, Request, Depends, BackgroundTasks, status, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
import time

from pydantic import BaseModel

# ── NEW: Fix ML backend configuration before any JAX imports ───────────────────────────
from .utils.env_sanitizer import fix_ml_backends
fix_ml_backends()
# ──────────────────────────────────────────────────────────────────────────

# ── NEW: Rate limiting imports ─────────────────────────────────────────────────────────
from fastapi_limiter import FastAPILimiter
import redis.asyncio as redis
# ────────────────────────────────────────────────────────────────────────────────────────

# ── NEW: Concurrency limiting imports ────────────────────────────────────────────────
from .middleware.concurrency import ConcurrencyLimiter
# ────────────────────────────────────────────────────────────────────────────────────────

from .db import lifespan, get_db, get_app_ready
from .security import create_access_token, get_current_user, verify_password
from .crud import get_user_by_username
from .schemas.iris import IrisPredictRequest, IrisPredictResponse, IrisFeatures
from .schemas.cancer import CancerPredictRequest, CancerPredictResponse, CancerFeatures
from .services.ml.model_service import model_service
from .core.config import settings
from .deps.limits import default_limit, heavy_limit, login_limit, training_limit, light_limit
from .security import LoginPayload, get_credentials

# ── NEW: guarantee log directory exists ───────────────────────────
os.makedirs("logs", exist_ok=True)
# ──────────────────────────────────────────────────────────────────

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Pydantic models
class Payload(BaseModel):
    count: int

class PredictionRequest(BaseModel):
    data: Payload

class PredictionResponse(BaseModel):
    prediction: str
    confidence: float
    input_received: Payload  # Echo back the input for verification

class Token(BaseModel):
    access_token: str
    token_type: str

app = FastAPI(
    title="FastAPI + React ML App",
    version="1.0.0",
    docs_url="/api/v1/docs",
    redoc_url="/api/v1/redoc",
    openapi_url="/api/v1/openapi.json",
    swagger_ui_parameters={"persistAuthorization": True},
    lifespan=lifespan,  # register startup/shutdown events
)

# ── Rate limiting is now initialized in lifespan() ────────────────────────────────────
# ────────────────────────────────────────────────────────────────────────────────────────

# Configure CORS with environment-based origins
origins_env = settings.ALLOWED_ORIGINS
origins: list[str] = [o.strip() for o in origins_env.split(",")] if origins_env != "*" else ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, replace with specific origins
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── NEW: Add concurrency limiting middleware ──────────────────────────────────────────
app.add_middleware(ConcurrencyLimiter, max_concurrent=4)
# ────────────────────────────────────────────────────────────────────────────────────────

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    """Measure request time and add X-Process-Time header."""
    start = time.perf_counter()
    response = await call_next(request)
    elapsed = time.perf_counter() - start
    response.headers["X-Process-Time"] = f"{elapsed:.4f}"
    return response

# Health check endpoint
@app.get("/api/v1/health")
async def health_check():
    """Basic health check - always returns 200 if server is running."""
    return {"status": "healthy", "timestamp": time.time()}

@app.get("/api/v1/hello")
async def hello(current_user: str = Depends(get_current_user)):
    """Simple endpoint for token validation."""
    return {"message": f"Hello {current_user}!", "status": "authenticated"}

@app.get("/api/v1/ready")
async def ready():
    """Basic readiness check."""
    return {"ready": get_app_ready()}

@app.post("/api/v1/token", response_model=Token, dependencies=[Depends(login_limit)])
async def login(
    creds: LoginPayload = Depends(get_credentials),
    db: AsyncSession = Depends(get_db),
):
    """
    Issue a JWT. Accepts **either**
    • JSON {"username": "...", "password": "..."}  *or*
    • classic x‑www‑form‑urlencoded.
    """
    # 1️⃣ readiness gate
    if not get_app_ready():
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="Backend still loading models. Try again in a moment.",
            headers={"Retry‑After": "10"},
        )

    # 2️⃣ verify credentials
    user = await get_user_by_username(db, creds.username)
    if not user or not verify_password(creds.password, user.hashed_password):
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
                            detail="Invalid credentials")

    # 3️⃣ issue token
    token = create_access_token(subject=user.username)
    return Token(access_token=token, token_type="bearer")

@app.get("/api/v1/ready/full")
async def ready_full() -> dict:
    """
    Extended readiness probe:
    - ready: API server is ready to accept requests (login allowed)
    - model_status: dict of {model_name: status} where status is 'loaded'|'training'|'failed'|'missing'
    - all_models_loaded: true if all models are in 'loaded' state
    """
    # Allow login if API is ready, regardless of model status
    ready_for_login = get_app_ready()

    expected = {"iris_random_forest", "breast_cancer_bayes"}
    loaded = set(model_service.models.keys())
    training = set(model_service.status.keys()) - loaded

    response = {
        "ready": ready_for_login,  # Allow login immediately
        "model_status": model_service.status,
        "all_models_loaded": all(s == "loaded" for s in model_service.status.values()),
        "models": {m: (m in loaded) for m in expected},
        "training": list(training)
    }

    logger.debug("READY endpoint – _app_ready=%s, response=%s", get_app_ready(), response)
    return response

# ── Alias routes (no auth, not shown in OpenAPI) ────────────────────────────
@app.get("/ready/full", include_in_schema=False)
async def ready_full_alias():
    """Alias for front-end calls that miss the /api/v1 prefix."""
    return await ready_full()

@app.get("/health", include_in_schema=False)
async def health_alias():
    """Alias for plain /health (SPA hits it before it knows the prefix)."""
    return await health_check()

@app.post("/token", include_in_schema=False)
async def login_alias(request: Request):
    """
    Alias: accept /token like /api/v1/token.
    Keeps the OAuth2PasswordRequestForm semantics without exposing clutter in docs.
    """
    from fastapi import Form

    # Parse form data manually to match OAuth2PasswordRequestForm behavior
    form_data = await request.form()
    username = form_data.get("username")
    password = form_data.get("password")

    if not username or not password:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail="username and password are required"
        )

    # Create a mock OAuth2PasswordRequestForm object
    class MockForm:
        def __init__(self, username, password):
            self.username = username
            self.password = password

    mock_form = MockForm(username, password)

    # Reuse the existing login logic
    db = await get_db().__anext__()
    return await login(mock_form, db)

@app.post("/iris/predict", include_in_schema=False)
async def iris_predict_alias(request: Request):
    """Alias for /api/v1/iris/predict"""
    from .schemas.iris import IrisPredictRequest

    # Parse JSON body
    body = await request.json()
    iris_request = IrisPredictRequest(**body)

    # Reuse the existing prediction logic without authentication for testing
    background_tasks = BackgroundTasks()
    current_user = "test_user"  # Skip authentication for alias endpoints
    return await predict_iris(iris_request, background_tasks, current_user)

@app.post("/cancer/predict", include_in_schema=False)
async def cancer_predict_alias(request: Request):
    """Alias for /api/v1/cancer/predict"""
    from .schemas.cancer import CancerPredictRequest

    # Parse JSON body
    body = await request.json()
    cancer_request = CancerPredictRequest(**body)

    # Reuse the existing prediction logic without authentication for testing
    background_tasks = BackgroundTasks()
    current_user = "test_user"  # Skip authentication for alias endpoints
    return await predict_cancer(cancer_request, background_tasks, current_user)

# ----- on-demand training endpoints ----------------------------------
@app.post("/api/v1/iris/train", status_code=202, dependencies=[Depends(training_limit)])
async def train_iris(background_tasks: BackgroundTasks,
                     current_user: str = Depends(get_current_user)):
    background_tasks.add_task(model_service.train_iris)
    return {"status": "started"}

@app.post("/api/v1/cancer/train", status_code=202, dependencies=[Depends(training_limit)])
async def train_cancer(background_tasks: BackgroundTasks,
                       current_user: str = Depends(get_current_user)):
    background_tasks.add_task(model_service.train_cancer)
    return {"status": "started"}

@app.get("/api/v1/iris/ready")
async def iris_ready():
    """Check if Iris model is loaded and ready."""
    return {"loaded": "iris_random_forest" in model_service.models}

@app.get("/api/v1/cancer/ready")
async def cancer_ready():
    """Check if Cancer model is loaded and ready."""
    return {"loaded": "breast_cancer_bayes" in model_service.models}

@app.post(
    "/api/v1/iris/predict",
    response_model=IrisPredictResponse,
    status_code=status.HTTP_200_OK,
    dependencies=[Depends(light_limit)]
)
async def predict_iris(
    request: IrisPredictRequest,
    background_tasks: BackgroundTasks,
    current_user: str = Depends(get_current_user),
):
    """
    Predict iris species from measurements.

    Example request:
        {
            "model_type": "rf",
            "samples": [
                {
                    "sepal_length": 5.1,
                    "sepal_width": 3.5,
                    "petal_length": 1.4,
                    "petal_width": 0.2
                }
            ]
        }
    """
    logger.info(f"User {current_user} called /iris/predict with {len(request.samples)} samples")
    logger.debug(f"→ Iris payload: {request.samples}")

    # Check if requested iris model is loaded; return 503 if not yet available
    if (
        request.model_type == "rf" and "iris_random_forest" not in model_service.models
    ) or (
        request.model_type == "logreg" and "iris_logreg" not in model_service.models
    ):
        logger.warning("Iris model not ready - returning 503")
        raise HTTPException(
            status_code=503,
            detail="Iris model is still loading. Please try again in a few seconds.",
            headers={"Retry-After": "30"},
        )

    # Convert Pydantic models to dicts
    features = [sample.dict() for sample in request.samples]
    logger.debug(f"→ Iris features: {features}")

    # Get predictions
    predictions, probabilities = await model_service.predict_iris(
        features=features,
        model_type=request.model_type
    )
    logger.debug(f"← Iris predictions: {predictions}")
    logger.debug(f"← Iris probabilities: {probabilities}")

    result = {
        "predictions": predictions,
        "probabilities": probabilities,
        "input_received": request.samples
    }

    # Background task for audit logging
    background_tasks.add_task(
        logger.info,
        f"[audit] user={current_user} endpoint=iris input={request.samples} output={predictions}"
    )

    return IrisPredictResponse(**result)

@app.post(
    "/api/v1/cancer/predict",
    response_model=CancerPredictResponse,
    status_code=status.HTTP_200_OK,
    dependencies=[Depends(heavy_limit)]
)
async def predict_cancer(
    request: CancerPredictRequest,
    background_tasks: BackgroundTasks,
    current_user: str = Depends(get_current_user),
):
    """
    Predict breast cancer diagnosis from features.

    Example request:
        {
            "model_type": "bayes",
            "samples": [
                {
                    "mean_radius": 17.99,
                    "mean_texture": 10.38,
                    ...
                }
            ],
            "posterior_samples": 1000  # optional
        }
    """
    logger.info(f"User {current_user} called /cancer/predict with {len(request.samples)} samples")
    logger.debug(f"→ Cancer payload: {request.samples}")

    # No early 503 here – model_service will transparently fall back to stub if Bayes not yet ready

    # Convert Pydantic models to dicts
    features = [sample.dict() for sample in request.samples]
    logger.debug(f"→ Cancer features: {features}")

    # Get predictions
    predictions, probabilities, uncertainties = await model_service.predict_cancer(
        features=features,
        model_type=request.model_type,
        posterior_samples=request.posterior_samples
    )
    logger.debug(f"← Cancer predictions: {predictions}")
    logger.debug(f"← Cancer probabilities: {probabilities}")
    logger.debug(f"← Cancer uncertainties: {uncertainties}")

    result = {
        "predictions": predictions,
        "probabilities": probabilities,
        "uncertainties": uncertainties,
        "input_received": request.samples
    }

    # Background task for audit logging
    background_tasks.add_task(
        logger.info,
        f"[audit] user={current_user} endpoint=cancer input={request.samples} output={predictions}"
    )

    return CancerPredictResponse(**result) 

@app.get("/api/v1/debug/ready")
async def debug_ready():
    """Debug endpoint to check _app_ready status."""
    return {
        "app_ready": get_app_ready(),
        "model_service_initialized": model_service.initialized,
        "models": list(model_service.models.keys()),
        "status": model_service.status,
        "errors": {k: v for k, v in model_service.status.items() if k.endswith("_last_error")}
    }

@app.get("/api/v1/debug/compiler")
async def debug_compiler():
    """
    Debug endpoint to check JAX/NumPyro backend configuration.
    Returns information about the JAX backend setup.
    """
    try:
        import jax
        import numpyro
        import pymc as pm

        return {
            "backend": "jax_numpyro",
            "jax_version": jax.__version__,
            "numpyro_version": numpyro.__version__,
            "pymc_version": pm.__version__,
            "jax_devices": str(jax.devices()),
            "jax_platform": jax.default_backend(),
            "status": "jax_backend_configured"
        }
    except ImportError as e:
        return {
            "backend": "unknown",
            "error": f"Import error: {e}",
            "status": "missing_dependencies"
        }
    except Exception as e:
        return {
            "backend": "unknown", 
            "error": f"Configuration error: {e}",
            "status": "configuration_failed"
        } 

@app.get("/api/v1/test/401")
async def test_401():
    """Test endpoint that returns 401 for testing session expiry."""
    raise HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Test 401 response"
    )

# ── Debug‑only ratelimit helpers ─────────────────────────────────────────────
from .deps.limits import get_redis, user_or_ip

@app.post("/api/v1/debug/ratelimit/reset", include_in_schema=False)
async def rl_reset(request: Request):
    """
    Flush **all** rate‑limit counters bound to the caller (JWT _or_ IP).

    We match every fragment that contains the identifier to survive
    future changes in FastAPI‑Limiter's key schema.
    """
    r = get_redis()
    if not r:
        raise HTTPException(status_code=503, detail="Rate‑limiter not initialised")

    ident = await user_or_ip(request)
    keys = await r.keys(f"ratelimit:*{ident}*")        # <— broader pattern
    if keys:
        await r.delete(*keys)
    return {"reset": len(keys)}

if settings.DEBUG_RATELIMIT:          # OFF by default
    @app.get("/api/v1/debug/ratelimit/{bucket}", include_in_schema=False)
    async def rl_status(bucket: str, request: Request):
        """
        Inspect Redis keys for the current identifier + bucket.
        Handy for CI tests – **never enable in prod**.
        """
        key_prefix = f"ratelimit:{bucket}:{await user_or_ip(request)}"
        r = get_redis()
        keys = await r.keys(f"{key_prefix}*")
        values = await r.mget(keys) if keys else []
        return dict(zip(keys, values)) 


Overwriting api/app/main.py


# Tests

In [30]:
%%writefile api/scripts/ensure_models.py
#!/usr/bin/env python3
"""
Ensure models script - pre-trains all models before starting the API.
This can be used in development or CI to ensure models are ready.
"""

import asyncio
import logging
import sys
from pathlib import Path

# Add the api directory to the Python path
sys.path.insert(0, str(Path(__file__).parent.parent))

from app.services.ml.model_service import TRAINERS, ModelService

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

async def main():
    """Ensure all models are trained and loaded."""
    logger.info("🚀 Starting model ensure script...")
    
    svc = ModelService()
    
    # Start the self-healing process
    await svc.startup(auto_train=True)
    
    # Wait until all models are loaded
    max_wait = 300  # 5 minutes max
    start_time = asyncio.get_event_loop().time()
    
    while len(svc.models) < len(TRAINERS):
        if asyncio.get_event_loop().time() - start_time > max_wait:
            logger.error("❌ Timeout waiting for models to load")
            return False
            
        logger.info(f"⏳ Waiting for models... ({len(svc.models)}/{len(TRAINERS)} loaded)")
        
        # Check for failed models
        failed = [name for name, status in svc.status.items() if status == "failed"]
        if failed:
            logger.error(f"❌ Models failed to train: {failed}")
            return False
            
        await asyncio.sleep(5)
    
    logger.info("✅ All models loaded successfully!")
    return True

if __name__ == "__main__":
    try:
        success = asyncio.run(main())
        sys.exit(0 if success else 1)
    except KeyboardInterrupt:
        logger.info("⏹️  Interrupted by user")
        sys.exit(1)
    except Exception as e:
        logger.error(f"❌ Unexpected error: {e}")
        sys.exit(1) 

Overwriting api/scripts/ensure_models.py


In [31]:
%%writefile api/__init__.py
# Create logs dir early when package is imported by Uvicorn workers
import os
os.makedirs("logs", exist_ok=True) 

Overwriting api/__init__.py


In [32]:
%%writefile tests/test_rate_limits.py
#!/usr/bin/env python3
"""
Test script for rate limiting functionality.
Run this to verify that rate limits are working correctly.
"""

import asyncio
import httpx
import time
import os
from typing import Optional

class RateLimitTester:
    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url
        self.client = httpx.AsyncClient(timeout=10.0)
        self.token: Optional[str] = None

    async def login(self) -> bool:
        """Login to get a JWT token."""
        try:
            response = await self.client.post(
                f"{self.base_url}/api/v1/token",
                data={"username": "alice", "password": "supersecretvalue"}
            )
            if response.status_code == 200:
                data = response.json()
                self.token = data["access_token"]
                print("✅ Login successful")
                return True
            else:
                print(f"❌ Login failed: {response.status_code}")
                return False
        except Exception as e:
            print(f"❌ Login error: {e}")
            return False

    async def test_endpoint(self, endpoint: str, payload: dict, name: str, expected_limit: int):
        """Test rate limiting on a specific endpoint."""
        print(f"\n🔍 Testing {name} endpoint: {endpoint}")
        print(f"Expected limit: {expected_limit} requests per window")
        
        headers = {}
        if self.token:
            headers["Authorization"] = f"Bearer {self.token}"

        success_count = 0
        rate_limited_count = 0
        
        for i in range(expected_limit + 5):  # Try a few extra requests
            try:
                response = await self.client.post(
                    f"{self.base_url}{endpoint}",
                    json=payload,
                    headers=headers
                )
                
                # Check rate limit headers
                remaining = response.headers.get("X-RateLimit-Remaining")
                limit = response.headers.get("X-RateLimit-Limit")
                
                if response.status_code == 200:
                    success_count += 1
                    print(f"  ✅ Request {i+1}: Success (Remaining: {remaining}/{limit})")
                elif response.status_code == 429:
                    rate_limited_count += 1
                    retry_after = response.headers.get("Retry-After", "unknown")
                    print(f"  🚫 Request {i+1}: Rate limited (Retry-After: {retry_after}s)")
                    break
                else:
                    print(f"  ❌ Request {i+1}: Error {response.status_code}")
                    break
                    
            except Exception as e:
                print(f"  ❌ Request {i+1}: Exception {e}")
                break

        print(f"📊 Results: {success_count} successful, {rate_limited_count} rate limited")
        return success_count, rate_limited_count

    async def test_login_rate_limit(self):
        """Test login rate limiting."""
        print("\n🔍 Testing login rate limiting")
        
        rate_limited_count = 0
        for i in range(10):  # Try more than the limit
            try:
                response = await self.client.post(
                    f"{self.base_url}/api/v1/token",
                    data={"username": "alice", "password": "wrongpassword"}
                )
                
                if response.status_code == 401:
                    print(f"  ✅ Login attempt {i+1}: Expected 401 (invalid credentials)")
                elif response.status_code == 429:
                    rate_limited_count += 1
                    retry_after = response.headers.get("Retry-After", "unknown")
                    print(f"  🚫 Login attempt {i+1}: Rate limited (Retry-After: {retry_after}s)")
                    break
                else:
                    print(f"  ❌ Login attempt {i+1}: Unexpected {response.status_code}")
                    break
                    
            except Exception as e:
                print(f"  ❌ Login attempt {i+1}: Exception {e}")
                break

        print(f"📊 Login rate limit results: {rate_limited_count} rate limited")
        return rate_limited_count > 0

    async def close(self):
        """Close the HTTP client."""
        await self.client.aclose()

async def main():
    """Run all rate limiting tests."""
    print("🚀 Starting rate limiting tests...")
    
    tester = RateLimitTester()
    
    try:
        # Test login rate limiting first
        login_rate_limited = await tester.test_login_rate_limit()
        
        # Login to get token for authenticated endpoints
        if not await tester.login():
            print("❌ Cannot proceed without login")
            return
        
        # Test iris prediction (light limit)
        iris_success, iris_rate_limited = await tester.test_endpoint(
            "/api/v1/iris/predict",
            {
                "model_type": "rf",
                "samples": [{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}]
            },
            "Iris Prediction",
            120  # Should be 2x default limit
        )
        
        # Test cancer prediction (heavy limit)
        cancer_success, cancer_rate_limited = await tester.test_endpoint(
            "/api/v1/cancer/predict",
            {
                "model_type": "bayes",
                "samples": [{"mean_radius": 17.99, "mean_texture": 10.38, "mean_perimeter": 122.8, "mean_area": 1001, "mean_smoothness": 0.1184, "mean_compactness": 0.2776, "mean_concavity": 0.3001, "mean_concave_points": 0.1471, "mean_symmetry": 0.2419, "mean_fractal_dimension": 0.07871, "se_radius": 1.095, "se_texture": 0.9053, "se_perimeter": 8.589, "se_area": 153.4, "se_smoothness": 0.006399, "se_compactness": 0.04904, "se_concavity": 0.05373, "se_concave_points": 0.01587, "se_symmetry": 0.03003, "se_fractal_dimension": 0.006193, "worst_radius": 25.38, "worst_texture": 17.33, "worst_perimeter": 184.6, "worst_area": 2019, "worst_smoothness": 0.1622, "worst_compactness": 0.6656, "worst_concavity": 0.7119, "worst_concave_points": 0.2654, "worst_symmetry": 0.4601, "worst_fractal_dimension": 0.1189}]
            },
            "Cancer Prediction",
            30  # Should be cancer limit
        )
        
        # Test training endpoints (very limited)
        training_success, training_rate_limited = await tester.test_endpoint(
            "/api/v1/iris/train",
            {},
            "Iris Training",
            2  # Should be training limit
        )
        
        # Summary
        print("\n📋 Test Summary:")
        print(f"  Login rate limiting: {'✅ Working' if login_rate_limited else '❌ Not working'}")
        print(f"  Iris prediction rate limiting: {'✅ Working' if iris_rate_limited else '❌ Not working'}")
        print(f"  Cancer prediction rate limiting: {'✅ Working' if cancer_rate_limited else '❌ Not working'}")
        print(f"  Training rate limiting: {'✅ Working' if training_rate_limited else '❌ Not working'}")
        
        all_working = login_rate_limited and iris_rate_limited and cancer_rate_limited and training_rate_limited
        print(f"\n🎯 Overall: {'✅ All rate limits working' if all_working else '❌ Some rate limits not working'}")
        
    finally:
        await tester.close()

if __name__ == "__main__":
    asyncio.run(main()) 

Overwriting tests/test_rate_limits.py
