In [None]:
import os
import pyspark
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
import mlflow
import pandas as pd
from datetime import datetime


## Spark Initialization

In [None]:
spark = (
    SparkSession.builder
    .appName("churn_inference_notebook")
    .master("local[*]")
    .config("spark.sql.shuffle.partitions", "4")  # faster for local debugging
    .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")
print("âœ… Spark started")


## Load features from feature store

In [None]:
def load_features(spark, snapshot_date_str):
    """Load inference features for a given snapshot date."""
    snapshot_date = datetime.strptime(snapshot_date_str, "%Y-%m-%d").date()

    path = "/app/datamart/gold/inference_feature_store/"
    print(f"ðŸ“‚ Loading feature store: {path}")

    df = (
        spark.read.parquet(path)
        .filter(F.col("snapshot_date") == F.lit(snapshot_date))
    )

    print(f"âœ… Loaded features: {df.count()} rows for snapshot_date={snapshot_date}")
    return df


## Load MLflow Model

def load_mlflow_model(model_name):
    """
    Load an MLflow model by registered name or full URI.
    """
    mlflow.set_tracking_uri("http://mlflow:5000")

    if model_name.startswith("models:/"):
        model_uri = model_name
    else:
        model_uri = f"models:/{model_name}/Production"

    logger.info(f"Loading MLflow model from: {model_uri}")

    try:
        model = mlflow.sklearn.load_model(model_uri)
        logger.info("âœ… MLflow model loaded successfully")
        return model
    except Exception as e:
        logger.error(f"Failed to load MLflow model: {e}")
        raise


In [None]:
def load_mlflow_model(model_name_or_uri):
    """Load MLflow model by registered name or full model URI."""
    
    print(f"ðŸ“¦ Loading MLflow model: {model_name_or_uri}")

    # model = mlflow.pyfunc.load_model(model_name_or_uri)
    model = mlflow.sklearn.load_model(model_uri)

    print("âœ… Model loaded successfully!")
    return model


## Save predictions to datamart

def save_predictions(spark, df_predictions, model_name, snapshot_date_str):
    """
    Save predictions to parquet under:
    datamart/gold/model_predictions/<model_name>/
    """
    base_dir = f"datamart/gold/model_predictions/{model_name}/"
    os.makedirs(base_dir, exist_ok=True)

    filename = f"{model_name}_predictions_{snapshot_date_str.replace('-', '_')}.parquet"
    filepath = os.path.join(base_dir, filename)

    (
        spark.createDataFrame(df_predictions)
            .write.mode("overwrite")
            .parquet(filepath)
    )

    logger.info(f"âœ… Predictions saved: {filepath}")


In [None]:
def save_predictions(spark, predictions_pdf, model_name, snapshot_date_str):
    """Save inference prediction results to parquet."""
    
    output_dir = "/app/datamart/gold/inference_output/"
    os.makedirs(output_dir, exist_ok=True)

    output_path = os.path.join(
        output_dir,
        f"{model_name}_predictions_{snapshot_date_str.replace('-', '')}.parquet"
    )

    # Convert pandas â†’ spark
    preds_sdf = spark.createDataFrame(predictions_pdf)

    preds_sdf.write.mode("overwrite").parquet(output_path)

    print(f"âœ… Predictions saved to: {output_path}")


## Main Inference Pipeline

def main(snapshot_date_str, model_name):

    logger.info("=== Starting Model Inference Job ===")

    # Spark session
    spark = pyspark.sql.SparkSession.builder \
        .appName("inference") \
        .master("local[*]") \
        .getOrCreate()
    spark.sparkContext.setLogLevel("ERROR")

    # Load features
    features_sdf = load_features(spark, snapshot_date_str)

    logger.info("Feature schema:")
    features_sdf.printSchema()

    # Convert to pandas
    features_pdf = features_sdf.toPandas()
    logger.info(f"Converted Spark â†’ Pandas: shape={features_pdf.shape}")

    # Extract feature columns
    # feature_cols = [c for c in features_pdf.columns if c.startswith("fe_")]
    feature_cols = ['tenure_days_at_snapshot', 'registered_via', 'city_clean', 
                'sum_secs_w30', 'active_days_w30', 'complete_rate_w30', 
                'sum_secs_w7', 'engagement_ratio_7_30', 'days_since_last_play', 
                'trend_secs_w30', 'auto_renew_share', 'last_is_auto_renew']
    X_inference = features_pdf[feature_cols]

    # Load MLflow model
    model = load_mlflow_model(model_name)

    # Predict
    y_proba = model.predict_proba(X_inference)[:, 1]

    # Output dataframe
    output = features_pdf[["msno", "snapshot_date"]].copy()
    output["model_name"] = model_name
    output["model_predictions"] = y_proba

    # Save
    save_predictions(spark, output, model_name, snapshot_date_str)

    spark.stop()
    logger.info("=== Inference Job Completed ===")


In [None]:
def run_inference(snapshot_date_str, model_uri):
    """
    Full inference pipeline:
    - load features
    - convert to pandas
    - load MLflow model
    - predict_proba
    - save output parquet
    """

    # --- Load features ---
    features_sdf = load_features(spark, snapshot_date_str)

    # Convert to pandas for ML model
    features_pdf = features_sdf.toPandas()

    # Identify feature columns (assuming fe_ prefix)
    # feature_cols = [c for c in features_pdf.columns if c.startswith("fe_")]
    feature_cols = ['tenure_days_at_snapshot',
                'registered_via',
                'city_clean', 
                'sum_secs_w30',
                'active_days_w30',
                'complete_rate_w30',
                'sum_secs_w7',
                'engagement_ratio_7_30',
                'days_since_last_play',
                'trend_secs_w30',
                'auto_renew_share',
                'last_is_auto_renew']
    X = features_pdf[feature_cols]

    # --- Load MLflow model ---
    model = load_mlflow_model(model_uri)

    # --- Inference ---
    preds = model.predict(X)
    if hasattr(model, "predict_proba"):
        proba = model.predict_proba(X)[:, 1]
    else:
        proba = preds  # fallback for regressors / non-proba models

    # --- Build output dataframe ---
    output = features_pdf[["msno", "snapshot_date"]].copy()
    output["model_name"] = model_uri
    output["prediction"] = preds
    output["probability"] = proba

    # --- Save to parquet ---
    save_predictions(spark, output, model_uri.replace("/", "_"), snapshot_date_str)

    return output


## Entry Point

In [None]:
snapshot_date = "2016-05-01"

# Choose model from MLflow registry
model_uri = "models:/LogisticRegression/1"   # or /XGBoost/1, /RandomForest/1

results = run_inference(snapshot_date, model_uri)
results.head()
