# Pyspark Preprocessor

In [22]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *

# new session
spark = SparkSession.builder.getOrCreate()

## Load schema and data

In [23]:
# schema for csv
# customerID, gender, SeniorCitizen, Partner, Dependents, tenure, PhoneService, MultipleLines, InternetService, OnlineSecurity, OnlineBackup, DeviceProtection, TechSupport, StreamingTV, StreamingMovies, Contract, PaperlessBilling, PaymentMethod, MonthlyCharges, TotalCharges, Churn
schema = StructType(
    [
        StructField("customerID", StringType(), nullable=False),
        StructField("gender", StringType()),
        StructField("SeniorCitizen", IntegerType()),
        StructField("Partner", StringType()),
        StructField("Dependents", StringType()),
        StructField("tenure", IntegerType()),
        StructField("PhoneService", StringType()),
        StructField("MultipleLines", StringType()),
        StructField("InternetService", StringType()),
        StructField("OnlineSecurity", StringType()),
        StructField("OnlineBackup", StringType()),
        StructField("DeviceProtection", StringType()),
        StructField("TechSupport", StringType()),
        StructField("StreamingTV", StringType()),
        StructField("StreamingMovies", StringType()),
        StructField("Contract", StringType()),
        StructField("PaperlessBilling", StringType()),
        StructField("PaymentMethod", StringType()),
        StructField("MonthlyCharges", FloatType()),
        StructField("TotalCharges", FloatType()),
        StructField("Churn", StringType()),
    ]
)

In [24]:
# read csv into schema

df = (
    spark.read.option("header", True)
    .schema(schema)
    .csv("WA_Fn-UseC_-Telco-Customer-Churn.csv")
)

# df.show()

In [25]:
df.printSchema()

df.count()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: float (nullable = true)
 |-- TotalCharges: float (nullable = true)
 |-- Churn: string (nullable = true)



7043

## Description
- customerID: string
- gender: Male/Female
- SeniorCitizen: 0/1
- Partner: Yes/No
- Dependents: Yes/No
- tenure: integer
- PhoneService: Yes/No
    - MultipleLines: Yes/No/No phone service
- InternetService: string
    - OnlineSecurity: Yes/No/No internet service
    - OnlineBackup: Yes/No/No internet service
    - DeviceProtection: Yes/No/No internet service
    - TechSupport: Yes/No/No internet service
    - StreamingTV: Yes/No/No internet service
    - StreamingMovies: Yes/No/No internet service
- Contract: string
- PaperlessBilling: Yes/No
- PaymentMethod: string
- MonthlyCharges: float
- TotalCharges: float
- Churn: Yes/No


## Cleaning
- check for typos
- rename columns
- drop duplicates
- remove null
- check consistency for PhoneService or InternetService related columns
- check consistency for SeniorCitizen
- fill missing values for PhoneService or InternetService related columns
- delete invalid values (NaN or negatives)
- one-hot encoding for string valued columns
- reformat for boolean columns

### Rename columns

In [26]:
df = df.withColumnsRenamed(
    {"customerID": "CustomerID", "gender": "Gender", "tenure": "Tenure"}
)

df.show()

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|CustomerID|Gender|SeniorCitizen|Partner|Dependents|Tenure|PhoneService|   MultipleLines|InternetService|     OnlineSecurity|       OnlineBackup|   DeviceProtection|        TechSupport|        StreamingTV|    StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|  

### Check for typos

In [None]:
from rapidfuzz import process
genders = ['Male', 'Female']
yesno = ['Yes', 'No']
yesno_phone = ["Yes", "No", "No phone service"]
yesno_internet = ["Yes", "No", "No internet service"]

def match_gender(word):
    if word is None:
        return None
    best_match, score = process.extractOne(word, genders)

    if score >= 80:
        return best_match
    else:
        return None

def match_yesno(word):
    if word is None:
        return None
    best_match, score = process.extractOne(word, yesno)

    if score >= 80:
        return best_match
    else:
        return None


def match_yesno_phone(word):
    if word is None:
        return None
    best_match, score = process.extractOne(word, yesno_phone)

    if score >= 80:
        return best_match
    else:
        return None


def match_yesno_internet(word):
    if word is None:
        return None
    best_match, score = process.extractOne(word, yesno_internet)

    if score >= 80:
        return best_match
    else:
        return None


# define udf for spark
udf_match_gender = udf(match_gender, StringType())
udf_match_yesno = udf(match_yesno, StringType())
udf_match_yesno_phone = udf(match_yesno_phone, StringType())
udf_match_yesno_internet = udf(match_yesno_internet, StringType())

