**Churn Input Dataframe**

This notebook creates input spark dataframe for Churn predication.

** Select companies for predication **

In [0]:
# from datetime import datetime

# # end date for the prediction
# pred_datetime = datetime(2020, 7, 21)

# # start date
# start_dateime = datetime(2000, 1, 1)

# # registered users only or both
# registered_only = False

**Load all tables**

In [0]:
%run ./CAB_Churn_ReadData

In [0]:
companies = spark.read.format('delta').load(f"{silverRootPath}/cab/companies")
accounts = spark.read.format('delta').load(f"{silverRootPath}/cab/accounts")
licenses = spark.read.format('delta').load(f"{silverRootPath}/cab/licenses")
trips = spark.read.format('delta').load(f"{silverRootPath}/cab/trips")
users = spark.read.format('delta').load(f"{silverRootPath}/cab/users")
form_headers = spark.read.format('delta').load(f"{silverRootPath}/cab/form_headers")
timekeeping_statuses = spark.read.format('delta').load(f"{silverRootPath}/cab/timekeeping_statuses")
order_headers = spark.read.format('delta').load(f"{silverRootPath}/cab/order_headers")
incidents = spark.read.format('delta').load(f"{silverRootPath}/dynamics/incidents")

**Create the churn dataframe for modeling:**
- Find billable Companies.
- Aggregate Licenses to company level.
- Aggregate dynamics tickets.
- Agregate trips.
- Aggregate forms submitted.

** Get billable companies **

In [0]:
from pyspark.sql.functions import sum, min, max, col, current_date, udf, collect_list
from pyspark.sql.types import *

billable_companies = companies.join(accounts, companies.AccountId == accounts.BillingAccountId) \
                              .join(licenses, (licenses.AccountId == companies.AccountId) & (licenses.InstanceId == companies.InstanceId)) \
                              .filter(licenses.Billable == True) \
                              .filter(accounts.IsTest == False) \
                              .filter(companies.IsLegacy == False) \
                              .drop(licenses.CompanyId).drop(licenses.TierId).drop(licenses.InstanceId) \
                              .select(col('CompanyId'), col('InstanceId'), col('CompanyName'), col('TierId').alias('Tier'), col('SetupCompletionTime').alias('IsSetupComplete'), \
                                      col('UserIntegrationType'), col('BillingAccountId').alias('AccountId'), accounts.AccountName, \
                                      col('City'), col('RegionName'), col('PostalCode'), col('CountryCode'), col('Culture')) \
                              .distinct()

billable_companies = billable_companies.withColumn('IsSetupComplete', billable_companies.IsSetupComplete.isNotNull())

billable_companies = billable_companies.filter(col('UserIntegrationType').isNotNull())

if registered_only == True:
  billable_companies = billable_companies.filter(col('IsSetupComplete') == True)
  
# display(billable_companies)

** Calculate Company Churn **
- add features: licenses by company, license duration. 
- add target: Churn

Aggregate activation, expiration, deactivateion per company

In [0]:
# group licenses by company based on start_dateime and pred_datetime.
# ExpirationDate and DeactDate have null as Max. Let's replace null with a big '2100--01-01'
# in order to get correct max('DeactDate') and max('ExpirationDate')
from pyspark.sql.functions import when

licenses = licenses.fillna({'DeactDate' : '2100-01-01'})
licenses = licenses.fillna({'ExpirationDate' : '2100-01-01'})
licenses_by_account = licenses.groupBy('CompanyId', 'AccountId', 'InstanceId') \
                              .agg(min('CreatedOn').alias('Activation'), max('ExpirationDate').alias('Expiration'), max('DeactDate').alias('Deactivation')) \
                              .filter(col('Activation').isNotNull()) \
                              .filter((col('Activation') >= start_dateime) & (col('Activation') < pred_datetime))

# now we need to replace '2100--01-01' with null
licenses_by_account = licenses_by_account.withColumn('Deactivation', when(licenses_by_account.Deactivation == '2100-01-01T00:00:00.000+0000', None).otherwise(licenses_by_account.Deactivation))
licenses_by_account = licenses_by_account.withColumn('Expiration', when(licenses_by_account.Expiration == '2100-01-01', None).otherwise(licenses_by_account.Expiration))

Calculate license duration

