#### Step 1: Load the Data

In 2018, KKBox, a popular music streaming service based in Taiwan, released a dataset consisting of a little over two years of (anonymized) customer transaction and activity data with the goal of challenging the Data & AI community to predict which customers would churn in a future period.

NOTE Due to the terms and conditions by which these data are made available, anyone interested in recreating this work will need to agree with the terms and conditions before making up this dataset and create a similar folder structure as described below in their environment. You can save the data permanently under a pre-defined mount point named /mnt/kkbox:

We have automated this downloading step for you and used a /tmp/kkbox_churn storage path throughout this accelerator.

Read into dataframes, these files form the following data model:

Each service subscriber is uniquely identified by a value in the msno field of the members table. Data in the transactions and user logs tables provide a record of subscription management and streaming activities, respectively. Not every member has a complete set of data in this schema. In addition, the transaction and streaming logs are quite verbose with multiple records being recorded for a subscriber on a given date. On dates where there is no activity, no entries are found for a subscriber in these tables.

In order to protect data privacy, many values in these tables have been ordinal-encoded, limiting their interpretability. In addition, timestamp information has been truncated to a daily level, making the sequencing of records on a given date dependent on business logic addressed in later steps in this notebook.

In [1]:
from datetime import date
from pyspark.sql.types import *
from pyspark.sql.functions import lit
import shutil
import os
import subprocess

In [2]:
# this has been added for scenarios where you might
# wish to alter some of the churn label prediction
# logic but do not wish to rerun the whole notebook
skip_reload = False

# please use a personalized database name here if you wish to avoid interfering with other users who might be running this accelerator in the same workspace
database_name = 'kkbox_churn'
data_dir = f"{os.getenv('HOME')}/databricks/kkbox_churn"

In [None]:
spark.stop()  # Properly stop Spark
del spark     # Delete the variable

In [6]:
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("ChurnCluster") \
    .config("spark.jars.packages", "io.delta:delta-core_2.12:1.2.1") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .config("spark.executor.memory", "56g") \
    .config("spark.driver.memory", "48g") \
    .getOrCreate()

os.environ["SPARK_APP_NAME"] = spark.conf.get("spark.app.name")
os.environ["SPARK_MASTER"] = spark.conf.get("spark.master")

print("Spark Version:", spark.version)

Spark Version: 3.2.4


In [None]:
if skip_reload:
  # create database to house SQL tables
  _ = spark.sql(f'CREATE DATABASE IF NOT EXISTS {database_name}')
  _ = spark.sql(f'USE {database_name}')
