In [None]:
import mlflow
import mlflow.sklearn
import pandas as pd
import json
import logging
import sys
import os
import re
import numpy as np
import mlflow.pyfunc
from mlflow.models.signature import infer_signature
from pydantic import BaseModel


class CustomModel:
    def remove_accents(self, col_name):
        accent_mapping = {
            r"Á": "A", r"É": "E", r"Í": "I", r"Ó": "O", r"Ú": "U",
            r"Â": "A", r"Ê": "E", r"Î": "I", r"Ô": "O", r"Û": "U",
            r"Ä": "A", r"Ë": "E", r"Ï": "I", r"Ö": "O", r"Ü": "U",
            r"À": "A", r"È": "E", r"Ì": "I", r"Ò": "O", r"Ù": "U",
            r"Ñ": "N", r"Ý": "Y", r"Ç": "C", r"Ã": "A", r"Õ": "O"
            # ,".":""             # Added period replacement with nothing
            }
        for character, replacement in accent_mapping.items():
            col_name = re.sub(character, replacement, col_name)
        return col_name

    def remove_special_chars(self, col_name):
        return re.sub(r'[^a-zA-Z0-9\s]', '', col_name)

    def remove_spaces(self, col_name):
        return re.sub(r' ', '', col_name)

    def preprocess_transactions(self, transactions):
        try:
            transactions['description'] = transactions['description'].str.upper()
            transactions['description'] = transactions['description'].apply(self.remove_accents)
            transactions['description'] = transactions['description'].apply(self.remove_special_chars)
            transactions['description'] = transactions['description'].apply(self.remove_spaces)
            transactions = transactions.drop_duplicates()
        except Exception as e:
            raise e
        return transactions

    def predict(self, transactions, percentiles):
        try:
            transactions = self.preprocess_transactions(transactions)
            merged = pd.merge(transactions, percentiles, on=["description"], how="left")
            merged["flag_p25"] = (merged["price"] < merged["p25"]).astype(int)
            merged = merged.reset_index(drop=True)
            merged['amount_difference'] = merged['p25'] - merged['price']
            merged['differential_tax'] = merged['amount_difference'] * merged['quantity'] * 0.18
            merged = merged.reset_index(drop=True)
            merged['ratio'] = merged['differential_tax'] / 700
            merged['ratio'] = merged['ratio'].mask(merged['ratio'] > 1, 1)
        except Exception as e:
            raise e
        return np.array(merged[['ratio']])


class CustomModelWrapper(mlflow.pyfunc.PythonModel):
    def __init__(self, model):
        self.model = model

    def load_context(self, context):
        # Load percentiles CSV from artifacts
        self.percentiles_df = pd.read_csv(context.artifacts["percentiles"])

    def predict(self, context, transactions):
        return self.model.predict(transactions, self.percentiles_df)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('train.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

PERCENTILES_PATH = 'data/raw/percentiles.csv'
TRANSACTIONS_PATH = 'data/raw/transactions.csv'


def load_percentiles():
    percentiles = pd.read_csv(
        PERCENTILES_PATH,
        sep=',',
        encoding='utf_8',
        dtype={
            'code': str,
            'description': str,
            'p25': float,
            'p50': float,
            'p75': float
            }
    )
    return percentiles


def load_transactions():
    transactions = pd.read_csv(
        TRANSACTIONS_PATH,
        sep=',',
        encoding='utf_8',
        dtype={
            'description': str,
            'quantity': float,
            'price': float
            }
    )
    return transactions


def load_inference_example_data():
    try:
        dataset = load_transactions()
        logger.info(
            f"Inference example data loaded successfully. Shape: {dataset.shape}")
        return dataset.head(1).to_dict('records')[0]
    except Exception as e:
        logger.error(f"Error loading inference example data: {e}")
        raise


example = load_inference_example_data()


# Pydantic class for schema definition
class TransactionInput(BaseModel):
    description: str = example['description']
    quantity: float = example['quantity']
    price: float = example['price']


# Create input_example
input_example = pd.DataFrame([TransactionInput().model_dump()])

# Export schema.json
schema = TransactionInput.model_json_schema()
with open("TransactionInput.json", "w") as f:
    json.dump(schema, f, indent=2)

# Train model and log everything
mlflow.set_tracking_uri("http://20.120.201.119:5000")
mlflow.set_experiment("minimal_model_experiment")

percentiles = load_percentiles()
transactions = load_transactions()

os.makedirs("model_artifacts", exist_ok=True)
percentiles.to_csv("model_artifacts/percentiles.csv", index=False)

with mlflow.start_run():
    customModel = CustomModel()
    output = customModel.predict(transactions, percentiles)

    signature = infer_signature(transactions, output)

    # Register model with input_example and signature
    mlflow.pyfunc.log_model(
        python_model=CustomModelWrapper(customModel),
        artifact_path="model",
        artifacts={"percentiles": "model_artifacts/percentiles.csv"},
        input_example=input_example,
        signature=signature
    )

    # Attach schema.json as additional artifact
    mlflow.log_artifact("TransactionInput.json", artifact_path="schemas")

    print(
        "✅ Model trained, input_example validated and schema.json registered."
    )