In [0]:
#calculate license duration in days
@udf(returnType=IntegerType())
def licenseDuration(colActivation, colDeactivation):
  if colActivation is not None and colActivation < pred_datetime:
    if colDeactivation is not None and colDeactivation < pred_datetime:
      delta = colDeactivation - colActivation
      # colDeactivation can be smaller than colActivation if it occurs within one day
      # because of the bug in db.
      return 0 if delta.days < 0 else delta.days
    else:
      delta = pred_datetime - colActivation
      return delta.days
  # return -1 for activation is after pred_datetime  
  return -1 

# calculate License Duration
company_churn = licenses_by_account.withColumn('Duration', licenseDuration(col('Activation'), col('Deactivation')))
# exclude records with activation after pred_datetime
company_churn = company_churn.filter(company_churn.Duration != -1)

Calculate license counts: Standalone and Addon

In [0]:
# agg over 'Count' isn't correct. We need to do custom sum() over 'Count'
@udf(returnType=IntegerType())
def getStandaloneLicenseCounts(vals, licenses, orginalLicenses, statuses, standalones):
  sum = 0
  i = 0
  for val in vals:
    if standalones[i]:
        # only sum active license counts
        if statuses[i] == 0:
          sum += val
        else:
          # for inactive licenses, only count original licenses 
          if orginalLicenses[i] == 0:
            sum += val
    i += 1
  return sum

# OriginalLicenseId = null if no tier change etc. collect_list will not include those null values in aggregation.
# Fill them with 0 first to facillitate the calculation.
licenses_tmp = licenses.fillna({'OriginalLicenseId' : 0})
license_count_standalone = licenses_tmp.groupBy('CompanyId', 'AccountId').agg(getStandaloneLicenseCounts(collect_list('Count'), collect_list('LicenseId'), collect_list('OriginalLicenseId'), \
                                                                              collect_list('Status'), collect_list('Standalone')).alias('StandaloneLicenses'))

# add standalone license count
company_churn = company_churn.join(license_count_standalone, (company_churn.CompanyId == license_count_standalone.CompanyId) & (company_churn.AccountId == license_count_standalone.AccountId)) \
                             .drop(license_count_standalone.AccountId).drop(license_count_standalone.CompanyId)

# do same to add Addon licenses
@udf(returnType=IntegerType())
def getAddonLicenseCounts(vals, licenses, orginalLicenses, statuses, Standalones):
  sum = 0
  i = 0
  for val in vals:
    if Standalones[i] == False:
        # only sum active license counts
        if statuses[i] == 0:
          sum += val
        else:
          # for inactive licenses, only count original licenses 
          if orginalLicenses[i] == 0:
            sum += val
    i += 1
  return sum

license_count_addon = licenses_tmp.groupBy('CompanyId', 'AccountId').agg(getAddonLicenseCounts(collect_list('Count'), collect_list('LicenseId'), collect_list('OriginalLicenseId'), \
                                                                              collect_list('Status'), collect_list('Standalone')).alias('AddonLicenses'))

# add addon license count
company_churn = company_churn.join(license_count_addon, (company_churn.CompanyId == license_count_addon.CompanyId) & (company_churn.AccountId == license_count_addon.AccountId)) \
                             .drop(license_count_addon.AccountId).drop(license_count_addon.CompanyId)

Calculate if company churn or not

In [0]:
# calculate Churn or No-Churn
@udf(returnType=IntegerType())
def isChurned(colDeactivation):
  if colDeactivation is not None and colDeactivation < pred_datetime:
    return 1
  return 0

company_churn = company_churn.withColumn('Churn', isChurned(col('Deactivation')))

#***** only keep active ones as input ******
company_churn = company_churn.filter(company_churn.Churn != 1)

Add feature: incidents by account

In [0]:
# accounts.createOrReplaceTempView('Accounts')
# incidents.createOrReplaceTempView('Incidents')

# cutoff_date = pred_datetime.strftime('%Y-%m-%d')

# account_incidents = spark.sql('''
#                             SELECT 
#                                 a.BillingAccountId as AccountId
#                                 ,a.AccountName
#                                 ,a.Status
#                                 ,(SELECT count(1) FROM Incidents i WHERE a.accountid == i.accountid) as IncidentCount
#                             FROM Accounts a
#                             where a.BillingAccountId is not null and a.IsTest=false
#                       ''')

# display(account_incidents.orderBy(col('IncidentCount').desc()))

In [0]:
account_incidents_tmp = accounts.drop('modifiedon').join(incidents, incidents.accountid == accounts.accountid) \
                            .filter(col('BillingAccountId').isNotNull()) \
                            .filter(col('IsTest') == False) \
                            .filter(col('modifiedon') < pred_datetime)