# fix detected typos, or set null if not considered typo
df = df.withColumns(
    {
        "Gender": udf_match_gender(col("Gender")),
        "Partner": udf_match_yesno(col("Partner")),
        "Dependents": udf_match_yesno(col("Dependents")),
        "PhoneService": udf_match_yesno(col("PhoneService")),
        "MultipleLines": udf_match_yesno_phone(col("MultipleLines")),
        "OnlineSecurity": udf_match_yesno_internet(col("OnlineSecurity")),
        "OnlineBackup": udf_match_yesno_internet(col("OnlineBackup")),
        "DeviceProtection": udf_match_yesno_internet(col("DeviceProtection")),
        "TechSupport": udf_match_yesno_internet(col("TechSupport")),
        "StreamingTV": udf_match_yesno_internet(col("StreamingTV")),
        "StreamingMovies": udf_match_yesno_internet(col("StreamingMovies")),
        "PaperlessBilling": udf_match_yesno(col("PaperlessBilling")),
        "Churn": udf_match_yesno(col("Churn")),
    }
)
df.show()

Py4JJavaError: An error occurred while calling o561.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 41.0 failed 1 times, most recent failure: Lost task 0.0 in stage 41.0 (TID 25) (latitude executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:54)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:131)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:858)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:701)
	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:745)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:698)
	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:663)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:639)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:585)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:543)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 27 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:530)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:483)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:61)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4333)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3316)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4323)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4321)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4321)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:3316)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:3539)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:280)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:315)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:75)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:52)
	at java.base/java.lang.reflect.Method.invoke(Method.java:580)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec.evaluate(BatchEvalPythonExec.scala:54)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:131)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:858)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:858)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	... 1 more
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.base/sun.nio.ch.NioSocketImpl.timedAccept(NioSocketImpl.java:701)
	at java.base/sun.nio.ch.NioSocketImpl.accept(NioSocketImpl.java:745)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:698)
	at java.base/java.net.ServerSocket.platformImplAccept(ServerSocket.java:663)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:639)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:585)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:543)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 27 more


### Drop duplicates and null values

In [6]:
# drop duplicates
df = df.dropDuplicates()

# drop duplicates by CustomerID
df = df.dropDuplicates(['CustomerID'])

# df.count()

In [7]:
# drop na on non PhoneService or InternetService related columns
df = df.dropna(
    subset=[
        "CustomerID",
        "Gender",
        "SeniorCitizen",
        "Partner",
        "Dependents",
        "Tenure",
        "PhoneService",
        "InternetService",
        "Contract",
        "PaperlessBilling",
        "PaymentMethod",
        "MonthlyCharges",
        "TotalCharges",
        "Churn",
    ]
)
# df.count()

### Check consistency for PhoneService

In [8]:
# check consistency for PhoneService related columns
invalid_phone = df.where("PhoneService == 'No' AND MultipleLines <> 'No phone service'")

# subtract invalid PhoneService from df
df = df.subtract(invalid_phone)
# df.count()

### Check consistency for InternetService

In [9]:
# check consistency for InternetService related columns
invalid_internet = df.where(
    """InternetService = 'No' AND (
                           OnlineSecurity <> 'No internet service' OR
                           OnlineBackup <> 'No internet service' OR
                           DeviceProtection <> 'No internet service' OR
                           TechSupport <> 'No internet service' OR
                           StreamingTV <> 'No internet service' OR
                           StreamingMovies <> 'No internet service')"""
)

# subtract invalid InternetService from df
df = df.subtract(invalid_internet)
# df.count()

### Check consistency for SeniorCitizen

In [10]:
invalid_senior = df.where("SeniorCitizen <> 0 AND SeniorCitizen <> 1")

# subtract
df = df.subtract(invalid_senior)
# df.count()

### Fill missing values (PhoneService and InternetService)

- if PhoneService or InternetService == 'Yes' and respective column is missing, drop
- if either == 'No', fill missing column with 'No phone service' or 'No internet service' respectively


In [12]:
# if PhoneService or InternetService == 'Yes' and respective column is missing, drop
missing_phone = df.where("PhoneService == 'Yes' AND MultipleLines IS NULL")

# subtract missing "Yes" on phones
df = df.subtract(missing_phone)
# df.count()

In [13]:
# if PhoneService or InternetService == 'Yes' and respective column is missing, drop
missing_internet = df.where(
    """InternetService == 'Yes' AND (
    OnlineSecurity IS NULL OR
    OnlineBackup IS NULL OR
    DeviceProtection IS NULL OR
    TechSupport IS NULL OR
    StreamingTV IS NULL OR
    StreamingMovies IS NULL)"""
)

# subtract missing "Yes" on internet
df = df.subtract(missing_internet)
# df.count()

In [14]:
# if == 'No', fill missing column with 'No phone service' or 'No internet service' respectively

# handle PhoneService
df = df.withColumn("MultipleLines", when(col("PhoneService") == "No", "No phone service").otherwise(col("MultipleLines")))

