In [0]:
%pip install mlflow xgboost

%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

%restart_python

In [0]:

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import DataFrame, functions as F, types as T, Window

import builtins
from datetime import datetime
from typing import Optional, Dict, Union, List, Tuple, Any
import math
import random


import pandas as pd
import numpy as np
import sklearn

from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
import mlflow

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics


from pyspark.ml.feature import BucketedRandomProjectionLSH
from pyspark.ml.linalg import Vectors, DenseVector, SparseVector, VectorUDT
from pyspark.ml import Pipeline, PipelineModel


from pyspark.ml.tuning import CrossValidatorModel, TrainValidationSplitModel, ParamGridBuilder, CrossValidator, TrainValidationSplit
from pyspark.storagelevel import StorageLevel

import matplotlib.pyplot as plt

from pyspark.sql.functions import round
import mlflow.spark
from mlflow.artifacts import download_artifacts


In [0]:
from src.config import *
from src.sampling import *
from src.tracking import *
# from src.tuning import * 
#from Deprecated.deprecated_tracking import *

In [0]:
# AnÃ¡lisis of proportion of churn for different targets

FEATURES_TABLE_NAME = "teams.data_science.pp_churn_features_v4_cluster"

display(spark.sql(f"""select count(case when churn3 = 1 then 1 end) / count(case when churn3 is not null then 1 end) as churn3_rt,
       count(case when churn5 = 1 then 1 end) / count(case when churn5 is not null then 1 end) as churn5_rt,
       count(case when churn7 = 1 then 1 end) / count(case when churn7 is not null then 1 end) as churn7_rt
        from {FEATURES_TABLE_NAME}"""))


In [0]:
LABEL_COL = "churn7"
#FEATURES_TABLE_NAME = "teams.data_science.pp_churn_features_v3_small"

DATE_FILTER = "2025-10-26"
DATE_INTERVAL = 30

# Payer split: None --> no split, "0" --> non-payer, "1,2" --> payer
payer_split = "1,2"

# Cluster vars: None --> no extra cluster vars
cluster_vars = None
# cluster_vars = ["cluster_0","cluster_1", "cluster_2", "cluster_3", "cluster_4", "cluster_5", "cluster_6", "cluster_7", "cluster_8", "cluster_9", "cluster_10", "cluster_11", "cluster_12", "cluster_13", "cluster_14", "cluster_15", "cluster_16", "cluster_17", "cluster_18", "cluster_19", "cluster_20", "cluster_21", "cluster_22", "cluster_23", "cluster_24", "cluster_25", "cluster_26", "cluster_27", "cluster_28"]

# Sampling method: None --> "no sampling", "up" --> oversampling, "under" --> undersampling
sampling_method = "up"

# These are loaded in config already
#EXPERIMENT_NAME = "/Users/krista@jamcity.com/PP-Churn-Model"
#FEATURES_TABLE_NAME = "teams.data_science.pp_churn_features"

In [0]:
string_features = []
other_features = ['unique_levels_played', 'market_idx','dayofweek','rounds_played', 'avg_attempts', 'total_attempts', 'avg_moves', 'win_rate', 'assist_success_rate', 'unassist_success_rate', 'assist_rate', 'total_boosters_used', 'total_boosters_spent', 'used_boosters_rate', 'spend_boosters_rate', 'avg_difficulty_score', 'rate_hard_levels', 'rate_superhard_levels', 'min_room_id_int', 'max_room_id_int', 'daily_win_rate_ref', 'daily_avg_boosters_used_ref', 'daily_avg_boosters_spent_ref', 'attribution_source_cd_idx', 'country_cd_idx', 'payer_type_cd_idx', 'iap_lifetime_amt', 'days_since_install', 'days_since_last_purchase', 'ad_revenue_amt', 'iap_revenue_amt', 'session_qty', 'total_session_length_qty', 'avg_session_length', 'sessions_per_round', 'avg_population_wr_on_levels_played_today', 'avg_population_assisted_rate_today', 'avg_population_attempts_today', 'wr_diff_vs_population', 'attempts_diff_vs_population', 'assist_rate_diff_vs_population', 'active_days_l7d', 'total_rounds_l7d', 'avg_rounds_l7d', 'avg_win_rate_l7d', 'avg_attempts_l7d', 'boosters_used_l7d', 'avg_used_boosters_rate_l7d', 'active_days_l14d', 'avg_rounds_l14d', 'avg_win_rate_l14d', 'std_rounds_l14d', 'std_win_rate_l14d', 'active_days_l30d', 'avg_rounds_l30d', 'rounds_trend_weekly', 'win_rate_trend_weekly', 'boosters_usage_trend_weekly', 'rounds_ratio_7d_vs_14_7d', 'frequency_ratio_7d_vs_14d', 'levels_progressed_l7d', 'levels_progressed_l14d', 'levels_progressed_l30d', 'days_on_current_max_level', 'level_diversity_ratio',]