account_incidents = account_incidents_tmp.groupBy('BillingAccountId').count().withColumnRenamed('count', 'Incidents')

# display(account_incidents.filter('BillingAccountId=2520'))

Add feature: number of trips by company

In [0]:
company_trips = trips.filter(col('StartLocalTime') < pred_datetime) \
                  .groupBy('CompanyId', 'InstanceId').count().withColumnRenamed('count', 'Trips')
                  

Add feature: number of forms submitted by company

In [0]:
company_forms = form_headers.filter(col('StartLocalTimeTag') < pred_datetime) \
                  .groupBy('CompanyId', 'InstanceId').count().withColumnRenamed('count', 'Forms')
                  

Add feature: number of timekeeping submitted by company

In [0]:
timekeeping = timekeeping_statuses.filter(col('StartLocalTime') < pred_datetime) \
                  .groupBy('CompanyId', 'InstanceId').count().withColumnRenamed('count', 'Timekeeping')

Add feature: number of orders submitted by company

In [0]:
orders = order_headers.filter(col('StartTime') < pred_datetime) \
            .groupBy('CompanyId', 'InstanceId').count().withColumnRenamed('count', 'Orders')

** Account Churn: combine all features **

In [0]:
billable_companies_w_licenses = billable_companies.join(company_churn, (company_churn.CompanyId == billable_companies.CompanyId) & (company_churn.InstanceId == billable_companies.InstanceId)) \
                                     .drop(company_churn.CompanyId).drop(company_churn.AccountId).drop(company_churn.InstanceId)

print('billable companies join with licenses: ', billable_companies_w_licenses.count())

#**** aggregate "incidents". "incidents" only has accountId to be associated with.
billable_account_churn = billable_companies_w_licenses.join(account_incidents, account_incidents.BillingAccountId == billable_companies_w_licenses.AccountId, "left_outer") \
                                    .drop(account_incidents.BillingAccountId)
# fill null as 0 if no incidents
billable_account_churn = billable_account_churn.fillna({'Incidents' : 0})

#**** aggregate "trips"
billable_account_churn = billable_account_churn.join(company_trips, (company_trips.CompanyId == billable_account_churn.CompanyId) & (company_trips.InstanceId == billable_account_churn.InstanceId), "left_outer") \
                                    .drop(company_trips.CompanyId).drop(company_trips.InstanceId)
# fill null as 0 if no Trips
billable_account_churn = billable_account_churn.fillna({'Trips' : 0})

#*** aggregate "forms"
billable_account_churn = billable_account_churn.join(company_forms, (company_forms.CompanyId == billable_account_churn.CompanyId) & (company_forms.InstanceId == billable_account_churn.InstanceId), "left_outer") \
                                    .drop(company_forms.CompanyId).drop(company_forms.InstanceId)
# fill null as 0 if no Forms
billable_account_churn = billable_account_churn.fillna({'Forms' : 0})

#*** aggregate "timekeeping"
billable_account_churn = billable_account_churn.join(timekeeping, (timekeeping.CompanyId == billable_account_churn.CompanyId) & (timekeeping.InstanceId == billable_account_churn.InstanceId), "left_outer") \
                                    .drop(timekeeping.CompanyId).drop(timekeeping.InstanceId)
# fill null as 0 if no Timekeeping
billable_account_churn = billable_account_churn.fillna({'Timekeeping' : 0})

#** aggregate "orders"
billable_account_churn = billable_account_churn.join(orders, (orders.CompanyId == billable_account_churn.CompanyId) & (orders.InstanceId == billable_account_churn.InstanceId), "left_outer") \
                                    .drop(orders.CompanyId).drop(orders.InstanceId)
# fill null as 0 if no Orders
billable_account_churn = billable_account_churn.fillna({'Orders' : 0})

# cache it
billable_account_churn_cache = billable_account_churn.cache()


** Define features **

In [0]:
categorical_features = ["IsSetupComplete", "Tier", "UserIntegrationType"]
numerical_features = ["StandaloneLicenses", "AddonLicenses", "Duration", "Incidents", "Trips", "Forms", "Timekeeping", "Orders"]

** Convert to Pandas dataframe **

In [0]:
import platform
import pandas as pd
import databricks.koalas as ks
import sklearn
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.rc("font", size=14)
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn import metrics
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
import seaborn as sns
sns.set(style="white")
sns.set(style="whitegrid", color_codes=True)
