In [None]:
import mlflow
import mlflow.sklearn
import pandas as pd
import json
import logging
import sys
import os
import re
import numpy as np
import mlflow.pyfunc
from mlflow.models.signature import infer_signature
from pydantic import BaseModel


class CustomModel:
    def reemplazar_tildes(self, col_name):
        mapeo_tildes = {
            r"Á": "A", r"É": "E", r"Í": "I", r"Ó": "O", r"Ú": "U",
            r"Â": "A", r"Ê": "E", r"Î": "I", r"Ô": "O", r"Û": "U",
            r"Ä": "A", r"Ë": "E", r"Ï": "I", r"Ö": "O", r"Ü": "U",
            r"À": "A", r"È": "E", r"Ì": "I", r"Ò": "O", r"Ù": "U",
            r"Ñ": "N", r"Ý": "Y", r"Ç": "C", r"Ã": "A", r"Õ": "O"
            # ,".":""             #Se agrega el cambio de un punto por nada
            }
        for letra, reemplazo in mapeo_tildes.items():
            col_name = re.sub(letra, reemplazo, col_name)
        return col_name

    def reemplazar_otro(self, col_name):
        return re.sub(r'[^a-zA-Z0-9\s]', '', col_name)

    def reemplazar_espacios(self, col_name):
        return re.sub(r' ', '', col_name)

    def preprocess_transacciones(self, transacciones):
        try:
            transacciones['codigo_mercancia'] = transacciones['codigo_mercancia'].str.upper()
            transacciones['codigo_mercancia'] = transacciones['codigo_mercancia'].apply(self.reemplazar_tildes)
            transacciones['codigo_mercancia'] = transacciones['codigo_mercancia'].apply(self.reemplazar_otro)
            transacciones['codigo_mercancia'] = transacciones['codigo_mercancia'].apply(self.reemplazar_espacios)
            transacciones = transacciones.drop_duplicates()
        except Exception as e:
            raise e
        return transacciones

    def predict(self, transacciones, percentiles):
        try:
            transacciones = self.preprocess_transacciones(transacciones)
            merged = pd.merge(transacciones, percentiles, on=["codigo_mercancia"], how="left")
            merged["flag_p25"] = (merged["fob_unitario"] < merged["p25"]).astype(int)
            merged = merged.reset_index(drop=True)
            merged['Monto_diferencia'] = merged['p25'] - merged['fob_unitario']
            merged['tributo_difrencial'] = merged['Monto_diferencia'] * merged['cnt_merc'] * 0.18
            merged = merged.reset_index(drop=True)
            merged['ratio'] = merged['tributo_difrencial'] / 700
            merged['ratio'] = merged['ratio'].mask(merged['ratio'] > 1, 1)
        except Exception as e:
            raise e
        return np.array(merged[['ratio']])


class CustomModelWrapper(mlflow.pyfunc.PythonModel):
    def __init__(self, model):
        self.model = model

    def load_context(self, context):
        # Cargar el CSV de percentiles desde los artifacts
        self.percentiles_df = pd.read_csv(context.artifacts["percentiles"])

    def predict(self, context, transacciones):
        return self.model.predict(transacciones, self.percentiles_df)


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('train.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

PERCENTILES_PATH = 'data/raw/percentiles.csv'
TRANSACCIONES_PATH = 'data/raw/transacciones.csv'


def load_percentiles():
    percentiles = pd.read_csv(
        PERCENTILES_PATH,
        sep=',',
        encoding='utf_8',
        dtype={
            'cod_mercancia': str,
            'codigo_mercancia': str,
            'p25': float,
            'p50': float,
            'p75': float
            }
    )
    return percentiles


def load_transacciones():
    transacciones = pd.read_csv(
        TRANSACCIONES_PATH,
        sep=',',
        encoding='utf_8',
        dtype={
            'codigo_mercancia': str,
            'cnt_merc': float,
            'fob_unitario': float
            }
    )
    return transacciones


def load_import_example_data():
    try:
        dataset = load_transacciones()
        logger.info(
            f"Import example data loaded successfully. Shape: {dataset.shape}")
        return dataset.head(1).to_dict('records')[0]
    except Exception as e:
        logger.error(f"Error loading import example data: {e}")
        raise


example = load_import_example_data()


# 1️⃣ Clase Pydantic para el esquema
class TransactionInput(BaseModel):
    codigo_mercancia: str = example['codigo_mercancia']
    cnt_merc: float = example['cnt_merc']
    fob_unitario: float = example['fob_unitario']


# 2️⃣ Crear input_example
input_example = pd.DataFrame([TransactionInput().model_dump()])

# 3️⃣ Exportar schema.json
schema = TransactionInput.model_json_schema()
with open("TransactionInput.json", "w") as f:
    json.dump(schema, f, indent=2)

# 4️⃣ Entrenar modelo y loguear todo
mlflow.set_tracking_uri("http://mlflow.default.svc.cluster.local:5000")
mlflow.set_experiment("modelo_minimas_experiment")

percentiles = load_percentiles()
transacciones = load_transacciones()

os.makedirs("model_artifacts", exist_ok=True)
percentiles.to_csv("model_artifacts/percentiles.csv", index=False)

with mlflow.start_run():
    customModel = CustomModel()
    output = customModel.predict(transacciones, percentiles)

    signature = infer_signature(transacciones, output)

    # 5️⃣ Registrar modelo con input_example y signature
    mlflow.pyfunc.log_model(
        python_model=CustomModelWrapper(customModel),
        artifact_path="model",
        artifacts={"percentiles": "model_artifacts/percentiles.csv"},
        input_example=input_example,
        signature=signature
    )

    # 6️⃣ Adjuntar schema.json como artefacto adicional
    mlflow.log_artifact("TransactionInput.json", artifact_path="schemas")

    print(
        "✅ Modelo entrenado, input_example válido y schema.json registrado."
    )