# <center>Customer churn - a predictive model using Apache Spark framework<center>

### <div style="text-align: right">*M.S. student Darius Alexandru Cocirta*</div>

![alternatvie text](https://d1eipm3vz40hy0.cloudfront.net/images/AMER/calculate-customer-churn.jpg)

## Context

With the rapid development of telecommunication industry, the service providers are inclined more towards expansion of the subscriber base. To meet the need of surviving in the competitive environment, the retention of existing customers has become a huge challenge. It is stated that the cost of acquiring a new customer is far more than that for retaining the existing one. Therefore, it is imperative for the telecom industries to use advanced analytics to understand consumer behavior and in-turn predict the association of the customers as whether or not they will leave the company.

## What is customer churn?

Simply put, customer churn occurs when **customers or subscribers stop doing business with a company or service**. Also known as customer attrition, customer churn is a critical metric because it is much less expensive to retain existing customers than it is to acquire new customers – earning business from new customers means working leads all the way through the sales funnel, utilizing your marketing and sales resources throughout the process. Customer retention, on the other hand, is generally more cost-effective as you’ve already earned the trust and loyalty of existing customers.

## Assignment

* Creating a model to predict customer churn

## About data

Dataset source: https://www.kaggle.com/datasets/barun2104/telecom-churn

This data set contains customer level information for a telecom company. Various attributes related to the services used are recorded for each customer.

# <center>Algorithm implementation<center>

## Importing needed libraries

In [28]:
import findspark
findspark.init()
from IPython.core.display import HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

## Creating SparkContext

In [3]:
spark = SparkSession.builder.master("local").appName("Customer churn prediction").getOrCreate()
sc = spark.sparkContext

In [4]:
sc

## Data quality assessment

### Importing dataset

In [5]:
df = spark.read.csv(r"C:\Users\Asus\Desktop\dis\p_users_churn_pyspark\telecom_churn.csv", inferSchema = True, header = True)

In [6]:
df.printSchema()

root
 |-- Churn: integer (nullable = true)
 |-- AccountWeeks: integer (nullable = true)
 |-- ContractRenewal: integer (nullable = true)
 |-- DataPlan: integer (nullable = true)
 |-- DataUsage: double (nullable = true)
 |-- CustServCalls: integer (nullable = true)
 |-- DayMins: double (nullable = true)
 |-- DayCalls: integer (nullable = true)
 |-- MonthlyCharge: double (nullable = true)
 |-- OverageFee: double (nullable = true)
 |-- RoamMins: double (nullable = true)



In [7]:
df.count()

3333

In [8]:
df.show()

+-----+------------+---------------+--------+---------+-------------+-------+--------+-------------+----------+--------+
|Churn|AccountWeeks|ContractRenewal|DataPlan|DataUsage|CustServCalls|DayMins|DayCalls|MonthlyCharge|OverageFee|RoamMins|
+-----+------------+---------------+--------+---------+-------------+-------+--------+-------------+----------+--------+
|    0|         128|              1|       1|      2.7|            1|  265.1|     110|         89.0|      9.87|    10.0|
|    0|         107|              1|       1|      3.7|            1|  161.6|     123|         82.0|      9.78|    13.7|
|    0|         137|              1|       0|      0.0|            0|  243.4|     114|         52.0|      6.06|    12.2|
|    0|          84|              0|       0|      0.0|            2|  299.4|      71|         57.0|       3.1|     6.6|
|    0|          75|              0|       0|      0.0|            3|  166.7|     113|         41.0|      7.42|    10.1|
|    0|         118|            

In [10]:
df.columns

['Churn',
 'AccountWeeks',
 'ContractRenewal',
 'DataPlan',
 'DataUsage',
 'CustServCalls',
 'DayMins',
 'DayCalls',
 'MonthlyCharge',
 'OverageFee',
 'RoamMins']

In [11]:
len(df.columns)

11

### Target variable

In [13]:
df.groupBy("churn").count().show()

+-----+-----+
|churn|count|
+-----+-----+
|    1|  483|
|    0| 2850|
+-----+-----+



## Exploratory analysis 

In [15]:
df.describe().show()

+-------+-------------------+------------------+-------------------+-------------------+------------------+------------------+------------------+------------------+-----------------+------------------+------------------+
|summary|              Churn|      AccountWeeks|    ContractRenewal|           DataPlan|         DataUsage|     CustServCalls|           DayMins|          DayCalls|    MonthlyCharge|        OverageFee|          RoamMins|
+-------+-------------------+------------------+-------------------+-------------------+------------------+------------------+------------------+------------------+-----------------+------------------+------------------+
|  count|               3333|              3333|               3333|               3333|              3333|              3333|              3333|              3333|             3333|              3333|              3333|
|   mean|0.14491449144914492|101.06480648064806|  0.903090309030903|0.27662766276627665|0.8164746474647478|1.5628562

## Modeling

### Pre processing

In [31]:
assembler = VectorAssembler(inputCols = ['AccountWeeks',
 'ContractRenewal',
 'DataPlan',
 'DataUsage',
 'CustServCalls',
 'DayMins',
 'DayCalls',
 'MonthlyCharge',
 'OverageFee',
 'RoamMins'], outputCol = "features")

In [32]:
output = assembler.transform(df)

In [33]:
dataframe = output.select('features', 'churn')

In [34]:
dataframe.show(truncate = False)

+---------------------------------------------------+-----+
|features                                           |churn|
+---------------------------------------------------+-----+
|[128.0,1.0,1.0,2.7,1.0,265.1,110.0,89.0,9.87,10.0] |0    |
|[107.0,1.0,1.0,3.7,1.0,161.6,123.0,82.0,9.78,13.7] |0    |
|[137.0,1.0,0.0,0.0,0.0,243.4,114.0,52.0,6.06,12.2] |0    |
|[84.0,0.0,0.0,0.0,2.0,299.4,71.0,57.0,3.1,6.6]     |0    |
|[75.0,0.0,0.0,0.0,3.0,166.7,113.0,41.0,7.42,10.1]  |0    |
|[118.0,0.0,0.0,0.0,0.0,223.4,98.0,57.0,11.03,6.3]  |0    |
|[121.0,1.0,1.0,2.03,3.0,218.2,88.0,87.3,17.43,7.5] |0    |
|[147.0,0.0,0.0,0.0,0.0,157.0,79.0,36.0,5.16,7.1]   |0    |
|[117.0,1.0,0.0,0.19,1.0,184.5,97.0,63.9,17.58,8.7] |0    |
|[141.0,0.0,1.0,3.02,0.0,258.6,84.0,93.2,11.1,11.2] |0    |
|[65.0,1.0,0.0,0.29,4.0,129.1,137.0,44.9,11.43,12.7]|1    |
|[74.0,1.0,0.0,0.34,0.0,187.7,127.0,49.4,8.17,9.1]  |0    |
|[168.0,1.0,0.0,0.0,1.0,128.8,96.0,31.0,5.25,11.2]  |0    |
|[95.0,1.0,0.0,0.44,3.0,156.6,88.0,52.4,

**Train - test split**

In [423]:
train, test = dataframe.randomSplit([0.7, 0.3], seed = 20)

In [424]:
train.count()

2340

In [425]:
test.count()

993

### Fitting logistic regression model

Logistic regression is a statistical analysis method to predict a binary outcome based on prior observations of a data set. A logistic regression model predicts a dependent data variable by analyzing the relationship between one or more existing independent variables.

In [427]:
log_reg = LogisticRegression(labelCol = "churn").fit(train)

In [428]:
log_reg_summary = log_reg.summary

In [429]:
log_reg

LogisticRegressionModel: uid=LogisticRegression_e80848866ba3, numClasses=2, numFeatures=10

In [430]:
log_reg_summary.predictions.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[1.0,1.0,0.0,0.0,...|  0.0|[3.79972763177822...|[0.97811289894901...|       0.0|
|[1.0,1.0,0.0,0.0,...|  0.0|[3.11620355648105...|[0.95755619961555...|       0.0|
|[1.0,1.0,0.0,0.0,...|  1.0|[0.70894274910876...|[0.67016750393706...|       0.0|
|[1.0,1.0,0.0,0.25...|  0.0|[1.79035447197302...|[0.85697073032035...|       0.0|
|[1.0,1.0,1.0,2.19...|  0.0|[3.20849682002671...|[0.96115277852310...|       0.0|
|[1.0,1.0,1.0,2.27...|  0.0|[3.77664139097788...|[0.97761317372948...|       0.0|
|[1.0,1.0,1.0,2.7,...|  0.0|[3.99399507317982...|[0.98190741921345...|       0.0|
|[2.0,0.0,0.0,0.27...|  1.0|[1.09368828352853...|[0.74907561298243...|       0.0|
|[3.0,0.0,0.0,0.26...|  0.0|[0.52428324621496...|[0.62814878584282...|       0.0|
|[3.0,1.0,0.0,0.

In [431]:
log_reg_summary.predictions.describe().show()

+-------+-------------------+--------------------+
|summary|              churn|          prediction|
+-------+-------------------+--------------------+
|  count|               2340|                2340|
|   mean| 0.1482905982905983|0.049145299145299144|
| stddev|0.35546377367657656|  0.2162175232424847|
|    min|                0.0|                 0.0|
|    max|                1.0|                 1.0|
+-------+-------------------+--------------------+



### Model evaluation

### Accuracy & precision & recall

In [432]:
results = log_reg.evaluate(test).predictions

In [433]:
results.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[1.0,1.0,0.0,0.0,...|    0|[3.27636752575371...|[0.96360911945117...|       0.0|
|[3.0,1.0,0.0,0.0,...|    0|[1.28775964790889...|[0.78376774495107...|       0.0|
|[6.0,1.0,0.0,0.0,...|    0|[1.90954591690239...|[0.87096812538758...|       0.0|
|[6.0,1.0,0.0,0.0,...|    0|[1.23094022470559...|[0.77398309321558...|       0.0|
|[10.0,1.0,0.0,0.0...|    0|[1.78702825687788...|[0.85656254553746...|       0.0|
|[10.0,1.0,0.0,0.4...|    0|[2.00309865710841...|[0.88112203340335...|       0.0|
|[12.0,0.0,0.0,0.0...|    0|[0.08224573535645...|[0.52054985126052...|       0.0|
|[13.0,1.0,1.0,3.2...|    0|[0.45449393768817...|[0.61170717511600...|       0.0|
|[13.0,1.0,1.0,3.4...|    0|[2.50972559709356...|[0.92482081425520...|       0.0|
|[15.0,1.0,0.0,0

In [434]:
results.select(['churn', 'prediction']).show(10)

+-----+----------+
|churn|prediction|
+-----+----------+
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
|    0|       0.0|
+-----+----------+
only showing top 10 rows



### Computing model accuracy & precision & recall

In [444]:
true_pos = results[(results.churn ==1) & (results.prediction ==1)].count()
true_neg = results[(results.churn ==0) & (results.prediction ==0)].count()
false_pos = results[(results.churn ==0) & (results.prediction ==1)].count()
false_neg = results[(results.churn ==1) & (results.prediction ==0)].count()

**Accuracy**

In [445]:
accuracy = float((true_pos + true_neg) / results.count())

In [446]:
print(accuracy)

0.8741188318227593


**Precision**

In [447]:
precision = float(true_pos / (true_pos + false_pos))

In [448]:
print(precision)

0.6037735849056604


**Recall**

In [449]:
recall = float(true_pos / (true_pos + false_neg))

In [450]:
print(recall)

0.23529411764705882


**F1 score**

In [452]:
f1_score = 2*((precision*recall)/(precision + recall))

In [453]:
f1_score

0.33862433862433866

## Conclusion

**I ended up with a 87% accuracy of the model. On the other hand, the recall of 0.23 is a good one in this case because recall depends on the context and the desired outcomes of the model. If you have a highly imbalanced dataset where the positive cases are rare, a recall of 0.23 might be considered good because it is still able to identify a significant portion of the positive cases.
I had 2850 negative instances and only 483 positive. Also, the same is valid for F1 score.**