In [0]:
#!pip install sparkxgb

Looking in indexes: https://artifacts.rbi.tech/artifactory/api/pypi/pypi-org-pypi-proxy/simple/
Collecting sparkxgb
  Downloading https://artifacts.rbi.tech/artifactory/api/pypi/pypi-org-pypi-proxy/packages/packages/7b/b4/9fe10fc109e4c52c12c10deb9bfa1c5886bbfc19d11b2289c8aa45ca01b9/sparkxgb-0.1.tar.gz (3.6 kB)
Collecting pyspark==3.1.1
  Downloading https://artifacts.rbi.tech/artifactory/api/pypi/pypi-org-pypi-proxy/packages/packages/45/b0/9d6860891ab14a39d4bddf80ba26ce51c2f9dc4805e5c6978ac0472c120a/pyspark-3.1.1.tar.gz (212.3 MB)
[?25l[K     |                                | 10 kB 27.0 MB/s eta 0:00:08[K     |                                | 20 kB 32.6 MB/s eta 0:00:07[K     |                                | 30 kB 39.2 MB/s eta 0:00:06[K     |                                | 40 kB 41.6 MB/s eta 0:00:06[K     |                                | 51 kB 41.0 MB/s eta 0:00:06[K     |                                | 61 kB 43.1 MB/s eta 0:00:05[K     |                

# XGBoost

#### Using the example at: [This repo](https://github.com/sllynn/spark-xgboost/blob/master/examples/spark-xgboost_adultdataset.ipynb)

#### Importing modules and disabling MLflow

In [0]:
from sparkxgb import XGBoostClassifier, XGBoostRegressor
from pprint import PrettyPrinter

from pyspark.sql.types import StringType

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
pp = PrettyPrinter()

In [0]:
col_names = [
  "age", "workclass", "fnlwgt",
  "education", "education-num",
  "marital-status", "occupation",
  "relationship", "race", "sex",
  "capital-gain", "capital-loss",
  "hours-per-week", "native-country",
  "label"
]

train_sdf, test_sdf = (
  spark.read.csv(
    path="/databricks-datasets/adult/adult.data",
    inferSchema=True  
  )
  .toDF(*col_names)
  .repartition(200)
  .randomSplit([0.8, 0.2])
)

In [0]:
string_columns = [fld.name for fld in train_sdf.schema.fields if isinstance(fld.dataType, StringType)]
string_col_replacements = [fld + "_ix" for fld in string_columns]
string_column_map=list(zip(string_columns, string_col_replacements))
target = string_col_replacements[-1]
predictors = [fld.name for fld in train_sdf.schema.fields if not isinstance(fld.dataType, StringType)] + string_col_replacements[:-1]
pp.pprint(
  dict(
    string_column_map=string_column_map,
    target_variable=target,
    predictor_variables=predictors
  )
)

{'predictor_variables': ['age',
                         'fnlwgt',
                         'education-num',
                         'capital-gain',
                         'capital-loss',
                         'hours-per-week',
                         'workclass_ix',
                         'education_ix',
                         'marital-status_ix',
                         'occupation_ix',
                         'relationship_ix',
                         'race_ix',
                         'sex_ix',
                         'native-country_ix'],
 'string_column_map': [('workclass', 'workclass_ix'),
                       ('education', 'education_ix'),
                       ('marital-status', 'marital-status_ix'),
                       ('occupation', 'occupation_ix'),
                       ('relationship', 'relationship_ix'),
                       ('race', 'race_ix'),
                       ('sex', 'sex_ix'),
                       ('native-country', 'native-country_ix

In [0]:
si = [StringIndexer(inputCol=fld[0], outputCol=fld[1]) for fld in string_column_map]
va = VectorAssembler(inputCols=predictors, outputCol="features")
pipeline = Pipeline(stages=[*si, va])
fitted_pipeline = pipeline.fit(train_sdf.union(test_sdf))

In [0]:
train_sdf_prepared = fitted_pipeline.transform(train_sdf)
train_sdf_prepared.cache()
train_sdf_prepared.count()

Out[26]: 26152

In [0]:
test_sdf_prepared = fitted_pipeline.transform(test_sdf)
test_sdf_prepared.cache()
test_sdf_prepared.count()

Out[27]: 6409

In [0]:
xgbParams = dict(
  eta=0.1,
  maxDepth=2,
  missing=0.0,
  objective="binary:logistic",
  numRound=5,
  numWorkers=2
)

xgb = (
  XGBoostClassifier(**xgbParams)
  .setFeaturesCol("features")
  .setLabelCol("label_ix")
)

bce = BinaryClassificationEvaluator(
  rawPredictionCol="rawPrediction",
  labelCol="label_ix"
)

In [0]:
param_grid = (
  ParamGridBuilder()
  .addGrid(xgb.eta, [1e-1, 1e-2, 1e-3])
  .addGrid(xgb.maxDepth, [2, 4, 8])
  .build()
)

cv = CrossValidator(
  estimator=xgb,
  estimatorParamMaps=param_grid,
  evaluator=bce,#mce,
  numFolds=5
)

In [0]:
import mlflow
import mlflow.spark

spark_model_name = "best_model_spark"

with mlflow.start_run():
  model = cv.fit(train_sdf_prepared)
  best_params = dict(
    eta_best=model.bestModel.getEta(),
    maxDepth_best=model.bestModel.getMaxDepth()
  )
  mlflow.log_params(best_params)
  
  mlflow.spark.log_model(fitted_pipeline, "featuriser")
  mlflow.spark.log_model(model, spark_model_name)

  metrics = dict(
    roc_test=bce.evaluate(model.transform(test_sdf_prepared))
  )
  mlflow.log_metrics(metrics)



## Alternative Gradient Boosted Approaches

There are lots of other gradient boosted approaches, such as [CatBoost](https://catboost.ai/), [LightGBM](https://github.com/microsoft/LightGBM), vanilla gradient boosted trees in [SparkML](https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier)/[scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html), etc. Each of these has their respective [pros and cons](https://towardsdatascience.com/catboost-vs-light-gbm-vs-xgboost-5f93620723db) that you can read more about.

-sandbox
&copy; 2020 Databricks, Inc. All rights reserved.<br/>
Apache, Apache Spark, Spark and the Spark logo are trademarks of the <a href="http://www.apache.org/">Apache Software Foundation</a>.<br/>
<br/>
<a href="https://databricks.com/privacy-policy">Privacy Policy</a> | <a href="https://databricks.com/terms-of-use">Terms of Use</a> | <a href="http://help.databricks.com/">Support</a>