# Model training fake

In [1]:
#!/usr/bin/env python3
# coding: utf-8
"""
Generate synthetic data and run XGBoost training pipeline (test).
This mirrors the signatures used by your training script.
"""


import argparse
import logging
import os
import joblib

import mlflow
import mlflow.sklearn

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
import os
import sys
from pathlib import Path


# ==== Adjust these if you want a different test size ====
N_SAMPLES = 5000
N_FEATURES = 20
N_INFORMATIVE = 8
RANDOM_STATE = 42

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("fake-data-test")

In [2]:
import pyspark
import pyspark.sql.functions as F
from pyspark.sql.functions import col
from pyspark.sql.types import StringType, IntegerType, FloatType, DateType

import glob
import logging
from datetime import datetime, timedelta
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import pprint

In [3]:
# MLflow configuration
mlflow_tracking_uri = 'http://mlflow:5000'
mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow.set_experiment("kkbox-churn-prediction-fake")

print(f"MLflow Tracking URI: {mlflow.get_tracking_uri()}")
print(f"MLflow Experiment: {mlflow.get_experiment_by_name('kkbox-churn-prediction-fake').experiment_id}")

MLflow Tracking URI: http://mlflow:5000
MLflow Experiment: 2


In [4]:
## Set up config (using your 4-split approach)
model_train_date_str = "2016-04-01"
train_period_months = 8
val_period_months = 2
test_period_months = 2
oot_period_months = 2

config = {}
config["model_train_date_str"] = model_train_date_str
config["model_train_date"] = datetime.strptime(model_train_date_str, "%Y-%m-%d")

# Work backwards from model_train_date
# OOT: Most recent data before deployment (2016-02-01 to 2016-03-31)
config["oot_end_date"] = config['model_train_date'] - timedelta(days=1)
config["oot_start_date"] = config['model_train_date'] - relativedelta(months=oot_period_months)

# Test: Before OOT (2015-12-01 to 2016-01-31)
config["test_end_date"] = config["oot_start_date"] - timedelta(days=1)
config["test_start_date"] = config["oot_start_date"] - relativedelta(months=test_period_months)

# Validation: Before Test (2015-10-01 to 2015-11-30)
config["val_end_date"] = config["test_start_date"] - timedelta(days=1)
config["val_start_date"] = config["test_start_date"] - relativedelta(months=val_period_months)

# Training: Before Validation (2015-02-01 to 2015-09-30)
config["train_end_date"] = config["val_start_date"] - timedelta(days=1)
config["train_start_date"] = config["val_start_date"] - relativedelta(months=train_period_months)

# NEW: Overall date range for extraction (covers all splits)
config["data_start_date"] = config["train_start_date"]  # Earliest date needed
config["data_end_date"] = config["oot_end_date"]        # Latest date needed

pprint.pprint(config)

{'data_end_date': datetime.datetime(2016, 3, 31, 0, 0),
 'data_start_date': datetime.datetime(2015, 2, 1, 0, 0),
 'model_train_date': datetime.datetime(2016, 4, 1, 0, 0),
 'model_train_date_str': '2016-04-01',
 'oot_end_date': datetime.datetime(2016, 3, 31, 0, 0),
 'oot_start_date': datetime.datetime(2016, 2, 1, 0, 0),
 'test_end_date': datetime.datetime(2016, 1, 31, 0, 0),
 'test_start_date': datetime.datetime(2015, 12, 1, 0, 0),
 'train_end_date': datetime.datetime(2015, 9, 30, 0, 0),
 'train_start_date': datetime.datetime(2015, 2, 1, 0, 0),
 'val_end_date': datetime.datetime(2015, 11, 30, 0, 0),
 'val_start_date': datetime.datetime(2015, 10, 1, 0, 0)}


In [5]:
import sys
from pathlib import Path

# Add the "utils" folder to PYTHONPATH (works in notebooks)
sys.path.append(str(Path().resolve().parent.parent / "utils"))

from model_preprocessor import prepare_data_for_training

In [6]:
from argparse import Namespace

args = Namespace(
    train_date="2016-05-01",
    features_path="/app/datamart/gold/feature_store/",
    labels_path="/app/datamart/gold/label_store/",
    sample_frac=0.3,
    label_col="label",
    mlflow_tracking_uri="http://mlflow:5000",
    mlflow_experiment="kkbox-churn-prediction",
    n_iter=10,
    cv_folds=5,
    random_state=42,
    train_months=8,
    val_months=2,
    test_months=2,
    oot_months=2
)

data = prepare_data_for_training(args)


