In [0]:
%pip install mlflow xgboost

%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

%restart_python

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import DataFrame, functions as F, types as T, Window

import builtins
from datetime import datetime
from typing import Optional, Dict, Union, List, Tuple, Any
import math
import random


import pandas as pd
import numpy as np
import sklearn

from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
import mlflow

from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics


from pyspark.ml.feature import BucketedRandomProjectionLSH
from pyspark.ml.linalg import Vectors, DenseVector, SparseVector, VectorUDT
from pyspark.ml import Pipeline, PipelineModel


from pyspark.ml.tuning import CrossValidatorModel, TrainValidationSplitModel, ParamGridBuilder, CrossValidator, TrainValidationSplit
from pyspark.storagelevel import StorageLevel

import matplotlib.pyplot as plt

from pyspark.sql.functions import round
import mlflow.spark
from mlflow.artifacts import download_artifacts


In [0]:
from src.config import *
from src.sampling import *
from src.tracking_helpers import *
from src.tracking import *
from src.tuning import * 

In [0]:
#LABEL_COL = "churn7" #Loaded from config

DATE_FILTER = "2025-10-26"
DATE_INTERVAL = 30



In [0]:
df = spark.sql(f"""select * from {FEATURES_TABLE_NAME}
                                where '{LABEL_COL}' is not null
                                and date between date_sub('{DATE_FILTER}',{DATE_INTERVAL}) AND '{DATE_FILTER}' """).withColumn('market_name', col('market')).drop('market').withColumnRenamed('market_name','market')

In [0]:
from pyspark.sql.types import StringType, NumericType, BooleanType

string_features = []
numerical_features = []
churn_labels = []

drop_cols = ['judi','date','ts_last_updated','processed_date','churn3','churn5','churn7','churn14']


for field in df.schema.fields:
    if isinstance(field.dataType, StringType) and field.name not in drop_cols:
        string_features.append(field.name)
    elif isinstance(field.dataType, NumericType) and field.name not in drop_cols:
        numerical_features.append(field.name)
        
    

In [0]:
### This is the churn feature SET!!
churn_features = df.withColumn('label', when(col(LABEL_COL)==True,1).otherwise(0))

In [0]:
split_payers = True
upsample=True
undersample=True

if split_payers:
    payers = ['P','S']
    non_payers= ['N']
    
    unioned_sets = get_stratified_sets(churn_features, split=None, undersample=undersample, upsample=upsample)

    payers_sets = get_stratified_sets(churn_features.filter(col('payer_type_cd').isin(payers)),split='payers',undersample=undersample, upsample=upsample)

    nonpayers_sets = get_stratified_sets(churn_features.filter(col('payer_type_cd').isin(non_payers)), split='nonpayers',undersample=undersample, upsample=upsample)

    all_sets = unioned_sets + payers_sets + nonpayers_sets
else:
    all_sets = get_stratified_sets(churn_features, split=None, undersample=undersample, upsample=upsample)

In [0]:
def get_safe_works_repartition(df):

    conf = spark.sparkContext.getConf()
    cores_per_exec = int(conf.get("spark.executor.cores", "1"))
    # executors = all JVMs except the driver
    num_exec = spark._jsc.sc().getExecutorMemoryStatus().size() - 1
    slots = __builtins__.max(1, cores_per_exec * __builtins__.max(1, num_exec))

    safe_workers = __builtins__.max(1, __builtins__.min(slots, 32))  # cap if you like
    df = df.repartition(safe_workers)  # match partitions to workers

    return df, safe_workers

In [0]:
# Unecessary because we only have 1 worker?

for val in all_sets:
    repartitioned, safe_workers = get_safe_works_repartition(val['dataset'])
    val['dataset']=repartitioned


In [0]:
# For XGBoost we don't need to standarize any features
indexers = [StringIndexer(inputCol=x, 
                          outputCol=x+"_index", 
                          handleInvalid="keep") for x in string_features]
indexed_cols = [ x+"_index" for x in string_features]

inputs = numerical_features + indexed_cols

vec_assembler = VectorAssembler(inputCols=inputs, outputCol='features', handleInvalid='keep')


# Now add the xgb model to the pipeline
#eval_metrics = ["auc", "aucpr", "logloss"]
eval_metrics = ["aucpr"]


safe_workers=1

xgb = SparkXGBClassifier(
  features_col = "features",
  label_col = "label",
  num_workers = safe_workers,
  eval_metric = eval_metrics,
)

# Set the pipeline stages for the entire process
pipeline = Pipeline().setStages(indexers+[vec_assembler]+ [xgb])

In [0]:
spec = {
    "max_depth":  ("int_uniform", 8, 8), # Originally "max_depth":  ("int_uniform", 4, 8),
}

# build random xgb param map
xgb_param_maps = build_random_param_maps(xgb, spec, n_samples=40, seed=7)


cv_xgb = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=xgb_param_maps,
    numFolds=2,
    seed=7,
    # parallelism=150
)

In [0]:
import logging

# Set the MLflow logging level to INFO
logger = logging.getLogger("mlflow")
logger.setLevel(logging.INFO)

In [0]:
# Plus other useful information ... can actually do this elsewhere or whatever.. but this works for now
extra_tags = { 
                'label': LABEL_COL,
                'safe_workers':safe_workers, 
                'date_filter':DATE_FILTER, 
                'date_interval':DATE_INTERVAL, 
                'source_table_name':FEATURES_TABLE_NAME
            }

for val in all_sets:
    val['extra_tags']= {**extra_tags, **val['dataset_info']}

In [0]:
mlflow.set_experiment(EXPERIMENT_NAME)

In [0]:
#all_sets[6]

In [0]:
results_list = []
best_estimators = []


results, best_estimator = run_spark_ml_training(estimator = cv_xgb, 
                        train_df = all_sets[6]["dataset"], 
                        test_df = all_sets[6]["relevant_test_set"], 
                        val_df = all_sets[6]["relevant_val_set"], 
                        extra_tags = all_sets[6]["extra_tags"])
results_list.append(results)
best_estimators.append(best_estimator)

In [0]:

"""
results_list = []
best_estimators = []
#need index 6,7,8 still


for i in all_sets[6:]:
    #print(f"Starting run on set {ix+7} out of {len(all_sets)}")
    #print(f"With dataset and run info:", i["extra_tags"])

    ### Strat train up: 
    results, best_estimator = run_spark_ml_training(estimator = cv_xgb, 
                        train_df = i["dataset"], 
                        test_df = i["relevant_test_set"], 
                        val_df = i["relevant_val_set"], 
                        extra_tags = i["extra_tags"])
    results_list.append(results)
    best_estimators.append(best_estimator)

mlflow.end_run()
"""
