# MODEL TRAINING
This script does training with Spark MLLib of a Random Forest Classification model for the customer churn prediction experiment-</br>
Uses BigQuery as a source, and writes test results, model metrics and 
feature importance scores to BigQuery

In [2]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.sql.types import FloatType
import pyspark.sql.functions as F
from pyspark.ml.feature import StringIndexer
import pandas as pd
import sys, logging, argparse, random, tempfile, json
from pyspark.sql.functions import col, udf
from pyspark.sql.functions import round as spark_round
from pyspark.sql.types import StructType, DoubleType, StringType
from pyspark.sql.functions import lit
from pathlib import Path as path
from google.cloud import storage
from urllib.parse import urlparse, urljoin
from datetime import datetime
import random

In [3]:
spark

In [4]:
# 1a. Arguments
pipelineID = random.randint(1, 10000)
projectNbr = "433578906282"
projectID = "vertex-ai-382806"
displayPrintStatements = True

In [5]:
# 1b. Variables 
appBaseName = "customer-churn-model"
appNameSuffix = "training"
appName = f"{appBaseName}-{appNameSuffix}"
modelBaseNm = appBaseName
modelVersion = pipelineID
bqDatasetNm = f"{projectID}.customer_churn_ds"
operation = appNameSuffix
bigQuerySourceTableFQN = f"{bqDatasetNm}.training_data"
bigQueryModelTestResultsTableFQN = f"{bqDatasetNm}.test_predictions"
bigQueryModelMetricsTableFQN = f"{bqDatasetNm}.model_metrics"
bigQueryFeatureImportanceTableFQN = f"{bqDatasetNm}.model_feature_importance_scores"
modelBucketUri = f"gs://s8s_model_bucket-{projectNbr}/{modelBaseNm}/{operation}/{modelVersion}"
metricsBucketUri = f"gs://s8s_metrics_bucket-{projectNbr}/{modelBaseNm}/{operation}/{modelVersion}"
scratchBucketUri = f"s8s-spark-bucket-{projectNbr}/{appBaseName}/pipelineId-{pipelineID}/{appNameSuffix}/"
pipelineExecutionDt = datetime.now().strftime("%Y%m%d%H%M%S")

In [6]:
# Other variables, constants
SPLIT_SEED = 6
SPLIT_SPECS = [0.8, 0.2]

In [7]:
# 1c. Display input and output
if displayPrintStatements:
    print("Starting model training for *Customer Churn* experiment")
    print(".....................................................")
    print(f"The datetime now is - {pipelineExecutionDt}")
    print(" ")
    print("INPUT PARAMETERS")
    print(f"....pipelineID={pipelineID}")
    print(f"....projectID={projectID}")
    print(f"....projectNbr={projectNbr}")
    print(f"....displayPrintStatements={displayPrintStatements}")
    print(" ")
    print("EXPECTED SETUP")  
    print(f"....BQ Dataset={bqDatasetNm}")
    print(f"....Model Training Source Data in BigQuery={bigQuerySourceTableFQN}")
    print(f"....Scratch Bucket for BQ connector=gs://s8s-spark-bucket-{projectNbr}") 
    print(f"....Model Bucket=gs://s8s-model-bucket-{projectNbr}")  
    print(f"....Metrics Bucket=gs://s8s-metrics-bucket-{projectNbr}") 
    print(" ")
    print("OUTPUT")
    print(f"....Model in GCS={modelBucketUri}")
    print(f"....Model metrics in GCS={metricsBucketUri}")  
    print(f"....Model metrics in BigQuery={bigQueryModelMetricsTableFQN}")      
    print(f"....Model feature importance scores in BigQuery={bigQueryFeatureImportanceTableFQN}") 
    print(f"....Model test results in BigQuery={bigQueryModelTestResultsTableFQN}") 

Starting model training for *Customer Churn* experiment
.....................................................
The datetime now is - 20240603065934
 
INPUT PARAMETERS
....pipelineID=4244
....projectID=vertex-ai-382806
....projectNbr=433578906282
....displayPrintStatements=True
 
EXPECTED SETUP
....BQ Dataset=vertex-ai-382806.customer_churn_ds
....Model Training Source Data in BigQuery=vertex-ai-382806.customer_churn_ds.training_data
....Scratch Bucket for BQ connector=gs://s8s-spark-bucket-433578906282
....Model Bucket=gs://s8s-model-bucket-433578906282
....Metrics Bucket=gs://s8s-metrics-bucket-433578906282
 
