In [35]:

import argparse
import logging
import os
import joblib

import mlflow
import mlflow.sklearn

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
import os
import sys
from pathlib import Path

In [36]:
# MLflow configuration
mlflow_tracking_uri = 'http://mlflow:5000'
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment("kkbox-churn-prediction-fake")

print(f"MLflow Tracking URI: {mlflow.get_tracking_uri()}")
print(f"MLflow Experiment: {mlflow.get_experiment_by_name('kkbox-churn-prediction-fake').experiment_id}")

MLflow Tracking URI: http://mlflow:5000
MLflow Experiment: 2


In [37]:
import sys
from pathlib import Path

# Add the "utils" folder to PYTHONPATH (works in notebooks)
sys.path.append(str(Path().resolve().parent.parent / "utils"))
from model_preprocessor import preprocess_features_for_lr


In [38]:
SNAPSHOT = "2016-04-01"
FEATURE_BASE = "/app/datamart/gold/feature_store/2016-04-01"
PARTITION_PATH = os.path.join(FEATURE_BASE, f"snapshot_date={SNAPSHOT}")
MODEL_PKL = "/app/mlflow/models/lr_churn_model_latest.pkl"

In [39]:
feature_store = "/app/datamart/gold/feature_store/snapshot_date="+"2016-04-01"

In [40]:
from pyspark.sql import SparkSession

# Init Spark
spark = (
    SparkSession.builder
    .appName("Inference")
    .getOrCreate()
)



In [41]:
# Load parquet
df = spark.read.parquet(feature_store)

# Quick checks
df.printSchema()
print("Rows:", df.count())
df.show(3, truncate=False)


root
 |-- msno: string (nullable = true)
 |-- city_clean: integer (nullable = true)
 |-- registered_via: integer (nullable = true)
 |-- registration_date: date (nullable = true)
 |-- tenure_days_at_snapshot: integer (nullable = true)
 |-- registered_via_freq: double (nullable = true)
 |-- city_freq: double (nullable = true)
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- num_unq_w30_sum: long (nullable = true)
 |-- sum_secs_w30: double (nullable = true)
 |-- active_days_w30: long (nullable = true)
 |-- complete_rate_w30: double (nullable = true)
 |-- sum_secs_w7: double (nullable = true)
 |-- engagement_ratio_7_30: double (nullable = true)
 |-- days_since_last_play: integer (nullable = true)
 |-- trend_secs_w30: double (nullable = true)
 |-- tenure_days: integer (nullable = true)
 |-- last_is_auto_renew: integer (nullable = true)
 |-- last_plan_list_price: integer (nullable = true)
 |-- auto_renew_share: double (nullable = true)

Rows: 2
+---------------

In [42]:
import pickle

with open("../05_model_training/models/xgb_model_20251103_122821.pkl", "rb") as f:
    artifact = pickle.load(f)

# unpack components
model = artifact["model"]
scaler = artifact["scaler"]
feature_cols = artifact["feature_columns"]
numeric_cols = artifact.get("numeric_columns", [])
print("Loaded model type:", artifact["model_type"])
print("Num features expected:", len(feature_cols))


Loaded model type: XGBoost
Num features expected: 38


In [43]:
# robust_inference.py — drop-in snippet
import os
import joblib
import pandas as pd
import pyspark
from pyspark.sql.functions import col

# CONFIG
SNAPSHOT = "2016-04-01"
FEATURE_BASE = "/app/datamart/gold/feature_store/"
PARTITION = os.path.join(FEATURE_BASE, f"snapshot_date={SNAPSHOT}")
MODEL_PKL = "/app/mlflow/models/lr_churn_model_latest.pkl"

# Start Spark (quiet)
spark = pyspark.sql.SparkSession.builder.master("local[*]").appName("inference_safe").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

# Load partition (fallback to base+filter)
if os.path.exists(PARTITION):
    sdf = spark.read.parquet(PARTITION)
else:
    if not os.path.exists(FEATURE_BASE):
        raise FileNotFoundError(FEATURE_BASE)
    sdf = spark.read.parquet(FEATURE_BASE)
    if "snapshot_date" in sdf.columns:
        sdf = sdf.filter(col("snapshot_date") == SNAPSHOT)

if sdf.count() == 0:
    raise ValueError("No rows for snapshot")

pdf = sdf.toPandas()

# Ensure id/snapshot_date exist
if "snapshot_date" not in pdf.columns:
    pdf["snapshot_date"] = SNAPSHOT
# rename common id synonyms to msno
for c in list(pdf.columns):
    if c.lower() in ("msno", "member_id", "user_id", "userid", "id", "memberid") and c != "msno":
        pdf = pdf.rename(columns={c: "msno"})
        break
if "msno" not in pdf.columns:
    raise KeyError("msno not found")

# Load model artifact (dict)
artifact = joblib.load(MODEL_PKL)
model = artifact.get("model")
scaler = artifact.get("scaler", None)
feature_cols = artifact.get("feature_columns")           # final OHE columns expected by model
numeric_cols_meta = artifact.get("numeric_columns", [])  # numeric columns list saved at training
orig_feature_list = artifact.get("original_feature_columns") or artifact.get("original_feature_list")

