In [None]:
import mlflow
import mlflow.sklearn
import pandas as pd
import json
import logging
import sys
import os
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from custom_model import CustomModel, CustomModelWrapper

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('train.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

PERCENTILES_PATH = 'data/raw/percentiles.csv'
TRANSACCIONES_PATH = 'data/raw/transacciones.csv'


def load_percentiles():
    percentiles = pd.read_csv(
        PERCENTILES_PATH,
        sep=',',
        encoding='utf_8',
        dtype={
            'cod_mercancia': str,
            'codigo_mercancia': str,
            'p25': float,
            'p50': float,
            'p75': float
            }
    )
    return percentiles


def load_transacciones():
    transacciones = pd.read_csv(
        TRANSACCIONES_PATH,
        sep=',',
        encoding='utf_8',
        dtype={
            'codigo_mercancia': str,
            'cnt_merc': float,
            'fob_unitario': float
            }
    )
    return transacciones


def load_import_example_data():
    try:
        dataset = load_transacciones()
        logger.info(
            f"Import example data loaded successfully. Shape: {dataset.shape}")
        return dataset.head(1).to_dict('records')[0]
    except Exception as e:
        logger.error(f"Error loading import example data: {e}")
        raise


example = load_import_example_data()


# 1️⃣ Clase Pydantic para el esquema
class TransactionInput(BaseModel):
    codigo_mercancia: str = example['codigo_mercancia']
    cnt_merc: float = example['cnt_merc']
    fob_unitario: float = example['fob_unitario']


# 2️⃣ Crear input_example
input_example = pd.DataFrame([TransactionInput().model_dump()])

# 3️⃣ Exportar schema.json
schema = TransactionInput.model_json_schema()
with open("TransactionInput.json", "w") as f:
    json.dump(schema, f, indent=2)

# 4️⃣ Entrenar modelo y loguear todo
mlflow.set_tracking_uri("http://mlflow.default.svc.cluster.local:5000")
mlflow.set_experiment("modelo_minimas_experiment")

percentiles = load_percentiles()
transacciones = load_transacciones()

os.makedirs("model_artifacts", exist_ok=True)
percentiles.to_csv("model_artifacts/percentiles.csv", index=False)

with mlflow.start_run():
    customModel = CustomModel()
    output = customModel.predict(transacciones, percentiles)

    signature = infer_signature(transacciones, output)

    # 5️⃣ Registrar modelo con input_example y signature
    mlflow.pyfunc.log_model(
        python_model=CustomModelWrapper(customModel),
        artifact_path="model",
        artifacts={"percentiles": "model_artifacts/percentiles.csv"},
        input_example=input_example,
        signature=signature,
        code_path=["custom_model.py"]
    )

    # 6️⃣ Adjuntar schema.json como artefacto adicional
    mlflow.log_artifact("TransactionInput.json", artifact_path="schemas")

    print(
        "✅ Modelo entrenado, input_example válido y schema.json registrado."
    )