OUTPUT
....Model in GCS=gs://s8s_model_bucket-433578906282/customer-churn-model/training/4244
....Model metrics in GCS=gs://s8s_metrics_bucket-433578906282/customer-churn-model/training/4244
....Model metrics in BigQuery=vertex-ai-382806.customer_churn_ds.model_metrics
....Model feature importance scores in BigQuery=vertex-ai-382806.customer_churn_ds.model_feature_importance_score

In [8]:
# 2. Spark Session creation
print('....Initializing spark & spark configs')
spark = SparkSession.builder.appName(appName).getOrCreate()

# Spark configuration setting for writes to BigQuery
spark.conf.set("parentProject", projectID)
spark.conf.set("temporaryGcsBucket", scratchBucketUri)
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# Add Python modules
sc.addPyFile(f"gs://s8s_code_bucket-{projectNbr}/pyspark/common_utils.py")
import common_utils

....Initializing spark & spark configs


24/06/03 06:59:34 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


### TRAINING DATA - READ, SPLIT

In [11]:
# 3. Read training data
print('....Read the training dataset into a dataframe')
inputDF = spark.read \
    .format('bigquery') \
    .load(bigQuerySourceTableFQN)

inputDF.printSchema()

if displayPrintStatements:
    print(f"inputDF count={inputDF.count()}")

....Read the training dataset into a dataframe
root
 |-- customer_id: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- senior_citizen: long (nullable = true)
 |-- partner: string (nullable = true)
 |-- dependents: string (nullable = true)
 |-- tenure: long (nullable = true)
 |-- tenure_group: string (nullable = true)
 |-- phone_service: string (nullable = true)
 |-- multiple_lines: string (nullable = true)
 |-- internet_service: string (nullable = true)
 |-- online_security: string (nullable = true)
 |-- online_backup: string (nullable = true)
 |-- device_protection: string (nullable = true)
 |-- tech_support: string (nullable = true)
 |-- streaming_tv: string (nullable = true)
 |-- streaming_movies: string (nullable = true)
 |-- contract: string (nullable = true)
 |-- paperless_billing: string (nullable = true)
 |-- payment_method: string (nullable = true)
 |-- monthly_charges: double (nullable = true)
 |-- total_charges: double (nullable = true)
 |-- churn: string 

                                                                                

In [12]:
# Typecast some columns to the right datatype
inputDF = inputDF.withColumn("partner", inputDF.partner.cast('string')) \
    .withColumn("dependents", inputDF.dependents.cast('string')) \
    .withColumn("phone_service", inputDF.phone_service.cast('string')) \
    .withColumn("paperless_billing", inputDF.paperless_billing.cast('string')) \
    .withColumn("churn", inputDF.churn.cast('string')) \
    .withColumn("monthly_charges", inputDF.monthly_charges.cast('float')) \
    .withColumn("total_charges", inputDF.total_charges.cast('float'))

In [13]:
# 4. Split to training and test datasets
print('....Split the dataset')
trainDF, testDF = inputDF.randomSplit(SPLIT_SPECS, seed=SPLIT_SEED)

....Split the dataset


### PREPROCESSING & FEATURE ENGINEERING 

In [14]:
# 5. Pre-process training data
print('....Data pre-procesing')
dataPreprocessingStagesList = []
# 5a. Create and append to pipeline stages - string indexing and one hot encoding
for eachCategoricalColumn in common_utils.CATEGORICAL_COLUMN_LIST:
    # Category indexing with StringIndexer
    stringIndexer = StringIndexer(inputCol=eachCategoricalColumn, outputCol=eachCategoricalColumn + "Index")
    # Use OneHotEncoder to convert categorical variables into binary SparseVectors
    encoder = OneHotEncoder(inputCols=[stringIndexer.getOutputCol()], outputCols=[eachCategoricalColumn + "classVec"])
    # Add stages.  This is a lazy operation
    dataPreprocessingStagesList += [stringIndexer, encoder]

# 5b. Convert label into label indices using the StringIndexer and append to pipeline stages
labelStringIndexer = StringIndexer(inputCol="churn", outputCol="label")
dataPreprocessingStagesList += [labelStringIndexer]


