In [1]:
# notebook parameters

import os

spark_master = "local[*]"
app_name = "augment"
input_file = os.path.join("data", "WA_Fn-UseC_-Telco-Customer-Churn-.csv")
output_prefix = ""
output_mode = "overwrite"
output_kind = "parquet"
driver_memory = '12g'
executor_memory = '8g'

dup_times = 100


In [2]:
import churn.augment

churn.augment.register_options(
    spark_master = spark_master,
    app_name = app_name,
    input_file = input_file,
    output_prefix = output_prefix,
    output_mode = output_mode,
    output_kind = output_kind,
    driver_memory = driver_memory,
    executor_memory = executor_memory,
    dup_times = dup_times,
    use_decimal = True
)

# Sanity-checking

We're going to make sure we're running with a compatible JVM first — if we run on macOS, we might get one that doesn't work with Scala.

In [3]:
from os import getenv

In [4]:
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-11.0.12.0.7-4.fc34.x86_64"
getenv("JAVA_HOME")

'/usr/lib/jvm/java-11-openjdk-11.0.12.0.7-4.fc34.x86_64'

In [5]:
os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3.9"
os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3.9"
getenv("PYSPARK_PYTHON")
getenv("PYSPARK_DRIVER_PYTHON")

'/usr/bin/python3.9'

# Spark setup

In [6]:
import pyspark

In [7]:
session = pyspark.sql.SparkSession.builder \
    .master(spark_master) \
    .appName(app_name) \
    .config("spark.driver.memory", driver_memory) \
    .config("spark.executor.memory", executor_memory) \
    .config("spark.executor.cores", 1) \
    .config("spark.rapids.sql.concurrentGpuTasks", 1) \
    .config("spark.rapids.memory.pinnedPool.size", "2G") \
    .config("spark.locality.wait", "0s") \
    .config("spark.sql.files.maxPartitionBytes", "512m") \
    .config("spark.plugins", "com.nvidia.spark.SQLPlugin") \
    .config("spark.jars", "/opt/sparkRapidsPlugin/cudf-21.08.2-cuda11.jar,/opt/sparkRapidsPlugin/rapids-4-spark_2.12-21.08.0.jar") \
    .getOrCreate()
session

21/09/25 14:34:23 WARN Utils: Your hostname, virt resolves to a loopback address: 127.0.0.1; using 192.168.86.109 instead (on interface wlp2s0)
21/09/25 14:34:23 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
21/09/25 14:34:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/09/25 14:34:29 WARN GpuDeviceManager: Initial RMM allocation (3437.32421875 MB) is larger than the adjusted maximum allocation (3018.375 MB), lowering initial allocation to the adjusted maximum allocation.
21/09/25 14:34:31 WARN SQLExecPlugin: RAPIDS Accelerator 21.08.0 using cudf 21.08.2. To disable GPU support set `spark.rapids.sql.enabled` to false
21/09/25 14:34:31 WARN Plugin: Installing rapids UDF compi

# Schema definition

Most of the fields are strings representing booleans or categoricals, but a few (`tenure`, `MonthlyCharges`, and `TotalCharges`) are numeric.

In [8]:
from churn.augment import load_supplied_data

df = load_supplied_data(session, input_file)



read 7043 records from source dataset (7032 non-null records)




# Splitting the data frame

The training data schema looks like this:

- customerID
- gender
- SeniorCitizen
- Partner
- Dependents
- tenure
- PhoneService
- MultipleLines
- InternetService
- OnlineSecurity
- OnlineBackup
- DeviceProtection
- TechSupport
- StreamingTV
- StreamingMovies
- Contract
- PaperlessBilling
- PaymentMethod
- MonthlyCharges
- TotalCharges
- Churn

We want to divide the data frame into several frames that we can join together in an ETL job.

Those frames will look like this:

