In [2]:
import findspark
findspark.init()
import pyspark
from pyspark.sql import *
import pyspark.sql.functions as F
from pyspark.sql.functions import *
from pyspark.sql.types import *

#Lifetime package in python is designed to calculate the CLTV 
import lifetimes
import sys

### Creating the Spark Session

In [3]:
spark=SparkSession.builder.appName("Spark Programming").getOrCreate()
sc = SQLContext(spark)
type(spark)

pyspark.sql.session.SparkSession

## Reading the Data from the files to the Spark Dataframe:

In [4]:
#Creating Schema for the dataframe
data_2009 = StructType(fields=[StructField('Invoice_No',StringType(),False),
                              StructField('StockCode',StringType(),False),
                              StructField('Description',StringType(),True),
                              StructField('Quantity',IntegerType(),True),
                              StructField('InvoiceDate',DateType(),True),
                              StructField('Price',DoubleType(),True),
                              StructField('Customer_ID',IntegerType(),False),   #Nullable is set to False so that any rows with customer Id as blank will not be included in dataframe
                              StructField('Country',StringType(),True)])


data_2010 = StructType(fields=[StructField('Invoice_No',StringType(),False),
                              StructField('StockCode',StringType(),False),
                              StructField('Description',StringType(),True),
                              StructField('Quantity',IntegerType(),True),
                              StructField('InvoiceDate',DateType(),True),
                              StructField('Price',DoubleType(),True),
                              StructField('Customer_ID',IntegerType(),False),
                              StructField('Country',StringType(),True)])

In [5]:
df1 = spark.read.format('csv').load("Data/Online_Retail_Data_2009.csv",header=True,schema=data_2009)
df2=spark.read.format('csv').load("Data/Online_Retail_Data_2010.csv",header=True,schema=data_2010)

In [6]:
df1.printSchema()

root
 |-- Invoice_No: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: date (nullable = true)
 |-- Price: double (nullable = true)
 |-- Customer_ID: integer (nullable = true)
 |-- Country: string (nullable = true)



In [7]:
df2.printSchema()

root
 |-- Invoice_No: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: date (nullable = true)
 |-- Price: double (nullable = true)
 |-- Customer_ID: integer (nullable = true)
 |-- Country: string (nullable = true)



In [8]:
#Let us merge the two seperated datasets to create a single dataset that contains customer transaction data from year 2009-2011
#To do this we will perform union operation
data=df1.union(df2).distinct()
data.show()

+----------+---------+--------------------+--------+-----------+-----+-----------+--------------+
|Invoice_No|StockCode|         Description|Quantity|InvoiceDate|Price|Customer_ID|       Country|
+----------+---------+--------------------+--------+-----------+-----+-----------+--------------+
|    489461|    72756|  FAIRY CAKE CANDLES|      27| 2009-12-01| 1.49|      17865|United Kingdom|
|    489529|    22030|SWALLOWS GREETING...|       1| 2009-12-01| 0.42|      17984|United Kingdom|
|    489529|    21866|UNION JACK FLAG L...|       1| 2009-12-01| 1.25|      17984|United Kingdom|
|    489564|    90088|CRYSTAL KEY+LOCK ...|      24| 2009-12-01| 0.85|      13526|United Kingdom|
|    489580|    21975|PACK OF 60 DINOSA...|      24| 2009-12-01| 0.55|      12921|United Kingdom|
|    489580|    21212|PACK OF 72 RETRO ...|      24| 2009-12-01| 0.55|      12921|United Kingdom|
|    489593|    21767|FRENCH STYLE WIRE...|       1| 2009-12-01|29.95|      12836|United Kingdom|
|    489593|    2179

In [9]:
#Let us get an idea of shape of the dataset
print(data.count(),len(data.columns))

1033480 8


## Check for the null values (if any)

In [10]:
def count_nulls(dataframe):
    dataframe.agg( *[ count(colmn).alias(colmn) for colmn in dataframe.columns ] ).show()
    
count_nulls(data)

+----------+---------+-----------+--------+-----------+-------+-----------+-------+
|Invoice_No|StockCode|Description|Quantity|InvoiceDate|  Price|Customer_ID|Country|
+----------+---------+-----------+--------+-----------+-------+-----------+-------+
|   1033480|  1033480|    1029205| 1033480|    1033480|1033480|     830390|1033480|
+----------+---------+-----------+--------+-----------+-------+-----------+-------+



In [14]:
# As we can see from above output that some customer IDs are null. This will not be useful as our analysis is at customer level
#So let's drop the rows having Customer_ID as null
data = data.dropna(how='any')
count_nulls(data)

+----------+---------+-----------+--------+-----------+------+-----------+-------+
|Invoice_No|StockCode|Description|Quantity|InvoiceDate| Price|Customer_ID|Country|
+----------+---------+-----------+--------+-----------+------+-----------+-------+
|    829247|   829247|     829247|  829247|     829247|829247|     829247| 829247|
+----------+---------+-----------+--------+-----------+------+-----------+-------+