if cluster_vars is not None:
    other_features = other_features + cluster_vars


In [0]:
# Get data from table

# If there is payer split
if payer_split is None:

    churn_features = spark.sql(f"""select * from {FEATURES_TABLE_NAME}
                                where '{LABEL_COL}' is not null
                                and date between date_sub('{DATE_FILTER}',{DATE_INTERVAL}) AND '{DATE_FILTER}' """)\
        .withColumn("label",col(LABEL_COL))

else:

    churn_features = spark.sql(f"""select * from {FEATURES_TABLE_NAME}
                                where '{LABEL_COL}' is not null
                                and payer_type_cd_idx in ({payer_split})
                                and date between date_sub('{DATE_FILTER}',{DATE_INTERVAL}) AND '{DATE_FILTER}' """)\
        .withColumn("label",col(LABEL_COL))
    if payer_split == "0":
        other_features.remove("payer_type_cd_idx")


In [0]:
# Get stratified train, validation, test set
strat_train, strat_val, strat_test = stratified_sampling(churn_features, P_TEST=0.2, P_VAL=0.2)

In [0]:
# Sampling method:

if sampling_method == "under":

  # Undersample majority class
  strat_train, train_info = undersample_majority(churn_features)
  print(train_info)

elif sampling_method == "up":

  #Upsample minority class
  strat_train, train_info = upsample_minority(churn_features)
  print(train_info)


Build Pipeline for classification

In [0]:
mlflow.set_experiment(EXPERIMENT_NAME)

In [0]:
#TODO: would love to have a function that automatically sorts the columns by type
#drop_for_features = {"judi","date","churn3"} 
#feature_cols = [c for c in df.columns if c not in drop_for_features and c not in drop_cols]

In [0]:
def get_safe_works_repartition(df):

    conf = spark.sparkContext.getConf()
    cores_per_exec = int(conf.get("spark.executor.cores", "1"))
    # executors = all JVMs except the driver
    num_exec = spark._jsc.sc().getExecutorMemoryStatus().size() - 1
    slots = __builtins__.max(1, cores_per_exec * __builtins__.max(1, num_exec))

    safe_workers = __builtins__.max(1, __builtins__.min(slots, 32))  # cap if you like
    df = df.repartition(safe_workers)  # match partitions to workers

    return df, safe_workers

In [0]:
# if num_workers > available slots, fitting fails
# determine number of workers and repartition the training data
strat_train, safe_workers = get_safe_works_repartition(strat_train)
print(safe_workers)
# strat_train_up, _ = get_safe_works_repartition(strat_train_up)
# strat_train_under, _ = get_safe_works_repartition(strat_train_under)

# Build Pipeline

In [0]:
# For XGBoost we don't need to standarize any features
indexers = [StringIndexer(inputCol=x, 
                          outputCol=x+"_index", 
                          handleInvalid="keep") for x in string_features]
indexed_cols = [ x+"_index" for x in string_features]

inputs = other_features + indexed_cols

vec_assembler = VectorAssembler(inputCols=inputs, outputCol='features', handleInvalid='keep')


# Now add the xgb model to the pipeline
#eval_metrics = ["auc", "aucpr", "logloss"]
eval_metrics = ["aucpr"]

xgb = SparkXGBClassifier(
  features_col = "features",
  label_col = "label",
  num_workers = safe_workers,
  eval_metric = eval_metrics,
)

# Set the pipeline stages for the entire process
pipeline = Pipeline().setStages(indexers+[vec_assembler]+ [xgb])

You can fit your pipeline model here with MLFlow tracking...

