In [None]:
from datetime import date
from pyspark.sql.types import *
from pyspark.sql.functions import lit

In [None]:
# this has been added for scenarios where you might
# wish to alter some of the churn label prediction
# logic but do not wish to rerun the whole notebook
skip_reload = False

# please use a personalized database name here if you wish to avoid interfering with other users who might be running this accelerator in the same workspace
database_name = 'kkbox_churn'

In [None]:
if skip_reload:
  # create database to house SQL tables
  _ = spark.sql(f'CREATE DATABASE IF NOT EXISTS {database_name}')
  _ = spark.sql(f'USE {database_name}')
else:
  # delete the old database if needed
  _ = spark.sql(f'DROP DATABASE IF EXISTS {database_name} CASCADE')
  _ = spark.sql(f'CREATE DATABASE {database_name}')
  _ = spark.sql(f'USE {database_name}')

  # drop any old delta lake files that might have been created
  dbutils.fs.rm("'/tmp/kkbox_churn/silver/'members", True)
  
  # members dataset schema
  member_schema = StructType([
    StructField('msno', StringType()),
    StructField('city', IntegerType()),
    StructField('bd', IntegerType()),
    StructField('gender', StringType()),
    StructField('registered_via', IntegerType()),
    StructField('registration_init_time', DateType())
    ])

  # read data from csv
  members = (
    spark
      .read
      .csv(
        '/tmp/kkbox_churn/members/members_v3.csv',
        schema=member_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  (
    members
      .write
      .format('delta')
      .mode('overwrite')
      .save('/tmp/kkbox_churn/silver/members')
    )

    # create table object to make delta lake queryable
  _ = spark.sql('''
      CREATE TABLE members 
      USING DELTA 
      LOCATION '/tmp/kkbox_churn/silver/members'
      ''')

In [None]:
if not skip_reload:

  # drop any old delta lake files that might have been created
  dbutils.fs.rm("/tmp/kkbox_churn/silver/transactions", True)

  # transaction dataset schema
  transaction_schema = StructType([
    StructField('msno', StringType()),
    StructField('payment_method_id', IntegerType()),
    StructField('payment_plan_days', IntegerType()),
    StructField('plan_list_price', IntegerType()),
    StructField('actual_amount_paid', IntegerType()),
    StructField('is_auto_renew', IntegerType()),
    StructField('transaction_date', DateType()),
    StructField('membership_expire_date', DateType()),
    StructField('is_cancel', IntegerType())  
    ])

  # read data from csv
  transactions = (
    spark
      .read
      .csv(
        '/tmp/kkbox_churn/transactions',
        schema=transaction_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  ( transactions
      .write
      .format('delta')
      .partitionBy('transaction_date')
      .mode('overwrite')
      .save('/tmp/kkbox_churn/silver/transactions')
    )

    # create table object to make delta lake queryable
  _ = spark.sql('''
      CREATE TABLE transactions
      USING DELTA 
      LOCATION '/tmp/kkbox_churn/silver/transactions'
      ''')

In [None]:
if not skip_reload:
  # drop any old delta lake files that might have been created
  dbutils.fs.rm("/tmp/kkbox_churn/silver/user_logs", True)

  # transaction dataset schema
  user_logs_schema = StructType([ 
    StructField('msno', StringType()),
    StructField('date', DateType()),
    StructField('num_25', IntegerType()),
    StructField('num_50', IntegerType()),
    StructField('num_75', IntegerType()),
    StructField('num_985', IntegerType()),
    StructField('num_100', IntegerType()),
    StructField('num_uniq', IntegerType()),
    StructField('total_secs', FloatType())  
    ])

  # read data from csv
  user_logs = (
    spark
      .read
      .csv(
        '/tmp/kkbox_churn/user_logs',
        schema=user_logs_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  ( user_logs
      .write
      .format('delta')
      .partitionBy('date')
      .mode('overwrite')
      .save('/tmp/kkbox_churn/silver/user_logs')
    )

  # create table object to make delta lake queryable
  _ = spark.sql('''
    CREATE TABLE IF NOT EXISTS user_logs
    USING DELTA 
    LOCATION '/tmp/kkbox_churn/silver/user_logs'
    ''')

In [None]:
_ = spark.sql('DROP TABLE IF EXISTS train')
dbutils.fs.rm("/tmp/kkbox_churn/silver/train", True)

In [None]:
%scala
 
import java.time.{LocalDate}
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit
 
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions._
import scala.collection.mutable
 
def calculateLastday(wrappedArray: mutable.WrappedArray[Row]) :String ={
  val orderedList = wrappedArray.sortWith((x:Row, y:Row) => {
    if(x.getAs[String]("transaction_date") != y.getAs[String]("transaction_date")) {
      x.getAs[String]("transaction_date") < y.getAs[String]("transaction_date")
    } else {
      
      val x_sig = x.getAs[String]("plan_list_price") +
        x.getAs[String]("payment_plan_days") +
        x.getAs[String]("payment_method_id")
 
      val y_sig = y.getAs[String]("plan_list_price") +
        y.getAs[String]("payment_plan_days") +
        y.getAs[String]("payment_method_id")
 
      //same plan, always subscribe then unsubscribe
      if(x_sig != y_sig) {
        x_sig > y_sig
      } else {
        if(x.getAs[String]("is_cancel")== "1" && y.getAs[String]("is_cancel") == "1") {
          //multiple cancel, consecutive cancels should only put the expiration date earlier
          x.getAs[String]("membership_expire_date") > y.getAs[String]("membership_expire_date")
        } else if(x.getAs[String]("is_cancel")== "0" && y.getAs[String]("is_cancel") == "0") {
          //multiple renewal, expiration date keeps extending
          x.getAs[String]("membership_expire_date") < y.getAs[String]("membership_expire_date")
        } else {
          //same day same plan transaction: subscription precedes cancellation
          x.getAs[String]("is_cancel") < y.getAs[String]("is_cancel")
        }
      }
    }
  })
  orderedList.last.getAs[String]("membership_expire_date")
}
 
def calculateRenewalGap(log:mutable.WrappedArray[Row], lastExpiration: String): Int = {
  val orderedDates = log.sortWith((x:Row, y:Row) => {
    if(x.getAs[String]("transaction_date") != y.getAs[String]("transaction_date")) {
      x.getAs[String]("transaction_date") < y.getAs[String]("transaction_date")
    } else {
      
      val x_sig = x.getAs[String]("plan_list_price") +
        x.getAs[String]("payment_plan_days") +
        x.getAs[String]("payment_method_id")
 
      val y_sig = y.getAs[String]("plan_list_price") +
        y.getAs[String]("payment_plan_days") +
        y.getAs[String]("payment_method_id")
 
      //same data same plan transaction, assumption: subscribe then unsubscribe
      if(x_sig != y_sig) {
        x_sig > y_sig
      } else {
        if(x.getAs[String]("is_cancel")== "1" && y.getAs[String]("is_cancel") == "1") {
          //multiple cancel of same plan, consecutive cancels should only put the expiration date earlier
          x.getAs[String]("membership_expire_date") > y.getAs[String]("membership_expire_date")
        } else if(x.getAs[String]("is_cancel")== "0" && y.getAs[String]("is_cancel") == "0") {
          //multiple renewal, expire date keep extending
          x.getAs[String]("membership_expire_date") < y.getAs[String]("membership_expire_date")
        } else {
          //same date cancel should follow subscription
          x.getAs[String]("is_cancel") < y.getAs[String]("is_cancel")
        }
      }
    }
  })
 
  //Search for the first subscription after expiration
  //If active cancel is the first action, find the gap between the cancellation and renewal
  val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
  var lastExpireDate = LocalDate.parse(s"${lastExpiration.substring(0,4)}-${lastExpiration.substring(4,6)}-${lastExpiration.substring(6,8)}", formatter)
  var gap = 9999
  for(
    date <- orderedDates
    if gap == 9999
  ) {
    val transString = date.getAs[String]("transaction_date")
    val transDate = LocalDate.parse(s"${transString.substring(0,4)}-${transString.substring(4,6)}-${transString.substring(6,8)}", formatter)
    val expireString = date.getAs[String]("membership_expire_date")
    val expireDate = LocalDate.parse(s"${expireString.substring(0,4)}-${expireString.substring(4,6)}-${expireString.substring(6,8)}", formatter)
    val isCancel = date.getAs[String]("is_cancel")
 
    if(isCancel == "1") {
      if(expireDate.isBefore(lastExpireDate)) {
        lastExpireDate = expireDate
      }
    } else {
      gap = ChronoUnit.DAYS.between(lastExpireDate, transDate).toInt
    }
  }
  gap
}
 
val data = spark
  .read
  .option("header", value = true)
  .csv("/tmp/kkbox_churn/transactions/")
 
val historyCutoff = "20170131"
 
val historyData = data.filter(col("transaction_date")>="20170101" and col("transaction_date")<=lit(historyCutoff))
val futureData = data.filter(col("transaction_date") > lit(historyCutoff))
 
val calculateLastdayUDF = udf(calculateLastday _)
val userExpire = historyData
  .groupBy("msno")
  .agg(
    calculateLastdayUDF(
      collect_list(
        struct(
          col("payment_method_id"),
          col("payment_plan_days"),
          col("plan_list_price"),
          col("transaction_date"),
          col("membership_expire_date"),
          col("is_cancel")
        )
      )
    ).alias("last_expire")
  )
 
val predictionCandidates = userExpire
  .filter(
    col("last_expire") >= "20170201" and col("last_expire") <= "20170228"
  )
  .select("msno", "last_expire")
 
 
val joinedData = predictionCandidates
  .join(futureData,Seq("msno"), "left_outer")
 
val noActivity = joinedData
  .filter(col("payment_method_id").isNull)
  .withColumn("is_churn", lit(1))
 
 
val calculateRenewalGapUDF = udf(calculateRenewalGap _)
val renewals = joinedData
  .filter(col("payment_method_id").isNotNull)
  .groupBy("msno", "last_expire")
  .agg(
    calculateRenewalGapUDF(
      collect_list(
        struct(
          col("payment_method_id"),
          col("payment_plan_days"),
          col("plan_list_price"),
          col("transaction_date"),
          col("membership_expire_date"),
          col("is_cancel")
        )
      ),
      col("last_expire")
    ).alias("gap")
  )
 
val validRenewals = renewals.filter(col("gap") < 30)
  .withColumn("is_churn", lit(0))
val lateRenewals = renewals.filter(col("gap") >= 30)
  .withColumn("is_churn", lit(1))
 
val resultSet = validRenewals
  .select("msno","is_churn")
  .union(
    lateRenewals
      .select("msno","is_churn")
      .union(
        noActivity.select("msno","is_churn")
      )
  )
 
resultSet.write.format("delta").mode("overwrite").save("/tmp/kkbox_churn/silver/train/")

In [None]:
%sql

CREATE TABLE train
USING DELTA
LOCATION '/tmp/kkbox_churn/silver/train/';

SELECT *
FROM train;

In [None]:
_ = spark.sql('DROP TABLE IF EXISTS test')
dbutils.fs.rm("/tmp/kkbox_churn/silver/test", True)

In [None]:
%scala

import java.time.{LocalDate}
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions._
import scala.collection.mutable

def calculateLastday(wrappedArray: mutable.WrappedArray[Row]) :String ={
  val orderedList = wrappedArray.sortWith((x:Row, y:Row) => {
    if(x.getAs[String]("transaction_date") != y.getAs[String]("transaction_date")) {
      x.getAs[String]("transaction_date") < y.getAs[String]("transaction_date")
    } else {
      val x_sig = x.getAs[String]("plan_list_price") +
        x.getAs[String]("payment_plan_days") +
        x.getAs[String]("payment_method_id")


      val y_sig = y.getAs[String]("plan_list_price") +
        y.getAs[String]("payment_plan_days") +
        y.getAs[String]("payment_method_id")

      //same plan, always subscribe then unsubscribe
      if(x_sig != y_sig) {
        x_sig > y_sig
      } else {
        if(x.getAs[String]("is_cancel")== "1" && y.getAs[String]("is_cancel") == "1") {
          //multiple cancel, consecutive cancels should only put the expiration date earlier
          x.getAs[String]("membership_expire_date") > y.getAs[String]("membership_expire_date")
        } else if(x.getAs[String]("is_cancel")== "0" && y.getAs[String]("is_cancel") == "0") {
          //multiple renewal, expiration date keeps extending
          x.getAs[String]("membership_expire_date") < y.getAs[String]("membership_expire_date")
        } else {
          //same day same plan transaction: subscription precedes cancellation
          x.getAs[String]("is_cancel") < y.getAs[String]("is_cancel")
        }
      }
    }
  })
  orderedList.last.getAs[String]("membership_expire_date")
}

def calculateRenewalGap(log:mutable.WrappedArray[Row], lastExpiration: String): Int = {
  val orderedDates = log.sortWith((x:Row, y:Row) => {
    if(x.getAs[String]("transaction_date") != y.getAs[String]("transaction_date")) {
      x.getAs[String]("transaction_date") < y.getAs[String]("transaction_date")
    } else {
      
      val x_sig = x.getAs[String]("plan_list_price") +
        x.getAs[String]("payment_plan_days") +
        x.getAs[String]("payment_method_id")

      val y_sig = y.getAs[String]("plan_list_price") +
        y.getAs[String]("payment_plan_days") +
        y.getAs[String]("payment_method_id")

      //same data same plan transaction, assumption: subscribe then unsubscribe
      if(x_sig != y_sig) {
        x_sig > y_sig
      } else {
        if(x.getAs[String]("is_cancel")== "1" && y.getAs[String]("is_cancel") == "1") {
          //multiple cancel of same plan, consecutive cancels should only put the expiration date earlier
          x.getAs[String]("membership_expire_date") > y.getAs[String]("membership_expire_date")
        } else if(x.getAs[String]("is_cancel")== "0" && y.getAs[String]("is_cancel") == "0") {
          //multiple renewal, expire date keep extending
          x.getAs[String]("membership_expire_date") < y.getAs[String]("membership_expire_date")
        } else {
          //same date cancel should follow subscription
          x.getAs[String]("is_cancel") < y.getAs[String]("is_cancel")
        }
      }
    }
  })

  //Search for the first subscription after expiration
  //If active cancel is the first action, find the gap between the cancellation and renewal
  val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
  var lastExpireDate = LocalDate.parse(s"${lastExpiration.substring(0,4)}-${lastExpiration.substring(4,6)}-${lastExpiration.substring(6,8)}", formatter)
  var gap = 9999
  for(
    date <- orderedDates
    if gap == 9999
  ) {
    val transString = date.getAs[String]("transaction_date")
    val transDate = LocalDate.parse(s"${transString.substring(0,4)}-${transString.substring(4,6)}-${transString.substring(6,8)}", formatter)
    val expireString = date.getAs[String]("membership_expire_date")
    val expireDate = LocalDate.parse(s"${expireString.substring(0,4)}-${expireString.substring(4,6)}-${expireString.substring(6,8)}", formatter)
    val isCancel = date.getAs[String]("is_cancel")

    if(isCancel == "1") {
      if(expireDate.isBefore(lastExpireDate)) {
        lastExpireDate = expireDate
      }
    } else {
      gap = ChronoUnit.DAYS.between(lastExpireDate, transDate).toInt
    }
  }
  gap
}

val data = spark
  .read
  .option("header", value = true)
  .csv("/tmp/kkbox_churn/transactions/")

val historyCutoff = "20170228"

val historyData = data.filter(col("transaction_date")>="20170201" and col("transaction_date")<=lit(historyCutoff))
val futureData = data.filter(col("transaction_date") > lit(historyCutoff))

val calculateLastdayUDF = udf(calculateLastday _)
val userExpire = historyData
  .groupBy("msno")
  .agg(
    calculateLastdayUDF(
      collect_list(
        struct(
          col("payment_method_id"),
          col("payment_plan_days"),
          col("plan_list_price"),
          col("transaction_date"),
          col("membership_expire_date"),
          col("is_cancel")
        )
      )
    ).alias("last_expire")
  )

val predictionCandidates = userExpire
  .filter(
    col("last_expire") >= "20170301" and col("last_expire") <= "20170331"
  )
  .select("msno", "last_expire")


val joinedData = predictionCandidates
  .join(futureData,Seq("msno"), "left_outer")

val noActivity = joinedData
  .filter(col("payment_method_id").isNull)
  .withColumn("is_churn", lit(1))


val calculateRenewalGapUDF = udf(calculateRenewalGap _)
val renewals = joinedData
  .filter(col("payment_method_id").isNotNull)
  .groupBy("msno", "last_expire")
  .agg(
    calculateRenewalGapUDF(
      collect_list(
        struct(
          col("payment_method_id"),
          col("payment_plan_days"),
          col("plan_list_price"),
          col("transaction_date"),
          col("membership_expire_date"),
          col("is_cancel")
        )
      ),
      col("last_expire")
    ).alias("gap")
  )

val validRenewals = renewals.filter(col("gap") < 30)
  .withColumn("is_churn", lit(0))
val lateRenewals = renewals.filter(col("gap") >= 30)
  .withColumn("is_churn", lit(1))

val resultSet = validRenewals
  .select("msno","is_churn")
  .union(
    lateRenewals
      .select("msno","is_churn")
      .union(
        noActivity.select("msno","is_churn")
      )
  )

resultSet.write.format("delta").mode("overwrite").save("/tmp/kkbox_churn/silver/test/")

In [None]:
%sql

CREATE TABLE test
USING DELTA
LOCATION '/tmp/kkbox_churn/silver/test/';

SELECT *
FROM test;