## Let's see how many customers are under our analysis?

In [13]:
data.select(countDistinct(data['Customer_ID']).alias("#ofCustomers")).show()

+------------+
|#ofCustomers|
+------------+
|        5937|
+------------+



In [18]:
#Converting the dataframe to a queriable view. This will allow us to use power of SQL to query the dataframes
data.createOrReplaceTempView('Transactions')

In [23]:
#Getting the date for the first transaction and the date for the last transaction
#This will help us to define the time period
min_date, max_date = data.select(min("InvoiceDate"), max("InvoiceDate")).first()
print("Transaction Start Date: ",min_date , "Transaction End Date", max_date)

Transaction Start Date:  2009-12-01 Transaction End Date 2011-12-09


In [43]:
#Let us try to identify the total number of customers we have and their total transactions done within this period
sql_q="select count(distinct(Customer_ID)) as Total_Customers,count(distinct(Invoice_No)) as Customer_Total_Purchases from Transactions"
spark.sql(sql_q).show()

+---------------+------------------------+
|Total_Customers|Customer_Total_Purchases|
+---------------+------------------------+
|           5937|                   53628|
+---------------+------------------------+



In [31]:
sql_q1="select InvoiceDate, count(distinct(Invoice_No)) as Total_Purchases from Transactions group by InvoiceDate order by InvoiceDate; "
results_q1=spark.sql(sql_q1).show()

+-----------+---------------+
|InvoiceDate|Total_Purchases|
+-----------+---------------+
| 2009-12-01|            166|
| 2009-12-02|            133|
| 2009-12-03|            150|
| 2009-12-04|            107|
| 2009-12-05|             32|
| 2009-12-06|             90|
| 2009-12-07|            123|
| 2009-12-08|            170|
| 2009-12-09|            114|
| 2009-12-10|            123|
| 2009-12-11|            121|
| 2009-12-13|             77|
| 2009-12-14|            122|
| 2009-12-15|            123|
| 2009-12-16|            106|
| 2009-12-17|            162|
| 2009-12-18|             84|
| 2009-12-20|             40|
| 2009-12-21|             52|
| 2009-12-22|            203|
+-----------+---------------+
only showing top 20 rows



In [None]:
plot=results_q1.toHandy()

In [47]:
# Now let's check how many transactions are done by each customer over the defined time period:
sql_q2="select Customer_ID,count(distinct(Invoice_No)) as Total_Purchases from Transactions group by Customer_ID"
spark.sql(sql_q2).show()

+-----------+---------------+
|Customer_ID|Total_Purchases|
+-----------+---------------+
|      15727|             25|
|      16574|             17|
|      18024|             20|
|      17389|             85|
|      15619|             15|
|      15447|             17|
|      18051|             14|
|      16339|             18|
|      13289|             20|
|      14450|             15|
|      12940|             13|
|      17679|             20|
|      17420|             17|
|      16386|             23|
|      13623|             19|
|      17753|             20|
|      16503|             24|
|      15846|              7|
|      14570|             12|
|      14832|              4|
+-----------+---------------+
only showing top 20 rows



## Calculating RFM (Recency, Frequency, Monetary Value) Table
For each customer, RFM values are defined as:
1. Recency
2. Frequency:
3. Monetary Value: 
We will also calculate "Account_Age" which is the duration for which customer has been with the organization. This is calculated as the number of days between current date and date of the first purchase(Transaction made)

Since our last date in the Transactions dataset is 2011-12-09. We will take a date after that as our reference date to calculate the above metrics.

In [49]:
#Calculating the Total_Sales for each Purchase(or Transaction) for each customer
data = data.withColumn("Total_Sales", round(data["Quantity"] * data["Price"], 2))

In [61]:
#We will find the first and the last transaction date for each customer
first_txn_dt=data.groupBy(data.Customer_ID).agg(min(to_date(data.InvoiceDate)).alias('First_Txn_Date'))
last_txn_dt=data.groupBy().agg(max(to_date(data.InvoiceDate)).alias('Last_Txn_Date'))

In [73]:
monetary = data.groupBy('Customer_ID').agg(round(sum('Total_Sales'), 2).alias('Monetary_Value'))
sub_reslt = (data.crossJoin(last_txn_dt).join(first_txn_dt, data.Customer_ID==first_txn_dt.Customer_ID, how='inner')
     .select(data.Customer_ID.alias('Customer_ID'),first_txn_dt.First_Txn_Date,to_date(data.InvoiceDate).alias('Txn_Date'), 
      last_txn_dt.Last_Txn_Date
      )
     .distinct()
    )

In [66]:
sub_reslt.show()