In [0]:
# Param specs for random grid builder
spec = {
    # "n_estimators": ("int_uniform", 50, 1000),
    "max_depth":  ("int_uniform", 8, 8), # Originally "max_depth":  ("int_uniform", 4, 8),
    #"gamma": ("uniform", 0.0, 0.2),
    #"learning_rate": ("uniform", 0.01,0.5),
    # "subsample": ("uniform", 0.7, 0.9),
    #"colsample_bytree": ("uniform", 0.7, 0.9),
    # "min_child_weight": ("int_uniform", 1, 5),
    #"reg_alpha": ("uniform", 0.0, 0.1),
    #"reg_lambda": ("int_uniform", 1, 10),
    #"colsample_bylevel": ("uniform", 0, 0.6),
}

# build random xgb param map
xgb_param_maps = build_random_param_maps(xgb, spec, n_samples=40, seed=7)


cv_xgb = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=xgb_param_maps,
    numFolds=5,
    seed=7,
    # parallelism=150
)


In [0]:
import logging

# Set the MLflow logging level to INFO
logger = logging.getLogger("mlflow")
logger.setLevel(logging.INFO)


In [0]:
# ------------------------------------------------------------------
# MLflow run: train, evaluate and log the best XGB pipeline model
# ------------------------------------------------------------------

import builtins
import statistics
import json

# Use PR AUC as the CV metric
cv_xgb.setEvaluator(
    BinaryClassificationEvaluator(
        labelCol="label",
        rawPredictionCol="probability",   # XGB outputs probability column
        metricName="areaUnderPR",
    )
)

# Build a descriptive run name
run_name_parts = [
    "XGB",
    str(DATE_INTERVAL),
    sampling_method or "no_sampling",
    LABEL_COL,
]

if payer_split == "1,2":
    run_name_parts.append("payer")
elif payer_split == "0":
    run_name_parts.append("non_payer")

run_name = "_".join(run_name_parts)

with mlflow.start_run(run_name=run_name):

    # ------------ Log config / data context -------------
    mlflow.log_param("label_col", LABEL_COL)
    mlflow.log_param("features_table", FEATURES_TABLE_NAME)
    mlflow.log_param("date_filter", DATE_FILTER)
    mlflow.log_param("date_interval_days", DATE_INTERVAL)
    mlflow.log_param("payer_split", payer_split or "None")
    mlflow.log_param("sampling_method", sampling_method or "None")
    mlflow.log_param("num_workers", safe_workers)
    mlflow.log_param("cluster_vars", ",".join(cluster_vars) if cluster_vars else "None")

    for k, (dist, low, high) in spec.items():
        mlflow.log_param(f"search_{k}_dist", dist)
        mlflow.log_param(f"search_{k}_low", low)
        mlflow.log_param(f"search_{k}_high", high)

    # ------------ Fit CrossValidator ---------------------
    cv_model = cv_xgb.fit(strat_train)
    best_model = cv_model.bestModel  # PipelineModel

    # 3) Average CV metric over the 5 folds (for best param set)
    # CrossValidator chooses the param set with the highest avgMetrics value
    avg_cv_aupr = float(builtins.max(cv_model.avgMetrics))
    mlflow.log_metric("cv_mean_areaUnderPR", avg_cv_aupr)

    # ------------ Find assembler & XGB stages ------------
    xgb_stage = None
    vec_assembler_stage = None

    for s in best_model.stages:
        if isinstance(s, SparkXGBClassifier):
            xgb_stage = s
        if isinstance(s, VectorAssembler):
            vec_assembler_stage = s

    # ------------ 1) Merge val + test and evaluate once ---
    eval_df = strat_val.unionByName(strat_test)

    evaluator_pr = BinaryClassificationEvaluator(
        labelCol="label",
        rawPredictionCol="probability",
        metricName="areaUnderPR",
    )

    eval_preds = best_model.transform(eval_df)
    eval_aupr = evaluator_pr.evaluate(eval_preds)
    mlflow.log_metric("eval_areaUnderPR", eval_aupr)

    # Sizes of splits (still nice to know)
    mlflow.log_metric("n_train_rows", strat_train.count())
    mlflow.log_metric("n_val_rows", strat_val.count())
    mlflow.log_metric("n_test_rows", strat_test.count())
    mlflow.log_metric("n_eval_rows", eval_df.count())

    # ------------ 2) Feature importance from XGBoost -----
        # ------------ 2) Feature importance from XGBoost -----
    # Re-find XGB & assembler stages in a robust way
    xgb_stage = None
    vec_assembler_stage = None

    for s in best_model.stages:
        # Any stage that exposes get_booster() is our XGBoost model
        if hasattr(s, "get_booster"):
            xgb_stage = s
        # Standard Spark VectorAssembler
        if isinstance(s, VectorAssembler):
            vec_assembler_stage = s

    if xgb_stage is None:
        print("[WARN] No XGBoost stage with get_booster() found in best_model.stages; skipping feature importance logging.")
    else:
        try:
            booster = xgb_stage.get_booster()
        except Exception as e:
            print(f"[WARN] Could not extract booster from XGB stage: {e}")
        else:
            # importance_type can be "gain", "weight", "cover", "total_gain", etc.
            score_dict = booster.get_score(importance_type="gain")

            # Try to map f0, f1, ... back to original feature names
            feature_names = None
            if vec_assembler_stage is not None:
                feature_names = list(vec_assembler_stage.getInputCols())

            mapped_scores = []
            for fname, score in score_dict.items():
                orig_name = fname
                if feature_names and fname.startswith("f"):
                    try:
                        idx = int(fname[1:])
                        if idx < len(feature_names):
                            orig_name = feature_names[idx]
                        # else leave as fN
                    except ValueError:
                        # Not in the f<number> pattern; leave as is
                        pass

                mapped_scores.append(
                    {
                        "feature": orig_name,
                        "xgb_feature": fname,
                        "gain": float(score),
                    }
                )

            # Sort by importance (gain) descending
            mapped_scores.sort(key=lambda x: x["gain"], reverse=True)

            # Log top 50 as metrics for quick inspection in MLflow UI
            for row in mapped_scores[:50]:
                # Metric names must be ASCII / reasonably short
                safe_name = row["feature"].replace(" ", "_").replace(".", "_")
                mlflow.log_metric(f"feat_gain__{safe_name}", row["gain"])

            # Log full importance as JSON artifact
            import json
            importance_path = "/tmp/xgb_feature_importance.json"
            with open(importance_path, "w") as f:
                json.dump(mapped_scores, f, indent=2)

            mlflow.log_artifact(importance_path, artifact_path="feature_importance")
            print("[INFO] Logged feature importance JSON under artifact path 'feature_importance'.")

    # ------------ Log the full Spark pipeline model ------
    mlflow.spark.log_model(
        spark_model=best_model,
        artifact_path="model",
    )

    mlflow.set_tag("model_type", "SparkXGBClassifier")
    mlflow.set_tag("label", LABEL_COL)