2025-11-07 09:13:13,643 INFO Config date windows computed
2025-11-07 09:13:13,643 INFO Train: 2015-03-01 ‚Üí 2015-10-31
2025-11-07 09:13:13,644 INFO Val:   2015-11-01 ‚Üí 2015-12-31
2025-11-07 09:13:13,645 INFO Test:  2016-01-01 ‚Üí 2016-02-29
2025-11-07 09:13:13,645 INFO OOT:   2016-03-01 ‚Üí 2016-04-30
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/07 09:13:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/11/07 09:13:27 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
2025-11-07 09:13:33,931 INFO Filtered features rows: 20158616                   
2025-11-07 09:13:34,462 INFO Filtered labels rows: 10079308
2025-11-07 

In [9]:
# assume out = prepare_data_for_training(args)
def inspect_prepared(out):
    expected_keys = {
        "X_train_lr","y_train","X_val_lr","y_val","X_test_lr","y_test",
        "X_oot_lr","y_oot","X_train_tree","X_val_tree","X_test_tree","X_oot_tree",
        "scaler","lr_cols","categorical_cols","class_weight_dict","scale_pos_weight","cv"
    }
    missing = expected_keys - set(out.keys())
    extra = set(out.keys()) - expected_keys

    print("Missing keys:", missing)
    print("Extra keys:", extra)
    # quick shapes / types
    for k in ["X_train_lr","X_val_lr","X_test_lr","X_oot_lr"]:
        v = out.get(k)
        print(f"{k}: type={type(v)}, shape={getattr(v,'shape',None)}")
    for k in ["y_train","y_val","y_test","y_oot"]:
        v = out.get(k)
        print(f"{k}: type={type(v)}, length={len(v) if v is not None else None}, dtype={getattr(v,'dtype',None)}")
    print("lr_cols len:", len(out.get("lr_cols", [])))
    print("categorical_cols:", out.get("categorical_cols"))
    print("class_weight_dict:", out.get("class_weight_dict"))
    print("scale_pos_weight:", out.get("scale_pos_weight"))
    print("cv:", out.get("cv"))
    # spot-check class balance
    if "y_train" in out:
        print("y_train value counts:\n", out["y_train"].value_counts(normalize=False).to_dict())

# call:
inspect_prepared(data)


Missing keys: set()
Extra keys: set()
X_train_lr: type=<class 'pandas.core.frame.DataFrame'>, shape=(3046492, 37)
X_val_lr: type=<class 'pandas.core.frame.DataFrame'>, shape=(975699, 37)
X_test_lr: type=<class 'pandas.core.frame.DataFrame'>, shape=(1006214, 37)
X_oot_lr: type=<class 'pandas.core.frame.DataFrame'>, shape=(1020422, 37)
y_train: type=<class 'pandas.core.series.Series'>, length=3046492, dtype=int64
y_val: type=<class 'pandas.core.series.Series'>, length=975699, dtype=int64
y_test: type=<class 'pandas.core.series.Series'>, length=1006214, dtype=int64
y_oot: type=<class 'pandas.core.series.Series'>, length=1020422, dtype=int64
lr_cols len: 37
categorical_cols: ['registered_via', 'city_clean']
class_weight_dict: {0: np.float64(0.5770500721290808), 1: np.float64(3.7446432961305867)}
scale_pos_weight: 6.489286592261173
cv: StratifiedKFold(n_splits=5, random_state=42, shuffle=True)
y_train value counts:
 {0: 2639712, 1: 406780}


In [10]:
lr = LogisticRegression(class_weight=data["class_weight_dict"], random_state=args.random_state, n_jobs=-1)

In [21]:
data["X_train_lr"] = data["X_train_lr"].sample(frac=0.2, random_state=42)
data["y_train"] = data["y_train"].loc[data["X_train_lr"].index]

In [22]:
import os
import pickle
import json
import xgboost as xgb
import mlflow
import mlflow.xgboost
from sklearn.metrics import classification_report, roc_auc_score

# --- Safety checks & DMatrix creation (use correct keys)
X_train = data["X_train_lr"]
y_train = data["y_train"]        # <-- use this key (not y_train_lr unless you have it)
X_val   = data["X_val_lr"]
y_val   = data["y_val"]
X_test  = data["X_test_lr"]
y_test  = data["y_test"]

# optional: if you downsampled earlier and indices are not reset, reset them now
X_train = X_train.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)

assert len(X_train) == len(y_train), f"train mismatch: {len(X_train)} vs {len(y_train)}"
assert len(X_val)   == len(y_val),   f"val mismatch:   {len(X_val)}   vs {len(y_val)}"
assert len(X_test)  == len(y_test),  f"test mismatch:  {len(X_test)}  vs {len(y_test)}"

dtrain = xgb.DMatrix(X_train, label=y_train.values)
dval   = xgb.DMatrix(X_val,   label=y_val.values)
dtest  = xgb.DMatrix(X_test,  label=y_test.values)