....Data pre-procesing


In [15]:
# 6. Feature engineering
print('....Feature engineering')
featureEngineeringStageList = []
assemblerInputs = common_utils.NUMERIC_COLUMN_LIST + [c + "classVec" for c in common_utils.CATEGORICAL_COLUMN_LIST]
featuresVectorAssembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
featureEngineeringStageList += [featuresVectorAssembler]

....Feature engineering


### MODEL TRAINING

In [16]:
# 5. Model training
print('....Model training')
modelTrainingStageList = []
rfClassifier = RandomForestClassifier(labelCol="label", featuresCol="features")
modelTrainingStageList += [rfClassifier]

....Model training


In [17]:
# 6. Create a model training pipeline for stages defined
print('....Instantiating pipeline model')
pipeline = Pipeline(stages=dataPreprocessingStagesList + featureEngineeringStageList + modelTrainingStageList) 


....Instantiating pipeline model


In [18]:
# 9. Fit the model
print('....Fit the model')
pipelineModel = pipeline.fit(trainDF)

....Fit the model


                                                                                

In [None]:
# import mleap.pyspark
# from mleap.pyspark.spark_support import SimpleSparkSerializer
# pipelineModel.serializeToBundle("jar:file:/tmp/mleap_python_model_export/churn-pipeline-json.zip", pipelineModel.transform(trainDF))
# upload_model_to_gcs('/tmp/mleap_python_model_export/churn-pipeline-json.zip', 'gs://s8s_model_bucket-433578906282/customer-churn-model/')

In [None]:
# def upload_model_to_gcs(local_model_path, gcs_destination):
#     """
#     Uploads a model file from the local filesystem to Google Cloud Storage.

#     Args:
#         local_model_path (str): The path to the model file on your local machine.
#         gcs_destination (str): The GCS URI where you want to store the model.
#     """
#     import subprocess

#     command = [
#         "gsutil",
#         "cp",
#         local_model_path,
#         gcs_destination
#     ]

#     try:
#         result = subprocess.run(command, check=True, capture_output=True, text=True)
#         print(f"Upload successful:\n{result.stdout}")
#     except subprocess.CalledProcessError as e:
#         print(f"Upload failed with error:\n{e.stderr}")


### MODEL TESTING

In [19]:
# 10. Test the model with the test dataset
print('....Test the model')
predictionsDF = pipelineModel.transform(testDF)
predictionsDF.show(2)

....Test the model


24/06/03 07:01:19 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 71:>                                                         (0 + 1) / 1]

+-----------+------+--------------+-------+----------+------+------------+-------------+--------------+----------------+---------------+-------------+-----------------+------------+------------+----------------+--------------+-----------------+----------------+---------------+-------------+-----+-----------+---------------------+-----------+--------------+-------------------+----------------------+------------+---------------+---------------+------------------+------------------+---------------------+-------------------+----------------------+---------------------+------------------------+--------------------+-----------------------+------------------+---------------------+----------------------+-------------------------+-----------------+--------------------+-----------------+--------------------+---------------------+------------------------+-------------+----------------+----------------------+-------------------------+-------------------+----------------------+-----+---------------

                                                                                

In [22]:
# 11. Persist model to GCS
print('....Persist the model to GCS')
pipelineModel.write().overwrite().save(modelBucketUri)

....Persist the model to GCS


                                                                                

In [23]:
# 12. Persist model testing results to BigQuery
persistPredictionsDF = predictionsDF.withColumn("pipeline_id", lit(pipelineID).cast("string")) \
                                   .withColumn("model_version", lit(pipelineID).cast("string")) \
                                   .withColumn("pipeline_execution_dt", lit(pipelineExecutionDt)) \
                                   .withColumn("operation", lit(appNameSuffix)) 

persistPredictionsDF.write.format('bigquery') \
.mode("overwrite")\
.option('table', bigQueryModelTestResultsTableFQN) \
.save()

                                                                                

### MODEL EXPLAINABILITY

In [24]:
# 13a. Model explainability - feature importance
pipelineModel.stages[-1].featureImportances

SparseVector(23, {0: 0.1292, 1: 0.2336, 3: 0.0132, 4: 0.0121, 5: 0.0028, 6: 0.0115, 7: 0.0039, 8: 0.0006, 9: 0.1446, 10: 0.0164, 11: 0.0339, 12: 0.0142, 13: 0.0011, 14: 0.0087, 15: 0.0001, 16: 0.0045, 17: 0.1838, 18: 0.0633, 19: 0.0153, 20: 0.0953, 21: 0.0057, 22: 0.0062})