# --- Build X aligned to feature_cols ---
if feature_cols:
    # Determine raw categorical columns to OHE:
    if orig_feature_list:
        # choose candidates that actually exist in pdf
        candidates = [c for c in orig_feature_list if c in pdf.columns]
    else:
        candidates = [c for c in ("registered_via", "city_clean") if c in pdf.columns]

    # pick those candidate columns that are non-numeric (categorical)
    cat_cols = [c for c in candidates if not pd.api.types.is_numeric_dtype(pdf[c])]

    # Run one-hot (drop_first=True to match training)
    if cat_cols:
        pdf_ohe = pd.get_dummies(pdf, columns=cat_cols, drop_first=True, dtype=int)
    else:
        pdf_ohe = pdf.copy()

    # Add any missing expected model cols with zeros
    for c in feature_cols:
        if c not in pdf_ohe.columns:
            pdf_ohe[c] = 0

    # Keep only model columns (in saved order)
    X = pdf_ohe[feature_cols].copy()

else:
    # No final feature list — fallback to numeric-only selection excluding id/date
    drop = {"msno", "snapshot_date"}
    X = pdf[[c for c in pdf.columns if c not in drop and pd.api.types.is_numeric_dtype(pdf[c])]].copy()
    if X.shape[1] == 0:
        raise RuntimeError("No usable features and no feature_columns in pickle")

# --- Convert to numeric safely and fill NaNs ---
# coerce every column to numeric (non-numeric -> NaN), then fill with 0
for c in X.columns:
    X[c] = pd.to_numeric(X[c], errors="coerce").fillna(0.0)

# --- Apply scaler safely ---
if scaler is not None:
    try:
        # prefer scaler.feature_names_in_ if available (ensures correct order)
        if hasattr(scaler, "feature_names_in_"):
            scaler_cols = [c for c in scaler.feature_names_in_ if c in X.columns]
            if scaler_cols:
                X.loc[:, scaler_cols] = scaler.transform(X[scaler_cols])
        elif numeric_cols_meta:
            scaler_cols = [c for c in numeric_cols_meta if c in X.columns]
            if scaler_cols:
                X.loc[:, scaler_cols] = scaler.transform(X[scaler_cols])
        else:
            # fallback: scale all numeric columns
            num_cols = X.select_dtypes(include=["number"]).columns.tolist()
            if num_cols:
                X.loc[:, num_cols] = scaler.transform(X[num_cols])
    except Exception as e:
        # scaling failed — safe fallback: continue with filled but unscaled features
        print("Warning: scaler.transform failed:", str(e))

# Final safety net: ensure no NaN remains
if X.isnull().values.any():
    X = X.fillna(0.0)

# --- Predict probabilities ---
if hasattr(model, "predict_proba"):
    probs = model.predict_proba(X)[:, 1]
elif hasattr(model, "predict"):
    print("predict_proba not available; using predict as fallback")
    probs = pd.Series(model.predict(X)).astype(float).values
else:
    raise RuntimeError("Model has no predict/probability methods")

# Attach and show
pdf["churn_proba"] = probs
print(pdf[["msno", "snapshot_date", "churn_proba"]].to_string(index=False))

spark.stop()


                                        msno snapshot_date  churn_proba
jn/lbZ3oyU0QhViguMecdTZbmG49VQp3C8H3DXI70to=    2016-04-01     0.576900
kRqgRrKQ/dPESVdL9W9yUyzYu8JMi7exmtJ6VoZg3hA=    2016-04-01     0.221884


  X.loc[:, scaler_cols] = scaler.transform(X[scaler_cols])
  X.loc[:, scaler_cols] = scaler.transform(X[scaler_cols])
  X.loc[:, scaler_cols] = scaler.transform(X[scaler_cols])


In [45]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

pred_path = "/app/datamart/gold/predictions/predictions_2016_04_01.parquet"  # change if needed
df_pred = spark.read.parquet(pred_path)

df_pred.show(20, truncate=False)
df_pred.printSchema()


+--------------------------------------------+-------------+------------------+-----------------------------+
|msno                                        |snapshot_date|churn_proba       |model_file                   |
+--------------------------------------------+-------------+------------------+-----------------------------+
|kRqgRrKQ/dPESVdL9W9yUyzYu8JMi7exmtJ6VoZg3hA=|2016-04-01   |0.4923136532306671|xgb_model_20251103_122821.pkl|
|jn/lbZ3oyU0QhViguMecdTZbmG49VQp3C8H3DXI70to=|2016-04-01   |0.5832123756408691|xgb_model_20251103_122821.pkl|
|kRqgRrKQ/dPESVdL9W9yUyzYu8JMi7exmtJ6VoZg3hA=|2016-04-01   |0.2218837027480165|lr_churn_model_latest.pkl    |
|jn/lbZ3oyU0QhViguMecdTZbmG49VQp3C8H3DXI70to=|2016-04-01   |0.5768996750374289|lr_churn_model_latest.pkl    |
+--------------------------------------------+-------------+------------------+-----------------------------+

root
 |-- msno: string (nullable = true)
 |-- snapshot_date: string (nullable = true)
 |-- churn_proba: double (nullabl