# Import libraries

In [1]:
import os
os.environ['JAVA_HOME'] = 'C:\\Program Files\\Java\\jdk-11'
os.environ['PATH'] = os.environ['JAVA_HOME'] + '\\bin;' + os.environ['PATH']
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
from pyspark.ml import Pipeline
from pyspark.ml.feature import StandardScaler, StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.feature import IndexToString

# Create Spark Session

In [2]:
spark = SparkSession.builder.appName("Churn Prediction").config("spark.memory.offHeap.enabled","true").config("spark.memory.offHeap.size","10g").getOrCreate()

# Create dataframe

In [3]:
df_path = r'data\WA_Fn-UseC_-Telco-Customer-Churn.csv'
df = spark.read.csv(df_path,header=True)
df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: string (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: string (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: string (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)



In [4]:
df.show(5)

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|          No|No phone service|            DSL|            No|         Yes|              No|         No|    

In [5]:
# Change certain types to numeric
to_numeric = ['tenure', 'MonthlyCharges', 'TotalCharges']

for col in to_numeric:
    df = df.withColumn(col, F.col(col).cast(FloatType()))

df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: string (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: float (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: float (nullable = true)
 |-- TotalCharges: float (nullable = true)
 |-- Churn: string (nullable = true)



# Data Cleaning

In [6]:
# Delete ID as it is unnecessary
df = df.drop('CustomerID')

# Check for null values and outliers
print('\nChecking for null_values')
for col in df.columns:
    null_count = df.filter(F.col(col).isNull()).count()
    if null_count>0:
        print(f"Column '{col}' has {null_count} null values")
        first_value_type = type(df.select(col).first()[0])

        if first_value_type==str:
            most_frequent_value = df.groupby(col).count().orderBy('count',ascending=False).first()[0]
            df = df.fillna({col: most_frequent_value})
        else:
            avg_value = df.agg(F.avg('TotalCharges')).collect()[0][0]
            df = df.fillna({col: avg_value})

print('\nChecking for outliers')
for col in df.columns:
    first_value_type = type(df.select(col).first()[0])

    if first_value_type!=str:    
        quantiles = df.approxQuantile(col, [0.25,0.75],0.01)
        Q1 = quantiles[0]
        Q3 = quantiles[1]
        IQR = Q3-Q1

        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR

        outliers = df.filter((F.col(col) < lower_bound) | (F.col(col) > upper_bound))
        if outliers.count()>0:
            print(f"Column '{col}' has {outliers.count()} outliers")
            df = df.filter((F.col(col) >= lower_bound) & (F.col(col) <= upper_bound))


Checking for null_values
Column 'TotalCharges' has 11 null values

Checking for outliers
Column 'TotalCharges' has 5 outliers


# Data preparation

In [7]:
labelCol = 'Churn'

# Convert categorical data to numeric representation
indexers = [StringIndexer(inputCol=col, outputCol=col+'_indexed') for col in df.columns if col not in to_numeric and col!=labelCol]

# Perform OneHotEncoding to numerical categorical data
encoders = [OneHotEncoder(inputCol=indexer.getOutputCol(), outputCol=indexer.getOutputCol() + '_encoded') for indexer in indexers]

# Assemble columns in feature vector col
assembler = VectorAssembler(inputCols=to_numeric+[encoder.getOutputCol() for encoder in encoders], outputCol='feature_vector')

# Standardize feature vector col
scaler = StandardScaler(inputCol='feature_vector', outputCol='scaled_feature_vector')

# Create data preparation pipeline
pipeline = Pipeline(stages=indexers + encoders + [assembler, scaler])

# Fit to data and transform
model = pipeline.fit(df)
final_df = model.transform(df)

# Transforming labelCol to integer col
label_indexer = StringIndexer(inputCol="Churn", outputCol="label").fit(final_df)
final_df = label_indexer.transform(final_df)

# Convert the label column from double (float) to integer
final_df = final_df.withColumn("label", F.col("label").cast("int"))

final_df.select('scaled_feature_vector').show(5)

+---------------------+
|scaled_feature_vector|
+---------------------+
| (30,[0,1,2,4,6,11...|
| (30,[0,1,2,3,4,5,...|
| [0.08148068183113...|
| (30,[0,1,2,3,4,5,...|
| (30,[0,1,2,4,5,6,...|
+---------------------+
only showing top 5 rows



# Train test split

In [8]:
train_df, test_df = final_df.randomSplit([0.7, 0.3], seed=42)
print(train_df.count(), test_df.count())

5032 2006


# Modeling

In [9]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

gbt = GBTClassifier(featuresCol='scaled_feature_vector', labelCol='label')
paramGrid = ParamGridBuilder()\
    .addGrid(gbt.maxDepth,[3,5])\
    .addGrid(gbt.maxIter, [20,50])\
    .build()
crossval = CrossValidator(estimator=gbt,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(),
                          numFolds=3)
model = crossval.fit(train_df)
predictions = model.transform(test_df)
predictions.show(5)

+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+----------------+--------------+------------+-----+--------------+---------------------+---------------+------------------+--------------------+---------------------+-----------------------+----------------------+--------------------+------------------------+-------------------+-------------------+-----------------------+----------------+------------------------+---------------------+----------------------+-----------------------------+-----------------------+--------------------------+----------------------------+-----------------------------+-------------------------------+------------------------------+----------------------------+--------------------------------+---------------------------+---------------------------+-------------------------------+------------------------+-

# Convert predicted values back to labels

In [10]:
label_converter = IndexToString(inputCol='prediction',outputCol='prediction_label', labels=label_indexer.labels)

# Apply the label_converter to the predictions DataFrame
predictions_with_labels = label_converter.transform(predictions)
predictions_with_labels.select('Churn','prediction_label').show(5)

+-----+----------------+
|Churn|prediction_label|
+-----+----------------+
|   No|             Yes|
|  Yes|             Yes|
|   No|              No|
|  Yes|              No|
|   No|              No|
+-----+----------------+
only showing top 5 rows



# Metrics

In [11]:
evaluator_multi = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator_multi.evaluate(predictions)
print(f"Accuracy: {accuracy}")

precision = evaluator_multi.evaluate(predictions, {evaluator_multi.metricName: "weightedPrecision"})
print(f"Precision: {precision}")

recall = evaluator_multi.evaluate(predictions, {evaluator_multi.metricName: "weightedRecall"})
print(f"Recall: {recall}")

f1 = evaluator_multi.evaluate(predictions, {evaluator_multi.metricName: "f1"})
print(f"F1 Score: {f1}")

# Evaluate the model using BinaryClassificationEvaluator for AUC
evaluator_auc = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
auc = evaluator_auc.evaluate(predictions)
print(f"AUC: {auc}")

Accuracy: 0.8045862412761715
Precision: 0.7946565932343269
Recall: 0.8045862412761715
F1 Score: 0.7966776310280208
AUC: 0.8377787483300815
