# Test model in notebook

In [None]:
from typing import Any, Dict
from sklearn.pipeline import Pipeline
import pandas as pd
import joblib

In [None]:
def validate_payload(payload: Dict[str, Any]) -> None:
    """
    Validate the input payload for stroke prediction.
    Raises ValueError if something is missing or invalid.
    """
    # expected schema
    expected_fields = {
        "gender": ["Male", "Female", "Other"],
        "age": (0, 120),
        "ever_married": ["Yes", "No"],
        "work_type": ["Private", "Self-employed", "Govt_job", "children", "Never_worked"],
        "Residence_type": ["Urban", "Rural"],
        "avg_glucose_level": (0, 400),
        "bmi": (10, 100),
        "smoking_status": ["formerly smoked", "never smoked", "smokes"],
        "hypertension": [0,1],
        "heart_disease": [0,1],
    }

    # check for missing keys
    missing = [k for k in expected_fields if k not in payload]
    if missing:
        raise ValueError(f"Missing keys in payload: {missing}")

    # check for unexpected keys
    extras = [k for k in payload if k not in expected_fields]
    if extras:
        raise ValueError(f"Unexpected keys in payload: {extras}")

    # type and value validation
    validated = {}
    for key, rule in expected_fields.items():
        value = payload[key]
        if isinstance(rule, tuple):  # numeric range
            if not isinstance(value, (int, float)):
                raise ValueError(f"{key} must be numeric, got {type(value).__name__}")
            low, high = rule
            if not (low <= value <= high):
                raise ValueError(f"{key}={value} outside plausible range {rule}")
            validated[key] = float(value)
        elif isinstance(rule, list):  # categorical choices
            if value not in rule:
                raise ValueError(f"{key} must be one of {rule}, got '{value}'")
            validated[key] = value
        else:
            raise ValueError(f"Internal schema error for {key}")


def get_stroke_prob(model: Pipeline, payload: Dict[str, Any]) -> float:
    X = pd.DataFrame([payload])
    proba = float(model.predict_proba(X)[:, 1][0])
    return proba

In [None]:
model = joblib.load("../models/log_reg_model.joblib")

In [None]:
example_payload = {
    "gender": "Male",
    "age": 67.0,
    "ever_married": "Yes",
    "work_type": "Self-employed",
    "Residence_type": "Urban",
    "avg_glucose_level": 228.69,
    "bmi": 36.6,
    "smoking_status": "smokes",
}
get_stroke_prob(model, example_payload)

# Test API

In [None]:
import requests

API_URL = "http://127.0.0.1:8000/predict"
# API_URL = "https://stroke-example-1085259940267.us-central1.run.app/predict"

response = requests.post(API_URL, json=example_payload)
response.text
print("Status code:", response.status_code)
print("Response JSON:", response.json())