In [25]:
# 13b. Function to parse feature importance
def fnExtractFeatureImportance(featureImportanceSparseVector, predictionsDataframe, featureColumnListing):
    featureColumnMetadataList = []
    for i in predictionsDataframe.schema[featureColumnListing].metadata["ml_attr"]["attrs"]:
        featureColumnMetadataList = featureColumnMetadataList + predictionsDataframe.schema[featureColumnListing].metadata["ml_attr"]["attrs"][i]
        
    featureColumnMetadataPDF = pd.DataFrame(featureColumnMetadataList)
    featureColumnMetadataPDF['importance_score'] = featureColumnMetadataPDF['idx'].apply(lambda x: featureImportanceSparseVector[x])
    return(featureColumnMetadataPDF.sort_values('importance_score', ascending = False))


In [26]:
# 13c. Print feature importance
fnExtractFeatureImportance(pipelineModel.stages[-1].featureImportances, predictionsDF, "features")


Unnamed: 0,name,idx,importance_score
22,total_charges,1,0.23365
15,contractclassVec_Month-to-month,17,0.183778
7,internet_serviceclassVec_Fiber optic,9,0.14456
21,monthly_charges,0,0.129179
18,payment_methodclassVec_Electronic check,20,0.095273
16,contractclassVec_Two year,18,0.063349
9,online_securityclassVec_No,11,0.033914
8,internet_serviceclassVec_DSL,10,0.016387
17,paperless_billingclassVec_Yes,19,0.015341
10,online_backupclassVec_No,12,0.014167


In [27]:
# 13d. Capture into a Pandas DF
featureImportantcePDF = fnExtractFeatureImportance(pipelineModel.stages[-1].featureImportances, predictionsDF, "features")


In [28]:
# 13e. Persist feature importance scores to BigQuery
# Convert Pandas to Spark DF & use Spark to persist
featureImportantceDF = spark.createDataFrame(featureImportantcePDF).toDF("feature_index","feature_nm","importance_score")

persistFeatureImportanceDF = featureImportantceDF.withColumn("pipeline_id", lit(pipelineID).cast("string")) \
                                   .withColumn("model_version", lit(pipelineID).cast("string")) \
                                   .withColumn("pipeline_execution_dt", lit(pipelineExecutionDt)) \
                                   .withColumn("operation", lit(operation)) 

persistFeatureImportanceDF.show(2)

persistFeatureImportanceDF.write.format('bigquery') \
.mode("overwrite")\
.option('table', bigQueryFeatureImportanceTableFQN) \
.save()

+--------------------+----------+-------------------+-----------+-------------+---------------------+---------+
|       feature_index|feature_nm|   importance_score|pipeline_id|model_version|pipeline_execution_dt|operation|
+--------------------+----------+-------------------+-----------+-------------+---------------------+---------+
|       total_charges|         1|0.23364955448527888|       4244|         4244|       20240603065934| training|
|contractclassVec_...|        17|0.18377782190175462|       4244|         4244|       20240603065934| training|
+--------------------+----------+-------------------+-----------+-------------+---------------------+---------+
only showing top 2 rows



                                                                                

### MODEL EVALUATION