In [0]:
# Display results (based on CV performance only)

experiment_lst = [{"run_id":"a679118511bd46c3b13b76bb22e24972","name":"XGB_30_up_churn7"},
                  {"run_id":"d68b7ac6064c4ce7989a5716fe4b2322","name":"XGB_30_up_churn3"},
                  {"run_id":"3fdd4de3b3bb42b9a765a50d1e2254e3","name":"XGB_30_up_churn5"},
                  {"run_id":"0ae51ea70f3b48cab60c0dc9ea8049fc","name":"XGB_30_up_churn7_non_payer"},
                  {"run_id":"954ff7ad42d04d0eaa29eee31d0de69b","name":"XGB_30_up_churn7_payer"},
                  {"run_id":"a7c6fab2d77846a8918e49368f9c7e1c","name":"XGB_30_up_churn7_non_payer_cluster"},
                  {"run_id":"d92501999472499d85e1e794716fb5b5","name":"XGB_30_up_churn7_payer_cluster"},
                  {"run_id":"9b478d4b0ac94f9f8175fccc21fed8ee","name":"XGB_30_under_churn7_payer"},
                  {"run_id":"13fb5876ecd346de9f012bf9529545d8","name":"XGB_30_under_churn7_non_payer"},
                  {"run_id":"82dd7b0b6a3f427ca70a0ef47c04ca54","name":"XGB_30_no_sampling_churn7_non_payer"},
                  {"run_id":"aa2f0431a0434aa18eb5d5737ae59c11","name":"XGB_30_no_sampling_churn7_payer"}]

df_lst = []

val_eval = False

for experiment in experiment_lst:

    run_id = experiment["run_id"]

    artifact_path = download_artifacts(artifact_uri=f"runs:/{run_id}/search_results.csv")
    df_tmp = pd.read_csv(artifact_path)
    df_tmp["model"] = experiment["name"]

    df_lst.append(df_tmp)

