In [None]:
import os
import json
import ray
import pandas as pd
import mlflow.pyfunc
from fastapi import FastAPI, Request
from json.decoder import JSONDecodeError
from jsonschema import validate, ValidationError
from mlflow.tracking import MlflowClient
from mlflow.exceptions import MlflowException
from ray import serve
from typing import Any, Dict
import logging
import sys
import numpy as np

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('import_model.log'),
        logging.StreamHandler(sys.stdout)
    ]
)

logger = logging.getLogger(__name__)

# Connect to Ray cluster
if not ray.is_initialized():
    ray.init(address="ray://raycluster-kuberay-head-svc:10001")

# Initialize Ray Serve
serve.start(detached=True, http_options={"host": "0.0.0.0", "port": 8000})

MODEL_CONFIG = {
    'NAME': "minimal_model",
    'VERSION': "5",
    'INPUT_SCHEMA': "TransactionInput.json",
    'ROUTE_PREFIX': "/v1"
}

MLFLOW_IP = "20.120.201.119"
RAY_IP = "20.29.159.59"

# FastAPI application
app = FastAPI(
    title="Minimal Predictor API",
    description="Import Pipeline online inference"
)


@serve.deployment(ray_actor_options={"num_cpus": 0.1})
@serve.ingress(app)
class MinimalModel:
    def __init__(self):
        self.model = self._load_model(
            f"models:/{MODEL_CONFIG['NAME']}/{MODEL_CONFIG['VERSION']}")
        self.schema = self._load_schema(self.model.metadata.run_id)
        self.columns = list(self.schema["properties"].keys())

    @staticmethod
    def _load_model(model_uri: str):
        mlflow.set_tracking_uri(f"http://{MLFLOW_IP}:5000")
        return mlflow.pyfunc.load_model(model_uri)

    @staticmethod
    def _load_schema(run_id: str) -> Dict[str, Any]:
        client = MlflowClient()
        schema_dir = client.download_artifacts(run_id, "schemas")
        with open(os.path.join(schema_dir, MODEL_CONFIG['INPUT_SCHEMA'])) as f:
            return json.load(f)

    @staticmethod
    async def _parse_request(request: Request) -> Dict[str, Any]:
        try:
            return await request.json()
        except JSONDecodeError as e:
            raise ValueError(f"❌ Invalid JSON: {e.msg} at position {e.pos}")

    def _validate_input(self, data: Dict[str, Any]) -> None:
        try:
            validate(instance=data, schema=self.schema)
        except ValidationError as e:
            raise ValueError(f"❌ JSON Schema validation error: {e.message}")

    def _to_dataframe(self, data: Dict[str, Any]) -> pd.DataFrame:
        try:
            values = [
                float(data[col])
                if self.schema["properties"][col].get("type") == "number"
                else data[col]
                for col in self.columns
            ]
            return pd.DataFrame([values], columns=self.columns)
        except Exception as e:
            raise ValueError(f"❌ Error converting data to DataFrame: {e}")

    @app.post("/")
    async def predict(self, request: Request):
        try:
            data = await self._parse_request(request)
            self._validate_input(data)
            row = self._to_dataframe(data)
            pred = self.model.predict(row)
            logger.info(f"Inference prediction: {pred}")
            if np.isnan(pred).any():
                logger.warning("⚠️ Prediction contains NaN values, replacing with 0.0")
                pred = np.nan_to_num(pred, nan=0.0)
            return {
                "code": 200,
                "message": "Score calculated successfully.",
                "data": {"import_event_probability": pred.tolist()},
                "errors": None
            }
        except ValueError as ve:
            return {"error": str(ve)}
        except MlflowException as me:
            return {
                "error": "❌ Prediction error",
                "message": str(me).split("Error:")[-1].strip()
            }
        except Exception as e:
            return {"error": f"❌ Unexpected error: {str(e)}"}


# Deploy model
serve.run(MinimalModel.bind(),
          name=f"{MODEL_CONFIG['NAME']}_app",
          route_prefix=MODEL_CONFIG['ROUTE_PREFIX'])

print(f"✅ Model deployed at http://{RAY_IP}:8000{MODEL_CONFIG['ROUTE_PREFIX']}")