else:
  # delete the old database if needed
  _ = spark.sql(f'DROP DATABASE IF EXISTS {database_name} CASCADE')
  _ = spark.sql(f'CREATE DATABASE {database_name}')
  _ = spark.sql(f'USE {database_name}')

  # drop any old delta lake files that might have been created
  folder_path = f'{data_dir}/silver/members'
  if os.path.exists(folder_path):
      shutil.rmtree(folder_path)
    
  # members dataset schema
  member_schema = StructType([
    StructField('msno', StringType()),
    StructField('city', IntegerType()),
    StructField('bd', IntegerType()),
    StructField('gender', StringType()),
    StructField('registered_via', IntegerType()),
    StructField('registration_init_time', DateType())
    ])

  # read data from csv
  members = (
    spark
      .read
      .csv(
        f'{data_dir}/members/members_v3.csv',
        schema=member_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  (
    members
      .write
      .format('delta')
      .mode('overwrite')
      .save(f'{data_dir}/silver/members')
    )

    # create table object to make delta lake queryable
  _ = spark.sql('''
      CREATE TABLE members 
      USING DELTA 
      LOCATION '/home/dinindu/databricks/kkbox_churn/silver/members'
      ''')


In [None]:
# print(members.show())
result = spark.sql("SELECT * FROM kkbox_churn.members LIMIT 10")
result.show()

In [None]:
if not skip_reload:

  # drop any old delta lake files that might have been created
  folder_path = f'{data_dir}/silver/transactions'
  if os.path.exists(folder_path):
      shutil.rmtree(folder_path)

  # transaction dataset schema
  transaction_schema = StructType([
    StructField('msno', StringType()),
    StructField('payment_method_id', IntegerType()),
    StructField('payment_plan_days', IntegerType()),
    StructField('plan_list_price', IntegerType()),
    StructField('actual_amount_paid', IntegerType()),
    StructField('is_auto_renew', IntegerType()),
    StructField('transaction_date', DateType()),
    StructField('membership_expire_date', DateType()),
    StructField('is_cancel', IntegerType())  
    ])

  # read data from csv
  transactions = (
    spark
      .read
      .csv(
        f'{data_dir}/transactions/transactions.csv',
        schema=transaction_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  ( transactions
      .write
      .format('delta')
      .partitionBy('transaction_date')
      .mode('overwrite')
      .save(f'{data_dir}/silver/transactions')
    )

    # create table object to make delta lake queryable
  _ = spark.sql('''
      CREATE TABLE transactions
      USING DELTA 
      LOCATION '/home/dinindu/databricks/kkbox_churn/silver/transactions'
      ''')

In [None]:
# print(transactions.show())
result = spark.sql("SELECT * FROM kkbox_churn.transactions")
result.show()

In [None]:
if not skip_reload:
  # drop any old delta lake files that might have been created
  folder_path = f'{data_dir}/silver/user_logs'
  if os.path.exists(folder_path):
      shutil.rmtree(folder_path)

  # transaction dataset schema
  user_logs_schema = StructType([ 
    StructField('msno', StringType()),
    StructField('date', DateType()),
    StructField('num_25', IntegerType()),
    StructField('num_50', IntegerType()),
    StructField('num_75', IntegerType()),
    StructField('num_985', IntegerType()),
    StructField('num_100', IntegerType()),
    StructField('num_uniq', IntegerType()),
    StructField('total_secs', FloatType())  
    ])

  # read data from csv
  user_logs = (
    spark
      .read
      .csv(
        f'{data_dir}/user_logs/user_logs.csv',
        schema=user_logs_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  ( user_logs
      .write
      .format('delta')
      .partitionBy('date')
      .mode('overwrite')
      .save(f'{data_dir}/silver/user_logs')
    )

  # create table object to make delta lake queryable
  _ = spark.sql('''
    CREATE TABLE IF NOT EXISTS user_logs
    USING DELTA 
    LOCATION '/home/dinindu/databricks/kkbox_churn/silver/user_logs'
    ''')

#### Step 2: Acquire Churn Labels

To build our model, we will need to identify which customers have churned within two periods of interest. These periods are February 2017 and March 2017. We will train our model to predict churn in February 2017 and then evaluate our model's ability to predict churn in March 2017, making these our training and testing datasets, respectively.

Per instructions provided in the Kaggle competition, a KKBox subscriber is not identified as churned until he or she fails to renew their subscription 30-days following its expiration. Most subscriptions are themselves on a 30-day renewal schedule (though some subscriptions renew on significantly longer cycles). This means that identifying churn involves a sequential walk through the customer data, looking for renewal gaps that would indicate a customer churned on a prior expiration date.

While the competition makes available pre-labeled training and testing datasets, train.csv and train_v2.csv, respectively, several past participants have noted that these datasets should be regenerated. A Scala script for doing so is provided by KKBox. Modifying the script for this environment, we might regenerate our training and test datasets as follows:

In [None]:
# Delete training labels if exists before create
_ = spark.sql('DROP TABLE IF EXISTS train')

In [None]:
# %%sh -e

# Create a spark cluster. But not required.
# docker network create spark-net

# docker run -d --rm --network spark-net --name spark-master \
#     -p 8080:8080 -p 7077:7077 -p 4040:4040 \
#     bitnami/spark spark-class org.apache.spark.deploy.master.Master

# docker run -d --rm --network spark-net --name spark-worker \
#     --env SPARK_MODE=worker \
#     --env SPARK_MASTER_URL=spark://spark-master:7077 \
#     bitnami/spark spark-class org.apache.spark.deploy.worker.Worker spark://spark-master:7077


In [None]:
%%sh -e
# Generate training labels

kkbox_churn_dir="/home/dinindu/databricks/kkbox_churn"
sudo chmod 777 $kkbox_churn_dir
sudo rm -rf $kkbox_churn_dir/silver/train

docker run --rm --network host  \
    -v "$kkbox_churn_dir:/opt/spark/work/kkbox_churn" \
    -v "$PWD:/opt/bitnami/spark/work" \
    bitnami/spark:3.4.1 spark-shell --master local[*] \
    --executor-memory 48G \
    --driver-memory 16G \
    --packages io.delta:delta-core_2.12:2.4.0 \
    --conf spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension \
    --conf spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog \
    -i /opt/bitnami/spark/work/scripts/generate_training_labels.scala

sudo chown -R dinindu:dinindu $kkbox_churn_dir/silver/train
sudo chmod -R 777 $kkbox_churn_dir/silver/train

In [7]:
# Access the training labels
_ = spark.sql('''
CREATE TABLE IF NOT EXISTS train
USING DELTA
LOCATION '/home/dinindu/databricks/kkbox_churn/silver/train/'
''')

_ = spark.sql('SELECT * FROM train').show(10)

+--------------------+--------+
|                msno|is_churn|
+--------------------+--------+
|+++snpr7pmobhLKUg...|       0|
|++0/NopttBsaAn6qH...|       0|
|++3fWHRDC5GWWlovH...|       0|
|++3wqX72HmowxFh5M...|       0|
|++4RuqBw0Ss6bQU4o...|       0|
|++5XBxbJNz3idK9eg...|       0|
|++5Z7z4xXBhCjID+B...|       0|
|++8dXbkKMJ0rXwUc/...|       0|
|++A3frr1K7edD9xpe...|       0|
|++Bks8kE9oclzxZM3...|       0|
+--------------------+--------+
only showing top 10 rows



25/02/20 23:44:38 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:301)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:113)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$driverEndpoint(BlockManagerMasterEndpoint.scala:112)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.isExecutorAlive$lzycompute$1(BlockManagerMasterEndpoint.scala:548)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.isExecutorAlive$1(BlockManagerMasterEndpoint.scala:547)
	at org.apache.spar

In [None]:
# Delete testing labels if exists before create
_ = spark.sql('DROP TABLE IF EXISTS test')

In [None]:
%%sh -e
# Generate testing labels

kkbox_churn_dir="/home/dinindu/databricks/kkbox_churn"
sudo chmod 777 $kkbox_churn_dir
sudo rm -rf $kkbox_churn_dir/silver/test

docker run --rm --network host  \
    -v "$kkbox_churn_dir:/opt/spark/work/kkbox_churn" \
    -v "$PWD:/opt/bitnami/spark/work" \
    bitnami/spark:3.4.1 spark-shell --master local[*] \
    --executor-memory 48G \
    --driver-memory 16G \
    --packages io.delta:delta-core_2.12:2.4.0 \
    --conf spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension \
    --conf spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog \
    -i /opt/bitnami/spark/work/scripts/generate_testing_labels.scala

sudo chown -R dinindu:dinindu $kkbox_churn_dir/silver/test
sudo chmod -R 777 $kkbox_churn_dir/silver/test

In [None]:
# Access the testing labels
_ = spark.sql('''
CREATE TABLE IF NOT EXISTS test
USING DELTA
LOCATION '/home/dinindu/databricks/kkbox_churn/silver/test/'
              ''')

_ = spark.sql('SELECT * FROM test').show(10)

#### Step 3: Cleanse & Enhance Transaction Logs

In the churn script provided by KKBox (and used in the last step), time between transaction events is used in order to determine churn status. In situations where multiple transactions are recorded on a given date, complex logic is used to determine which transaction represents the final state of the account on that date. This logic states that when we have multiple transactions for a given subscriber on a given date, we should:

1. Concatenate the plan_list_price, payment_plan_days, and payment_method_id values and consider the "bigger" of these values as preceding the others
2. If the concatenated value (defined in the last step) is the same across records for this date, cancellations, i.e. records where is_cancel=1, should follow other transactions
3. If there are multiple cancellations in this sequence, the record with the earliest expiration date is the last record for this transaction date
4. If there are no cancellations but multiple non-cancellations in this sequence, the non-cancellation record with the latest expiration date is the last record on the transaction date

Rewriting this logic in SQL allows us to generate a cleansed version of the transaction log with the final record for each date:

In [None]:
%%sh -e
sudo rm -rf $PWD/spark-warehouse/transactions_clean

In [None]:
# Read the SQL file
with open("scripts/transactions_clean.sql", "r") as sql_file:
    sql_script = sql_file.read()

# Execute the SQL script
for statement in sql_script.split(";"):
    statement = statement.strip()
    if statement:
        spark.sql(statement)

spark.sql("SELECT * FROM transactions_clean LIMIT 10").show()


Using this cleansed transaction data, we can now more easily identify the start and end of subscriptions using the 30-day gap logic found in the Scala code. It's important to note that over the 2+ year period represented by the dataset, many subscribers will churn and many of those that do churn will re-subscribe. With this in mind, we will generate a subscription ID to identify the different subscriptions, each of which will have a non-overlapping starting and ending date for a given subscriber:

In [None]:
%%sh -e
sudo rm -rf $PWD/spark-warehouse/subscription_windows

In [None]:
# Read the SQL file
with open("scripts/subscription_windows.sql", "r") as sql_file:
    sql_script = sql_file.read()

# Execute the SQL script
for statement in sql_script.split(";"):
    statement = statement.strip()
    if statement:
        spark.sql(statement)

spark.sql("SELECT * FROM subscription_windows LIMIT 10").show()



To verify we have our subscription windows aligned with the script used to identify customers at-risk for churn in February and March 2017, let's perform a quick test. The script identifies an at-risk subscription as one where the last transaction recorded in the historical period, i.e. the time period leading up to the start of the month of interest, has an expiration date falling between the 30-day window leading up to the start of the period of interest and the end of that period. For example, if we were to identify at-risk customers for February 2017, we should look for those customers with active subscriptions set to expire within the 30-days before February 1, 2017 and February 28, 2017. This shifted window allows time for the 30-day grace period to expire within the period of interest.

NOTE Better logic would limit our assessment to those subscriptions with an expiration date between 30-days prior to the start of the period AND 30-days prior to the end of the period. (Such logic would exclude subscriptions expiring within the period of interest but which do not exit the 30-day grace period until after the period is over.) When we use this logic, we find numerous subscriptions that the provided script identifies as at-risk but which we would not. We will align our logic with that of the competition for this exercise.

With this logic in mind, let's see if all our labeled at-risk customers adhere to this logic:

NOTE The next two cells should return NO RESULTS if our logic is valid

In [None]:
# Identify Any Subscriptions in Training Dataset Not Believed to Be At-Risk
spark.sql('''
SELECT
  x.msno
FROM train x
LEFT OUTER JOIN (
  SELECT DISTINCT -- subscriptions that had risk in Feb 2017
    a.msno
  FROM subscription_windows a
  INNER JOIN transactions_clean b
    ON a.msno=b.msno AND b.transaction_date BETWEEN a.subscription_start AND a.subscription_end
  WHERE 
        a.subscription_start < '2017-02-01' AND
        (b.membership_expire_date BETWEEN DATE_ADD('2017-02-01',-30) AND '2017-02-28')
  ) y
  ON x.msno=y.msno
WHERE y.msno IS NULL
''').show()

In [None]:
# Identify Any Subscriptions in Testing Dataset Not Believed to Be At-Risk
spark.sql('''
SELECT
  x.msno
FROM test x
LEFT OUTER JOIN (
  SELECT DISTINCT -- subscriptions that had risk in Feb 2017
    a.msno
  FROM subscription_windows a
  INNER JOIN transactions_clean b
    ON a.msno=b.msno AND b.transaction_date BETWEEN a.subscription_start AND a.subscription_end
  WHERE 
        a.subscription_start < '2017-03-01' AND
        (b.membership_expire_date BETWEEN DATE_ADD('2017-03-01',-30) AND '2017-03-31')
  ) y
  ON x.msno=y.msno
WHERE y.msno IS NULL
''').show()

While we do not fail to identify the same at-risk subscriptions as the provided script, if we were to alter the code above we would find a few subscriptions that we do identify as at-risk but which the Scala script does not. While it might be useful to examine why this is, so long as there are no members that the Scala script identifies as at risk that we do not, we should should be able to use this dataset to derive features for subscriptions in our testing and training datasets.

Leveraging subscription duration information derived in the last few cells, we can now enhance our transaction log to detect account-level changes. This information will form the basis for transaction-feature generation in the next notebook:

In [None]:
%%sh -e
sudo rm -rf $PWD/spark-warehouse/transactions_enhanced

In [None]:
# Read the SQL file
with open("scripts/transactions_enhanced.sql", "r") as sql_file:
    sql_script = sql_file.read()

# Execute the SQL script
for statement in sql_script.split(";"):
    statement = statement.strip()
    if statement:
        spark.sql(statement)

spark.sql("SELECT * FROM transactions_enhanced LIMIT 10").show()


#### Step 4: Generate Dates Table

Finally, it is very likely we will want to derive features from both the transaction log and the user activity data where we examine days without activity. To make this analysis easier, it may be helpful to generate a table containing one record for each date from the beginning date to the end date in our dataset. We know that these data span January 1, 2015 through March 31, 2017. With that in mind, we can generate such a table as follows:

In [None]:
# calculate days in range
start_date = date(2015, 1, 1)
end_date = date(2017, 3, 31)
days = end_date - start_date

# generate temp view of dates in range
( spark
    .range(0, days.days)  
    .withColumn('start_date', lit(start_date.strftime('%Y-%m-%d')))  # first date in activity dataset
    .selectExpr('date_add(start_date, CAST(id as int)) as date')
    .write.format("delta").mode("overwrite").option("overwriteSchema", "true") 
    .saveAsTable('dates')
  )

# display SQL table content
display(spark.table('dates').orderBy('date'))

spark.sql("SELECT * FROM dates").show()

In [46]:
transactions.describe().show()
spark.sql("SELECT * FROM kkbox_churn.transactions").show()



+-------+--------------------+------------------+------------------+------------------+------------------+-------------------+-------------------+
|summary|                msno| payment_method_id| payment_plan_days|   plan_list_price|actual_amount_paid|      is_auto_renew|          is_cancel|
+-------+--------------------+------------------+------------------+------------------+------------------+-------------------+-------------------+
|  count|            21547746|          21547746|          21547746|          21547746|          21547746|           21547746|           21547746|
|   mean|                null|38.933103026181946| 31.33906284211815|139.88501507303826|141.98732048354384| 0.8519661406812573|0.03976522648819046|
| stddev|                null| 3.507936375441224|30.356493755193306|130.96470029176857|132.48242402576437|0.35513355469032165|0.19540715192283983|
|    min|+++FOrTS7ab3tIgIh...|                 1|                 0|                 0|                 0|            

                                                                                

In [None]:
# Copilot: what is the query to get if a user has a record with is_cancel with both 1 and 0
spark.sql("select * from kkbox_churn.transactions where msno like '18AdyEBVYQuH3K1Pg%' order by transaction_date").show()


AnalysisException: Table or view not found: kkbox_churn.transactions; line 1 pos 14;
'Sort ['transaction_date ASC NULLS FIRST], true
+- 'Project [*]
   +- 'Filter 'msno LIKE 18AdyEBVYQuH3K1Pg%
      +- 'UnresolvedRelation [kkbox_churn, transactions], [], false


25/02/20 23:37:48 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 3529461 ms exceeds timeout 120000 ms
25/02/20 23:37:48 WARN SparkContext: Killing executors is not supported by current scheduler.
25/02/20 23:37:48 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:301)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:113)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$driverEndpoint(BlockManagerMasterEndpoint.scala:112)
	at org.apache.spark.storage.