- **Customer metadata**
  - customerID
  - gender
  - date of birth (we'll derive age and senior citizen status from this)
  - Partner
  - Dependents
  - (nominal) MonthlyCharges
- **Billing events**
  - customerID
  - date (we'll derive tenure from the number/duration of billing events)
  - kind (one of "AccountCreation", "Charge", or "AccountTermination")
  - value (either a positive nonzero amount or 0.00; we'll derive TotalCharges from the sum of amounts and Churn from the existence of an AccountTermination event)
- **Customer phone features**
  - customerID
  - feature (one of "PhoneService" or "MultipleLines")
- **Customer internet features**
  - customerID
  - feature (one of "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies")
  - value (one of "Fiber", "DSL", "Yes", "No")
- **Customer account features**
  - customerID
  - feature (one of "Contract", "PaperlessBilling", "PaymentMethod")
  - value (one of "Month-to-month", "One year", "Two year", "No", "Yes", "Credit card (automatic)", "Mailed check", "Bank transfer (automatic)", "Electronic check")

In [9]:
df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: string (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: double (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



We'll start by generating a series of monthly charges, then a series of account creation events, and finally a series of churn events. `billingEvents` is the data frame containing all of these events:  account activation, account termination, and individual payment events.

In [10]:
from churn.augment import billing_events
billingEvents = billing_events(df)

Our next step is to generate customer metadata, which includes the following fields:

  - gender
  - date of birth (we'll derive age and senior citizen status from this)
  - Partner
  - Dependents
  
We'll calculate date of birth by using the hash of the customer ID as a pseudorandom number and then assuming that ages are uniformly distributed between 18-65 and exponentially distributed over 65.

In [12]:
from churn.augment import customer_meta
customerMeta = customer_meta(df)

Now we can generate customer phone features, which include:

  - customerID
  - feature (one of "PhoneService" or "MultipleLines")
  - value (always "Yes"; there are no records for "No" or "No Phone Service")

In [13]:
from churn.augment import phone_features
customerPhoneFeatures = phone_features(df)

Customer internet features include:
  - customerID
  - feature (one of "InternetService", "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies")
  - value (one of "Fiber", "DSL", "Yes" -- no records for "No" or "No internet service")

In [14]:
from churn.augment import internet_features
customerInternetFeatures = internet_features(df)

Customer account features include:

  - customerID
  - feature (one of "Contract", "PaperlessBilling", "PaymentMethod")
  - value (one of "Month-to-month", "One year", "Two year", "Yes", "Credit card (automatic)", "Mailed check", "Bank transfer (automatic)", "Electronic check")

In [15]:
from churn.augment import account_features
customerAccountFeatures = account_features(df)

# Write outputs

In [16]:
%%time

from churn.augment import write_df

write_df(billingEvents, "billing_events", partition_by="month")
write_df(customerMeta, "customer_meta", skip_replication=True)
write_df(customerPhoneFeatures, "customer_phone_features")
write_df(customerInternetFeatures.orderBy("customerID"), "customer_internet_features")
write_df(customerAccountFeatures, "customer_account_features")



CPU times: user 272 ms, sys: 63 ms, total: 335 ms
Wall time: 3min 25s




In [17]:
for f in ["billing_events", "customer_meta", "customer_phone_features", "customer_internet_features", "customer_account_features"]:
    output_df = session.read.parquet("%s.parquet" % f)
    print(f, output_df.select("customerID").distinct().count())



billing_events 703200




customer_meta 703200




customer_phone_features 635200




customer_internet_features 551200




customer_account_features 703200




In [18]:
import pyspark.sql.functions as F
from functools import reduce

output_dfs = []

for f in ["billing_events", "customer_meta", "customer_phone_features", "customer_internet_features", "customer_account_features"]:
    output_dfs.append(
        session.read.parquet("%s.parquet" % f).select(
            F.lit(f).alias("table"),
            "customerID"
        )
    )

all_customers = reduce(lambda l, r: l.unionAll(r), output_dfs)



In [19]:

each_table = all_customers.groupBy("table").agg(F.approx_count_distinct("customerID").alias("approx_unique_customers"))
overall = all_customers.groupBy(F.lit("all").alias("table")).agg(F.approx_count_distinct("customerID").alias("approx_unique_customers"))

each_table.union(overall).show()

21/09/25 14:40:40 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+--------------------+-----------------------+
|               table|approx_unique_customers|
+--------------------+-----------------------+
|customer_internet...|                 521053|
|       customer_meta|                 699470|
|customer_phone_fe...|                 631148|
|customer_account_...|                 699470|
|      billing_events|                 699470|
|                 all|                 699470|
+--------------------+-----------------------+





In [20]:
rows = each_table.union(overall).collect()



In [21]:
dict([(row[0], row[1]) for row in rows])

{'customer_internet_features': 521053,
 'customer_meta': 699470,
 'customer_phone_features': 631148,
 'customer_account_features': 699470,
 'billing_events': 699470,
 'all': 699470}

In [None]:
session.stop()