In [None]:
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline

target = 'default_flag'
feature = ['loan_amnt']
params = {
    "featuresCol": "features_col",
    "labelCol": "target",
    # "predictionCol": "prediction",
    "probabilityCol": "probability",
    # "rawPredictionCol": "rawPrediction",
    "maxDepth": 30,
    "maxBins": 100,
    # "minInstancesPerNode": 0.1,
    # "minInfoGain": 0.0,
    # "impurity": "gini",
    # "seed": 42,
}

target_indexer = StringIndexer(inputCol="default_flag", outputCol="target")
feature_assembler = VectorAssembler(inputCols=feature, outputCol="features_col")
pipeline = Pipeline(stages = [target_indexer, feature_assembler])
pipeline_model = pipeline.fit(data_nulls_excluded)
training_data = pipeline_model.transform(data_nulls_excluded)

In [None]:
w1 = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
undersample = training_data\
    .groupBy("target")\
    .agg(count(lit(1)).alias("count"))\
    .withColumn("percentage", 
                F.col("count") / 
                sum("count").over(w1))\
    .withColumn("undersample", 
                max(when(F.col("target")==1, F.col("percentage"))).over(w1) /
                    max(when(F.col("target")==0, F.col("percentage"))).over(w1))\
    .select(collect_list("undersample")).first()[0][0]

In [None]:
defaulted_df = training_data\
    .sampleBy("target", fractions={0: undersample, 1: 1.0}, seed=0)

defaulted_df\
.groupBy("target")\
.agg(count(lit(1)))\
.show()

In [None]:
dtc = DecisionTreeClassifier(labelCol="target", 
                             featuresCol="features_col",
                             minInfoGain=0.0,
                             maxDepth=7,
                             maxBins=100,
                             impurity='gini')
model = dtc.fit(defaulted_df)

In [None]:
print("Max depth:", model.getMaxDepth())
print("Max Bins:", model.getMaxBins())
print("Min weight:", model.getMinWeightFractionPerNode())
print("Min instances:", model.getMinInstancesPerNode())
print("Min info gain:", model.getMinInfoGain())
print(model.toDebugString)

In [None]:
model.write().overwrite().save(f"{output_path}tree")

In [None]:
tree_df = spark.read.parquet(f"{output_path}tree/data")

tree_df\
.withColumn("impurity0", F.get(F.col("impurityStats"),0))\
.withColumn("impurity1", F.get(F.col("impurityStats"),1))\
.withColumn("med_impurity", (F.col("impurity0") + F.col("impurity1")))\
.where(F.col("split.featureIndex")==0)\
.withColumn("n_thresholds")\
.orderBy(F.col("med_impurity").desc())\
.show()