+-----------+--------------+----------+-------------+
|Customer_ID|First_Txn_Date|  Txn_Date|Last_Txn_Date|
+-----------+--------------+----------+-------------+
|      14450|    2009-12-01|2011-06-12|   2011-12-09|
|      14148|    2009-12-14|2011-04-20|   2011-12-09|
|      13207|    2009-12-02|2009-12-07|   2011-12-09|
|      13329|    2009-12-04|2010-01-22|   2011-12-09|
|      14944|    2009-12-10|2009-12-10|   2011-12-09|
|      13468|    2009-12-03|2011-01-25|   2011-12-09|
|      17346|    2009-12-02|2010-03-03|   2011-12-09|
|      13313|    2009-12-02|2011-07-21|   2011-12-09|
|      13544|    2009-12-18|2010-10-24|   2011-12-09|
|      14245|    2009-12-15|2011-05-26|   2011-12-09|
|      15827|    2009-12-08|2011-06-13|   2011-12-09|
|      16391|    2009-12-01|2009-12-15|   2011-12-09|
|      17677|    2009-12-07|2010-03-22|   2011-12-09|
|      17904|    2009-12-01|2010-01-29|   2011-12-09|
|      12727|    2009-12-17|2010-02-01|   2011-12-09|
|      14760|    2009-12-03|

In [69]:
monetary.show()

+-----------+--------------+
|Customer_ID|Monetary_Value|
+-----------+--------------+
|      14450|       1182.52|
|      16574|       1452.68|
|      17679|       2977.51|
|      13285|       3427.88|
|      15727|       8518.38|
|      16503|       3827.72|
|      17389|      54616.11|
|      15957|       1149.36|
|      16861|       1229.82|
|      14570|        668.89|
|      18024|        357.74|
|      12799|        616.19|
|      13832|        602.99|
|      13623|        2130.2|
|      12940|         951.2|
|      15447|        628.07|
|      17420|       1850.62|
|      13289|        633.86|
|      17753|        653.45|
|      13840|        745.77|
+-----------+--------------+
only showing top 20 rows



In [82]:
rf_metrics = (sub_reslt.groupBy(sub_reslt.Customer_ID, sub_reslt.Last_Txn_Date, sub_reslt.First_Txn_Date)
           .agg((countDistinct(sub_reslt.Txn_Date)-1).cast(FloatType()).alias('Frequency'),
             datediff(max(sub_reslt.Txn_Date), sub_reslt.First_Txn_Date).cast(FloatType()).alias('Recency'),
             datediff(a.Last_Txn_Date, a.First_Txn_Date).cast(FloatType()).alias('Account_Age'),
            )
           .select('Customer_ID','Frequency','Recency','Account_Age')
           .orderBy('Customer_ID')
          )

In [83]:
rfm_metrics=rf_metrics.join(monetary,rf_metrics.Customer_ID==monetary.Customer_ID,how='inner')

In [84]:
rfm_metrics.show()

+-----------+---------+-------+-----------+-----------+--------------+
|Customer_ID|Frequency|Recency|Account_Age|Customer_ID|Monetary_Value|
+-----------+---------+-------+-----------+-----------+--------------+
|      12799|      8.0|  196.0|      725.0|      12799|        616.19|
|      12940|     11.0|  673.0|      719.0|      12940|         951.2|
|      13285|     12.0|  714.0|      737.0|      13285|       3427.88|
|      13289|     15.0|   51.0|      731.0|      13289|        633.86|
|      13623|     14.0|  707.0|      737.0|      13623|        2130.2|
|      13832|      3.0|  670.0|      687.0|      13832|        602.99|
|      13840|      6.0|  322.0|      738.0|      13840|        745.77|
|      14450|     12.0|  558.0|      738.0|      14450|       1182.52|
|      14570|      9.0|  449.0|      729.0|      14570|        668.89|
|      14832|      1.0|   67.0|      697.0|      14832|       -249.88|
|      15447|     15.0|  407.0|      737.0|      15447|        628.07|
|     

In [86]:
rfm_metrics.describe().show()

+-------+------------------+------------------+-----------------+-----------------+------------------+------------------+
|summary|       Customer_ID|         Frequency|          Recency|      Account_Age|       Customer_ID|    Monetary_Value|
+-------+------------------+------------------+-----------------+-----------------+------------------+------------------+
|  count|              5937|              5937|             5937|             5937|              5937|              5937|
|   mean|15314.092470944921|14.203806636348324|520.2135758800741|720.8881590028634|15314.092470944921|2796.5656223682004|
| stddev|1714.1657024031815|11.511825880492022|216.3221496679391|65.14393887771152|1714.1657024031815| 12790.79665140313|
|    min|             12346|               0.0|              0.0|             15.0|             12346|         -24972.99|
|    max|             18287|             277.0|            738.0|            738.0|             18287|         529721.15|
+-------+---------------