# --- Params
params = {
    "objective": "binary:logistic",
    "eval_metric": "auc",
    "scale_pos_weight": data.get("scale_pos_weight", 1.0),
    "tree_method": "hist",
    "max_depth": 6,
    "eta": 0.1,
    "subsample": 0.8,
    "colsample_bytree": 0.8,
    "seed": 42
}
num_boost_round = 300
early_stopping_rounds = 20

# --- Directory to save local artifacts that we'll log to MLflow
local_artifact_dir = "mlflow_artifacts"
os.makedirs(local_artifact_dir, exist_ok=True)

# Optionally set MLflow tracking URI if you use a remote server:
# mlflow.set_tracking_uri("http://your-mlflow-server:5000")

with mlflow.start_run(run_name="xgb_churn") as run:
    # log params
    mlflow.log_params({**params, "num_boost_round": num_boost_round, "early_stopping_rounds": early_stopping_rounds})

    # train
    evals = [(dtrain, "train"), (dval, "val")]
    model = xgb.train(
        params, dtrain,
        num_boost_round=num_boost_round,
        evals=evals,
        early_stopping_rounds=early_stopping_rounds,
        verbose_eval=10
    )

    # log model with MLflow XGBoost flavor
    mlflow.xgboost.log_model(model, artifact_path="xgb_model")

    # predictions + metrics
    y_pred_proba = model.predict(dtest)
    y_pred = (y_pred_proba > 0.5).astype(int)

    auc = float(roc_auc_score(y_test, y_pred_proba))
    report = classification_report(y_test, y_pred, digits=4)
    mlflow.log_metric("test_auc", auc)

    # Save classification report text and log it
    report_path = os.path.join(local_artifact_dir, "classification_report.txt")
    with open(report_path, "w") as f:
        f.write(report)
        f.write("\nTest ROC-AUC: {:.6f}\n".format(auc))
    mlflow.log_artifact(report_path)

    # Save feature importance (as JSON) and log it
    fmap = model.get_score(importance_type="weight")  # {feature:score}
    fi_path = os.path.join(local_artifact_dir, "feature_importance.json")
    with open(fi_path, "w") as f:
        json.dump(fmap, f, indent=2)
    mlflow.log_artifact(fi_path)

    # Bundle model + scaler + lr_cols into a single pickle and log as artifact
    bundle = {
        "xgb_model": model,               # Booster object
        "lr_cols": data["lr_cols"],
        "scaler": data.get("scaler", None),
        "threshold": 0.5
    }
    pkl_path = os.path.join(local_artifact_dir, "xgb_bundle.pkl")
    with open(pkl_path, "wb") as f:
        pickle.dump(bundle, f)
    mlflow.log_artifact(pkl_path)

    # (Optional) also save the raw X_test head for quick debugging
    X_test.head(1000).to_parquet(os.path.join(local_artifact_dir, "X_test_head.parquet"), index=False)
    mlflow.log_artifact(os.path.join(local_artifact_dir, "X_test_head.parquet"))

    # final logging of run id
    run_id = run.info.run_id
    print("MLflow run id:", run_id)
    print("Test ROC-AUC:", auc)
    print("Classification report:\n", report)


The git executable must be specified in one of the following ways:
    - be included in your $PATH
    - be set via $GIT_PYTHON_GIT_EXECUTABLE
    - explicitly set via git.refresh(<full-path-to-git-executable>)

All git commands will error until this is rectified.

This initial message can be silenced or aggravated in the future by setting the
$GIT_PYTHON_REFRESH environment variable. Use one of the following values:
    - quiet|q|silence|s|silent|none|n|0: for no message or exception
    - error|e|exception|raise|r|2: for a raised exception

Example:
    export GIT_PYTHON_REFRESH=quiet



[0]	train-auc:0.80897	val-auc:0.67438
[10]	train-auc:0.83227	val-auc:0.68234
[20]	train-auc:0.84130	val-auc:0.68151
[30]	train-auc:0.85199	val-auc:0.68252


  xgb_model.save_model(model_data_path)


MLflow run id: 2ddc6524115b4d0ebd2d3161c7081289
Test ROC-AUC: 0.742001436518704
Classification report:
               precision    recall  f1-score   support

           0     0.9375    0.6560    0.7719    853682
           1     0.2818    0.7554    0.4104    152532

    accuracy                         0.6710   1006214
   macro avg     0.6097    0.7057    0.5912   1006214
weighted avg     0.8381    0.6710    0.7171   1006214

üèÉ View run xgb_churn at: http://mlflow:5000/#/experiments/2/runs/2ddc6524115b4d0ebd2d3161c7081289
üß™ View experiment at: http://mlflow:5000/#/experiments/2


# Inference

In [23]:
model.


<xgboost.core.Booster at 0x77d679c76e40>