In [2]:
## Import Libraries
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType, TimestampType
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

## Set seed
seed = 42

In [3]:
## Create Spark Session
spark = SparkSession.builder.appName('logRegConsProject').getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [4]:
## Setup Schema
schema = StructType(fields=[StructField('name', StringType(), True),
                            StructField('age', DoubleType(), True),
                            StructField('total_purchase', DoubleType(), True),
                            StructField('account_manager', IntegerType(), True),
                            StructField('years', DoubleType(), True),
                            StructField('num_sites', DoubleType(), True),
                            StructField('onboard_date', TimestampType(), True),
                            StructField('location', StringType(), True),
                            StructField('company', StringType(), True),
                            StructField('churn', IntegerType(), True)])

In [5]:
## Load Data
df = spark.read.csv('gs://spark-training-data/datasets/customer_churn.csv', header=True,
                    inferSchema=False, schema=schema)
df.show(5)
df.printSchema() ## Confirm proper schema

[Stage 0:>                                                          (0 + 1) / 1]

+----------------+----+--------------+---------------+-----+---------+-------------------+--------------------+--------------------+-----+
|            name| age|total_purchase|account_manager|years|num_sites|       onboard_date|            location|             company|churn|
+----------------+----+--------------+---------------+-----+---------+-------------------+--------------------+--------------------+-----+
|Cameron Williams|42.0|       11066.8|              0| 7.22|      8.0|2013-08-30 07:00:40|10265 Elizabeth M...|          Harvey LLC|    1|
|   Kevin Mueller|41.0|      11916.22|              0|  6.5|     11.0|2013-08-13 00:38:46|6157 Frank Garden...|          Wilson PLC|    1|
|     Eric Lozano|38.0|      12884.75|              0| 6.67|     12.0|2016-06-29 06:20:07|1331 Keith Court ...|Miller, Johnson a...|    1|
|   Phillip White|42.0|       8010.76|              0| 6.71|     10.0|2014-04-22 12:43:12|13120 Daniel Moun...|           Smith Inc|    1|
|  Cynthia Norton|37.0|    

                                                                                

In [6]:
## Assembler & Create modeling df
## Does not appear that indexers or encoders are necessary
assembler = VectorAssembler(inputCols=['age', 'total_purchase', 'account_manager',
                                       'years', 'num_sites'],
                           outputCol='features')
output = assembler.transform(df)
final_data = output.select('features','churn')

In [7]:
## Split Data
train_data, test_data = final_data.randomSplit([0.7, 0.3], seed=seed)

In [8]:
## Build model
log_reg = LogisticRegression(featuresCol='features', labelCol='churn', predictionCol='prediction')
log_reg_model = log_reg.fit(train_data)

21/11/24 21:57:09 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
21/11/24 21:57:09 WARN com.github.fommil.netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
21/11/24 21:57:11 WARN org.apache.spark.storage.BlockManager: Asked to remove block broadcast_36_piece0, which does not exist
21/11/24 21:57:11 WARN org.apache.spark.storage.BlockManager: Asked to remove block broadcast_36, which does not exist


In [9]:
## Summarize Model
model_summary = log_reg_model.summary
model_summary.predictions.describe().show()

[Stage 46:>                                                         (0 + 1) / 1]

+-------+------------------+-------------------+
|summary|             churn|         prediction|
+-------+------------------+-------------------+
|  count|               667|                667|
|   mean|0.1634182908545727|0.12293853073463268|
| stddev|0.3700243606477147|0.32861306618408714|
|    min|               0.0|                0.0|
|    max|               1.0|                1.0|
+-------+------------------+-------------------+



                                                                                

In [10]:
## Evaluate the model
pred_and_labels = log_reg_model.evaluate(test_data)
pred_and_labels.predictions.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[26.0,8787.39,1.0...|    1|[0.79106193949545...|[0.68805930409057...|       0.0|
|[28.0,9090.43,1.0...|    0|[1.61026634613434...|[0.83344836179841...|       0.0|
|[28.0,11204.23,0....|    0|[1.97148327271884...|[0.87777034205971...|       0.0|
|[28.0,11245.38,0....|    0|[3.75330942021012...|[0.97709680745324...|       0.0|
|[29.0,9617.59,0.0...|    0|[4.42202740353912...|[0.98813266624674...|       0.0|
|[29.0,10203.18,1....|    0|[3.71080374825935...|[0.97612604829734...|       0.0|
|[29.0,11274.46,1....|    0|[4.39058453619493...|[0.98775823543341...|       0.0|
|[30.0,6744.87,0.0...|    0|[3.55749176407943...|[0.97228005685650...|       0.0|
|[30.0,8403.78,1.0...|    0|[5.76304532016813...|[0.99686830825215...|       0.0|
|[30.0,8874.83,0

In [11]:
## Check AUC
churn_eval = BinaryClassificationEvaluator(rawPredictionCol='prediction', labelCol='churn')
auc = churn_eval.evaluate(pred_and_labels.predictions)
auc

0.7456808943089431

In [12]:
## Model "Deployment"
log_reg_model_final = log_reg.fit(final_data)

In [13]:
## Read in new data
new_data = spark.read.csv('gs://spark-training-data/datasets/new_customers.csv', header=True,
                          inferSchema=False, schema=schema)
new_data.printSchema()

root
 |-- name: string (nullable = true)
 |-- age: double (nullable = true)
 |-- total_purchase: double (nullable = true)
 |-- account_manager: integer (nullable = true)
 |-- years: double (nullable = true)
 |-- num_sites: double (nullable = true)
 |-- onboard_date: timestamp (nullable = true)
 |-- location: string (nullable = true)
 |-- company: string (nullable = true)
 |-- churn: integer (nullable = true)



In [15]:
## Transform the data
new_data_assembled = assembler.transform(new_data)
new_data_assembled.show(5)

+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+-----+--------------------+
|          name| age|total_purchase|account_manager|years|num_sites|       onboard_date|            location|         company|churn|            features|
+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+-----+--------------------+
| Andrew Mccall|37.0|       9935.53|              1| 7.71|      8.0|2011-08-29 18:37:54|38612 Johnny Stra...|        King Ltd| null|[37.0,9935.53,1.0...|
|Michele Wright|23.0|       7526.94|              1| 9.28|     15.0|2013-07-22 18:19:54|21083 Nicole Junc...|   Cannon-Benson| null|[23.0,7526.94,1.0...|
|  Jeremy Chang|65.0|         100.0|              1|  1.0|     15.0|2006-12-11 07:48:13|085 Austin Views ...|Barron-Robertson| null|[65.0,100.0,1.0,1...|
|Megan Ferguson|32.0|        6487.5|              0|  9.4|     14.0|2016-10-

In [16]:
## Make Predictions
new_data_predictions = log_reg_model_final.transform(new_data_assembled)
new_data_predictions.select('company','prediction').show()

+----------------+----------+
|         company|prediction|
+----------------+----------+
|        King Ltd|       0.0|
|   Cannon-Benson|       1.0|
|Barron-Robertson|       1.0|
|   Sexton-Golden|       1.0|
|        Wood LLC|       0.0|
|   Parks-Robbins|       1.0|
+----------------+----------+