import pandas as pd
df = pd.concat(df_lst, axis=0)

display(df.sort_values(["params","model"]))


In [0]:
# Display results (based on CV performance and test performance)

import mlflow

def get_model_by_name(run_name, model_path="model"):
    # 1. Search for the run by name
    # "run_name" is a first-class attribute in search filters
    runs = mlflow.search_runs(
    filter_string=f"run_name = '{run_name}'",
    order_by=["attribute.start_time DESC"],  # Sort by newest first
    max_results=1
    )
    
    # 2. Safety check: Ensure a run was found
    if runs.empty:
        raise ValueError(f"No run found with name: {run_name}")
    
    # 3. Get the Run ID (first match)
    run_id = runs.iloc[0]["run_id"]
    print(f"Found Run ID: {run_id} for name: {run_name}")
    
    # 4. Construct the Model URI
    # Format: runs:/<run_id>/<artifact_path>
    model_uri = f"runs:/{run_id}/{model_path}"
    
    return model_uri

# Usage

run_name_lst = ["XGB_30_no_sampling_churn7_non_payer",
                "XGB_30_no_sampling_churn7_payer",
                "XGB_30_under_churn7_non_payer",
                "XGB_30_under_churn7_payer",
                "XGB_30_up_churn7_non_payer",
                "XGB_30_up_churn7_payer"]
df_lst = []

for run_name in run_name_lst:
    
    df_temp = mlflow.search_runs(filter_string=f"run_name = '{run_name}'")
    df_temp = df_temp[(df_temp["start_time"] >= "2025-11-24") & (df_temp["tags.model_type"] == "SparkXGBClassifier")][["tags.mlflow.runName","metrics.cv_mean_areaUnderPR","metrics.eval_areaUnderPR"]]
    df_lst.append(df_temp)

df = pd.concat(df_lst, axis=0)
display(df)


In [0]:
# Check new models trained using the definitive table

import mlflow

def get_model_by_name(run_name, model_path="model"):
    # 1. Search for the run by name
    # "run_name" is a first-class attribute in search filters
    runs = mlflow.search_runs(
    filter_string=f"run_name = '{run_name}'",
    order_by=["attribute.start_time DESC"],  # Sort by newest first
    max_results=1
    )
    
    # 2. Safety check: Ensure a run was found
    if runs.empty:
        raise ValueError(f"No run found with name: {run_name}")
    
    # 3. Get the Run ID (first match)
    run_id = runs.iloc[0]["run_id"]
    print(f"Found Run ID: {run_id} for name: {run_name}")
    
    # 4. Construct the Model URI
    # Format: runs:/<run_id>/<artifact_path>
    model_uri = f"runs:/{run_id}/{model_path}"
    
    return model_uri

# Usage

run_name_lst = ["XGB_date_interval_30_nonpayer_churnlabel_churn7_sampling_stratified",
                "XGB_date_interval_30_payer_churnlabel_churn7_sampling_stratified",
                "XGB_date_interval_30_None_churnlabel_churn7_sampling_stratified",
                "XGB_date_interval_30_nonpayer_churnlabel_churn7_sampling_upsample",
                "XGB_date_interval_30_payer_churnlabel_churn7_sampling_upsample",
                "XGB_date_interval_30_None_churnlabel_churn7_sampling_upsample",
                "XGB_date_interval_30_nonpayer_churnlabel_churn7_sampling_undersample",
                "XGB_date_interval_30_payer_churnlabel_churn7_sampling_undersample",
                "XGB_date_interval_30_None_churnlabel_churn7_sampling_undersample"]
df_lst = []

for run_name in run_name_lst:
    
    df_temp = mlflow.search_runs(filter_string=f"run_name = '{run_name}'")
    df_temp = df_temp[(df_temp["start_time"].astype(str) >= "2025-11-24")][["tags.mlflow.runName","metrics.area_under_pr"]]
    df_lst.append(df_temp)

df = pd.concat(df_lst, axis=0)
display(df)


In [0]:
# Check probability threshold
model_name_lst = ["XGB_date_interval_180_payer_churnlabel_churn7_sampling_upsample"]

for model_name in model_name_lst:

    model_uri = get_model_by_name(model_name,model_path="best_model/spark-model")
    model = mlflow.spark.load_model(model_uri)
    

In [0]:
print([p.name for p in model.stages[-1].params])