In [None]:
import mlflow
import mlflow.sklearn
import pandas as pd
import json
import logging
import sys
import os
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from custom_model import CustomModel, CustomModelWrapper

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('train.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

PERCENTILES_PATH = 'data/raw/percentiles.csv'
TRANSACTIONS_PATH = 'data/raw/transactions.csv'


def load_percentiles():
    percentiles = pd.read_csv(
        PERCENTILES_PATH,
        sep=',',
        encoding='utf_8',
        dtype={
            'code': str,
            'description': str,
            'p25': float,
            'p50': float,
            'p75': float
            }
    )
    return percentiles


def load_transactions():
    transactions = pd.read_csv(
        TRANSACTIONS_PATH,
        sep=',',
        encoding='utf_8',
        dtype={
            'description': str,
            'quantity': float,
            'price': float
            }
    )
    return transactions


def load_inference_example_data():
    try:
        dataset = load_transactions()
        logger.info(
            f"Inference example data loaded successfully. Shape: {dataset.shape}")
        return dataset.head(1).to_dict('records')[0]
    except Exception as e:
        logger.error(f"Error loading inference example data: {e}")
        raise


example = load_inference_example_data()


# Pydantic class for schema definition
class TransactionInput(BaseModel):
    description: str = example['description']
    quantity: float = example['quantity']
    price: float = example['price']


# Create input_example
input_example = pd.DataFrame([TransactionInput().model_dump()])

# Export schema.json
schema = TransactionInput.model_json_schema()
with open("TransactionInput.json", "w") as f:
    json.dump(schema, f, indent=2)

# Train model and log everything
mlflow.set_tracking_uri("http://20.120.201.119:5000")
mlflow.set_experiment("minimal_model_experiment")

percentiles = load_percentiles()
transactions = load_transactions()

os.makedirs("model_artifacts", exist_ok=True)
percentiles.to_csv("model_artifacts/percentiles.csv", index=False)

with mlflow.start_run():
    customModel = CustomModel()
    output = customModel.predict(transactions, percentiles)

    signature = infer_signature(transactions, output)

    # Register model with input_example and signature
    mlflow.pyfunc.log_model(
        python_model=CustomModelWrapper(customModel),
        artifact_path="model",
        artifacts={"percentiles": "model_artifacts/percentiles.csv"},
        input_example=input_example,
        signature=signature,
        code_path=["custom_model.py"]
    )

    # Attach schema.json as additional artifact
    mlflow.log_artifact("TransactionInput.json", artifact_path="schemas")

    print(
        "âœ… Model trained, input_example validated and schema.json registered."
    )