In [29]:
# 14a. Metrics parsing function
def fnParseModelMetrics(predictionsDF, labelColumn, operation, boolSubsetOnly):
    """
    Get model metrics
    Args:
        predictions: predictions
        labelColumn: target column
        operation: train or test
        boolSubsetOnly: boolean for partial(without true, score, prediction) or full metrics 
    Returns:
        metrics: metrics
        
    Anagha TODO: This function if called from common_utils fails; Need to researchy why
    """
    
    metricLabels = ['area_roc', 'area_prc', 'accuracy', 'f1', 'precision', 'recall']
    metricColumns = ['true', 'score', 'prediction']
    metricKeys = [f'{operation}_{ml}' for ml in metricLabels] + metricColumns

    # Instantiate evaluators
    bcEvaluator = BinaryClassificationEvaluator(labelCol=labelColumn)
    mcEvaluator = MulticlassClassificationEvaluator(labelCol=labelColumn)

    # Capture metrics -> areas, acc, f1, prec, rec
    area_roc = round(bcEvaluator.evaluate(predictionsDF, {bcEvaluator.metricName: 'areaUnderROC'}), 5)
    area_prc = round(bcEvaluator.evaluate(predictionsDF, {bcEvaluator.metricName: 'areaUnderPR'}), 5)
    acc = round(mcEvaluator.evaluate(predictionsDF, {mcEvaluator.metricName: "accuracy"}), 5)
    f1 = round(mcEvaluator.evaluate(predictionsDF, {mcEvaluator.metricName: "f1"}), 5)
    prec = round(mcEvaluator.evaluate(predictionsDF, {mcEvaluator.metricName: "weightedPrecision"}), 5)
    rec = round(mcEvaluator.evaluate(predictionsDF, {mcEvaluator.metricName: "weightedRecall"}), 5)

    # Get the true, score, prediction off of the test results dataframe
    rocDictionary = common_utils.fnGetTrueScoreAndPrediction(predictionsDF, labelColumn)
    true = rocDictionary['true']
    score = rocDictionary['score']
    prediction = rocDictionary['prediction']

    # Create a metric values array
    metricValuesArray = []
    if boolSubsetOnly:
        metricValuesArray.extend((area_roc, area_prc, acc, f1, prec, rec))
    else:
        metricValuesArray.extend((area_roc, area_prc, acc, f1, prec, rec, true, score, prediction))
    
    # Zip the keys and values into a dictionary  
    metricsDictionary = dict(zip(metricKeys, metricValuesArray))

    return metricsDictionary


In [None]:
# 14b. Capture & display metrics
modelMetrics = fnParseModelMetrics(predictionsDF, "label", "test", True)
for m, v in modelMetrics.items():
    print(f'{m}: {v}')
    

                                                                                

In [None]:
# 14c. Persist metrics subset to GCS
blobName = f"{modelBaseNm}/{operation}/{modelVersion}/subset/metrics.json"
common_utils.fnPersistMetrics(urlparse(metricsBucketUri).netloc, modelMetrics, blobName)


In [None]:
# 14d. Persist metrics in full to GCS
# (The version persisted to BQ does not have True, Score and Prediction needed for Confusion Matrix
# This version below has the True, Score and Prediction additionally) 

# 14d.1. Capture
modelMetricsWithTSP = fnParseModelMetrics(predictionsDF, "label", "test", False)

# 14d.2. Persist
blobName = f"{modelBaseNm}/{operation}/{modelVersion}/full/metrics.json"
print(blobName)
common_utils.fnPersistMetrics(urlparse(metricsBucketUri).netloc, modelMetricsWithTSP, blobName)

# 14d.3. Print
for m, v in modelMetricsWithTSP.items():
    print(f'{m}: {v}')
    

In [None]:
# 14e. Persist metrics subset to BigQuery
metricsDF = spark.createDataFrame(modelMetrics.items(), ["metric_nm", "metric_value"]) 
metricsWithPipelineIdDF = metricsDF.withColumn("pipeline_id", lit(pipelineID).cast("string")) \
                                   .withColumn("model_version", lit(pipelineID).cast("string")) \
                                   .withColumn("pipeline_execution_dt", lit(pipelineExecutionDt)) \
                                   .withColumn("operation", lit(operation)) 

metricsWithPipelineIdDF.show()

metricsWithPipelineIdDF.write.format('bigquery') \
.mode("overwrite")\
.option('table', bigQueryModelMetricsTableFQN) \
.save()


                                                                                

+--------------+------------+-----------+-------------+---------------------+---------+
|     metric_nm|metric_value|pipeline_id|model_version|pipeline_execution_dt|operation|
+--------------+------------+-----------+-------------+---------------------+---------+
| test_area_roc|     0.85105|       6451|         6451|       20240529131829| training|
| test_area_prc|     0.66582|       6451|         6451|       20240529131829| training|
| test_accuracy|     0.81228|       6451|         6451|       20240529131829| training|
|       test_f1|     0.78894|       6451|         6451|       20240529131829| training|
|test_precision|     0.80211|       6451|         6451|       20240529131829| training|
|   test_recall|     0.81228|       6451|         6451|       20240529131829| training|
+--------------+------------+-----------+-------------+---------------------+---------+