# handle InternetService
df = df.withColumns(
    {
        "OnlineSecurity": when(
            col("InternetService") == "No", "No internet service"
        ).otherwise(col("OnlineSecurity")),
        "OnlineBackup": when(
            col("InternetService") == "No", "No internet service"
        ).otherwise(col("OnlineBackup")),
        "DeviceProtection": when(
            col("InternetService") == "No", "No internet service"
        ).otherwise(col("DeviceProtection")),
        "TechSupport": when(
            col("InternetService") == "No", "No internet service"
        ).otherwise(col("TechSupport")),
        "StreamingTV": when(
            col("InternetService") == "No", "No internet service"
        ).otherwise(col("StreamingTV")),
        "StreamingMovies": when(
            col("InternetService") == "No", "No internet service"
        ).otherwise(col("StreamingMovies"))
        
    }
)

### Delete invalid values

In [15]:
invalid_float = df.where("Tenure < 0 OR MonthlyCharges < 0 OR TotalCharges < 0 OR MonthlyCharges == 'NaN' OR TotalCharges == 'NaN'")
df = df.subtract(invalid_float)
# df.count()

### One-hot encoding

In [16]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer

# setup string indexer for df
si = StringIndexer(
    inputCols=[
        "Gender",
        "MultipleLines",
        "InternetService",
        "OnlineSecurity",
        "OnlineBackup",
        "DeviceProtection",
        "TechSupport",
        "StreamingTV",
        "StreamingMovies",
        "Contract",
        "PaymentMethod",
    ],
    outputCols=[
        "GenderIndexed",
        "MultipleLinesIndexed",
        "InternetServiceIndexed",
        "OnlineSecurityIndexed",
        "OnlineBackupIndexed",
        "DeviceProtectionIndexed",
        "TechSupportIndexed",
        "StreamingTVIndexed",
        "StreamingMoviesIndexed",
        "ContractIndexed",
        "PaymentMethodIndexed",
    ],
)
si_model = si.fit(df)

# indexing categoricals on df
indexed_df = si_model.transform(df)
indexed_df.show()



+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+-------------+--------------------+----------------------+---------------------+-------------------+-----------------------+------------------+------------------+----------------------+---------------+--------------------+
|CustomerID|Gender|SeniorCitizen|Partner|Dependents|Tenure|PhoneService|   MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|GenderIndexed|MultipleLinesIndexed|InternetServiceIndexed|OnlineSecurityIndexed|OnlineBackupIndexed|DeviceProtectionIndexed|TechSupportIndexed|StreamingTVIndexed|StreamingMoviesIndexed|ContractIndexed|PaymentMethodIndexed|


In [17]:
# setup onehotencoder
ohe = OneHotEncoder(
    inputCols=[
        "GenderIndexed",
        "MultipleLinesIndexed",
        "InternetServiceIndexed",
        "OnlineSecurityIndexed",
        "OnlineBackupIndexed",
        "DeviceProtectionIndexed",
        "TechSupportIndexed",
        "StreamingTVIndexed",
        "StreamingMoviesIndexed",
        "ContractIndexed",
        "PaymentMethodIndexed",
    ],
    outputCols=[
        "GenderVector",
        "MultipleLinesVector",
        "InternetServiceVector",
        "OnlineSecurityVector",
        "OnlineBackupVector",
        "DeviceProtectionVector",
        "TechSupportVector",
        "StreamingTVVector",
        "StreamingMoviesVector",
        "ContractVector",
        "PaymentMethodVector",
    ],
)
ohe_model = ohe.fit(indexed_df)

# encode
encoded_df = ohe_model.transform(indexed_df)
encoded_df.show()

df = encoded_df



+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+-------------+--------------------+----------------------+---------------------+-------------------+-----------------------+------------------+------------------+----------------------+---------------+--------------------+-------------+-------------------+---------------------+--------------------+------------------+----------------------+-----------------+-----------------+---------------------+--------------+-------------------+
|CustomerID|Gender|SeniorCitizen|Partner|Dependents|Tenure|PhoneService|   MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|GenderIndex

### Reformat values

In [18]:
df = df.withColumns({
    'SeniorCitizen': when(col('SeniorCitizen') == 1, True).otherwise(False),
    'Partner': when(col('Partner') == 'Yes', True).otherwise(False),
    'Dependents': when(col('Dependents') == 'Yes', True).otherwise(False),
    'PhoneService': when(col('PhoneService') == 'Yes', True).otherwise(False),
    'PaperlessBilling': when(col('PaperlessBilling') == 'Yes', True).otherwise(False),
    'Churn': when(col('Churn') == 'Yes', True).otherwise(False),
})
df.show()

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+-------------+--------------------+----------------------+---------------------+-------------------+-----------------------+------------------+------------------+----------------------+---------------+--------------------+-------------+-------------------+---------------------+--------------------+------------------+----------------------+-----------------+-----------------+---------------------+--------------+-------------------+
|CustomerID|Gender|SeniorCitizen|Partner|Dependents|Tenure|PhoneService|   MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|GenderIndex

## Export

In [None]:
# export parquet
df.write.mode('overwrite').parquet('preprocessed.parquet')