# ONLINE RETAIL BIG DATA PROJECT USING PYSPARK

Data description found in [this link](https://archive.ics.uci.edu/ml/datasets/Online%20Retail).

# Data Processing

#### Load Data

In [0]:
# File location and type
file_location = "/FileStore/tables/OnlineRetail.csv"
file_type = "csv"

# CSV options
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","

# The applied options are for CSV files. For other file types, these will be ignored.
df = spark.read.format(file_type) \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .load(file_location)

display(df)

InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country
536365,85123A,WHITE HANGING HEART T-LIGHT HOLDER,6,01/12/2010 08:26,2.55,17850.0,United Kingdom
536365,71053,WHITE METAL LANTERN,6,01/12/2010 08:26,3.39,17850.0,United Kingdom
536365,84406B,CREAM CUPID HEARTS COAT HANGER,8,01/12/2010 08:26,2.75,17850.0,United Kingdom
536365,84029G,KNITTED UNION FLAG HOT WATER BOTTLE,6,01/12/2010 08:26,3.39,17850.0,United Kingdom
536365,84029E,RED WOOLLY HOTTIE WHITE HEART.,6,01/12/2010 08:26,3.39,17850.0,United Kingdom
536365,22752,SET 7 BABUSHKA NESTING BOXES,2,01/12/2010 08:26,7.65,17850.0,United Kingdom
536365,21730,GLASS STAR FROSTED T-LIGHT HOLDER,6,01/12/2010 08:26,4.25,17850.0,United Kingdom
536366,22633,HAND WARMER UNION JACK,6,01/12/2010 08:28,1.85,17850.0,United Kingdom
536366,22632,HAND WARMER RED POLKA DOT,6,01/12/2010 08:28,1.85,17850.0,United Kingdom
536367,84879,ASSORTED COLOUR BIRD ORNAMENT,32,01/12/2010 08:34,1.69,13047.0,United Kingdom


#### Missing values

In [0]:
# find missing values in all columns of the dataframe

from pyspark.sql.functions import col,isnan,when,count

df_missing = df.select([count(when(col(c).contains('None') | col(c).contains('NULL') | (col(c) == '' ) | col(c).isNull() | isnan(c), c)).alias(c) for c in df.columns])
df_missing.show()

+---------+---------+-----------+--------+-----------+---------+----------+-------+
|InvoiceNo|StockCode|Description|Quantity|InvoiceDate|UnitPrice|CustomerID|Country|
+---------+---------+-----------+--------+-----------+---------+----------+-------+
|        0|        0|       1454|       0|          0|        0|    135080|      0|
+---------+---------+-----------+--------+-----------+---------+----------+-------+



The column **CustomeID** contains **135,080** missing values. We shall drop all rows with missing values since we cannot attribute these transactions to customers - our focus in this project.

In [0]:
# drop rows with missing values in any column

df = df.na.drop()

#### Data type conversions

In [0]:
# convert InvoiceDate to a date format

from pyspark.sql.functions import to_date

df1 = df.select('InvoiceNo','StockCode','Description','Quantity','UnitPrice','CustomerID','Country', to_date(df.InvoiceDate, 'dd/MM/yyyy HH:mm').alias('dateFormatted'))

#### Duplicates

In [0]:
# let us keep only the distinct rows and drop any duplicates

df1 = df1.distinct()

# Customer Segmentation

#### Number of customers/products/countries

In [0]:
n_customers = df1.select("CustomerID").distinct().count()
n_products = df1.select("StockCode").distinct().count()
n_countries = df1.select("Country").distinct().count()

print("Number of customers: ", n_customers)
print("Number of products: ", n_products)
print("Number of countries: ", n_countries)

Number of customers:  4372
Number of products:  3684
Number of countries:  37


#### Products per transaction

In [0]:
from pyspark.sql.functions import count, col, desc

df1.groupBy('InvoiceNo').agg(count('StockCode').alias('no_products')).sort(desc('no_products')).show()

+---------+-----------+
|InvoiceNo|no_products|
+---------+-----------+
|   576339|        542|
|   579196|        533|
|   580727|        529|
|   578270|        442|
|   573576|        435|
|   567656|        421|
|   567183|        391|
|   575607|        377|
|   571441|        364|
|   570488|        353|
|   572552|        352|
|   568346|        335|
|   569246|        285|
|   547063|        281|
|   562031|        274|
|   570672|        259|
|   554098|        256|
|   543040|        244|
|   562046|        219|
|   569897|        218|
+---------+-----------+
only showing top 20 rows



#### Transactions per country

In [0]:
from pyspark.sql.functions import count, col, desc

transactions = df1.groupBy('Country').agg(count('InvoiceNo').alias('no_transactions')).sort(desc('no_transactions'))
display(transactions)

Country,no_transactions
United Kingdom,356727
Germany,9480
France,8475
EIRE,7475
Spain,2528
Netherlands,2371
Belgium,2069
Switzerland,1877
Portugal,1471
Australia,1258


Output can only be rendered in Databricks

#### Time series of transactions per month

In [0]:
# transactions per month
from pyspark.sql.functions import month, mean, sum, asc, countDistinct

transactionsTS = df1.groupBy('dateFormatted').agg(countDistinct('InvoiceNo').alias('no_transactions')).sort(asc('dateFormatted'))
transactionsByMonth = transactionsTS.groupBy(month("dateFormatted").alias("month")).agg(sum("no_transactions").alias("transactions")).sort(asc('month'))
display(transactionsByMonth)

month,transactions
1,1236
2,1202
3,1619
4,1384
5,1849
6,1707
7,1593
8,1544
9,2078
10,2263


Output can only be rendered in Databricks

#### Canceled orders
- The invoice numbers of transactions that were cancelled beging with 'C' as indicated in the description of the dataset

In [0]:
cancelled = df1.filter(col("InvoiceNo").contains("C"))
display(cancelled)

InvoiceNo,StockCode,Description,Quantity,UnitPrice,CustomerID,Country,dateFormatted
C536543,22355,CHARLOTTE BAG SUKI DESIGN,-2,0.85,17841,United Kingdom,2010-12-01
C537320,35400,WOODEN BOX ADVENT CALENDAR,-1,8.95,14867,United Kingdom,2010-12-06
C538713,22179,SET 10 LIGHTS NIGHT OWL,-1,6.75,16570,United Kingdom,2010-12-14
C539956,35004C,SET OF 3 COLOURED FLYING DUCKS,-15,4.65,12980,United Kingdom,2010-12-23
C540006,21306,SET/4 DAISY MIRROR MAGNETS,-1,2.1,14606,United Kingdom,2011-01-04
C540417,22197,SMALL POPCORN HOLDER,-10,0.85,13680,United Kingdom,2011-01-07
C541009,22411,JUMBO SHOPPER VINTAGE RED PAISLEY,-1,1.95,17581,United Kingdom,2011-01-13
C536812,22130,PARTY CONE CHRISTMAS DECORATION,-144,0.72,16546,United Kingdom,2010-12-02
C537797,22501,PICNIC BASKET WICKER LARGE,-1,8.5,13113,United Kingdom,2010-12-08
C538341,21647,ASSORTED TUTTI FRUTTI LARGE PURSE,-1,2.1,15514,United Kingdom,2010-12-10


Notice that the quantities of product purchased are negative values

In [0]:
# number of orders that were cancelled

n_cancelled = cancelled.select('InvoiceNo').distinct().count()

print("Number of cancelled orders: ", n_cancelled)

Number of cancelled orders:  3654


#### StockCodes with alphabets

A look at the data shows there are stock codes that contain alphabets in them while some do not. We narrow down to this to understand more about this phenomenon.

From the dataset we note the meaning of these stock codes as below:

- POST -> Postage
- D -> Discount
- C2 -> Carriage
- M -> Manual
- BANK CHARGES -> Bank Charges
- PADS -> Pads to match all the cushions sold
- DOT -> Dotcom postage

In [0]:
# get stock codes with alphabets
alphaStockCodes = df1.filter(df1.StockCode.rlike('^[a-zA-Z]'))
display(alphaStockCodes)

InvoiceNo,StockCode,Description,Quantity,UnitPrice,CustomerID,Country,dateFormatted
537804,M,Manual,12,0.19,12748,United Kingdom,2010-12-08
538003,POST,POSTAGE,8,18.0,12429,Denmark,2010-12-09
539840,POST,POSTAGE,3,15.0,12383,Belgium,2010-12-22
540073,M,Manual,6,1.65,16814,United Kingdom,2011-01-04
537915,POST,POSTAGE,1,28.0,12797,Portugal,2010-12-09
540410,POST,POSTAGE,3,18.0,12530,Germany,2011-01-07
540278,M,Manual,1,1.65,15719,United Kingdom,2011-01-06
540351,POST,POSTAGE,2,18.0,12735,France,2011-01-06
541405,POST,POSTAGE,5,18.0,12683,France,2011-01-17
537201,POST,POSTAGE,6,18.0,12472,Germany,2010-12-05


Output can only be rendered in Databricks

#### Basket prices

In [0]:
# we create a new df with a new column TotalPrice derived from UnitPrice and Quantity
df1 = df1.withColumn("TotalPrice", df1.UnitPrice*df1.Quantity)

mvBasketDF = df1.groupBy('CustomerID','InvoiceNo').agg(sum('TotalPrice').alias('BasketPrice'))
mvBasketDF.show()

+----------+---------+------------------+
|CustomerID|InvoiceNo|       BasketPrice|
+----------+---------+------------------+
|     14723|   539321|208.46000000000004|
|     12601|   540769|            144.26|
|     13328|   542397|           1308.48|
|     16931|   539414|            212.61|
|     17609|   541265|             423.6|
|     15426|   537196|            382.57|
|     17850|   536690|            331.26|
|     15955|   536635| 572.3800000000001|
|     13018|   548920|493.45000000000005|
|     14796|   547730|349.51000000000005|
|     17431|   542613|403.29999999999995|
|     17811|   543801|             210.7|
|     13662|   547803|227.70000000000005|
|     13394|   543054|313.76000000000005|
|     14688|   548742|375.15999999999997|
|     13819|   543658|425.05000000000007|
|     16722|   544914|            314.36|
|     17449|   545338|338.34000000000003|
|     14163|   550180|            411.87|
|     15696|  C552707|             -9.95|
+----------+---------+------------

# Product Categories

#### TF-IDF
- [Term frequency-inverse document frequency (TF-IDF)](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a feature vectorization method widely used in text mining to reflect the importance of a term to a document in the corpus. You can find more information about its implementation on PySpark [here](https://spark.apache.org/docs/latest/ml-features.html#tf-idf).

In [0]:
from pyspark.ml.feature import HashingTF, IDF, Tokenizer

tokenizer = Tokenizer(inputCol="Description", outputCol="descriptionTokens")
wordsData = tokenizer.transform(df1)

# tf transformation
hashingTF = HashingTF(inputCol="descriptionTokens", outputCol="rawFeatures", numFeatures=20)

featurizedData = hashingTF.transform(wordsData)

# alternatively, CountVectorizer can also be used to get term frequency vectors

# idf transformation
idf = IDF(inputCol="rawFeatures", outputCol="features")
idfModel = idf.fit(featurizedData)
rescaledData = idfModel.transform(featurizedData)

#### Creating clusters of products

In [0]:
# Import the required libraries
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline

# Create an object for the Logistic Regression model
kmeans_model = KMeans(k=4)

# fit
fit_model = kmeans_model.fit(rescaledData.select('features'))

# wsse = fit_model.computeCost(final_data) for spark 2.7
wssse = fit_model.summary.trainingCost # for spark 3.0
print("The within set sum of squared error of the mode is {}".format(wssse))

# Store the results in a dataframe

results = fit_model.transform(rescaledData.select('features'))

The within set sum of squared error of the mode is 3531130.7557436135


#### Finding optimal number of clusters

- **NOTE:** 
The part below DOES NOT scale. We could not find a suitable way of going about this step.

In [0]:
#silhoutte analysis to find the optimal number of clusters. 
silhouette_score=[]

from pyspark.ml.evaluation import ClusteringEvaluator

evaluator = ClusteringEvaluator(predictionCol='prediction', featuresCol='features',metricName='silhouette', distanceMeasure='squaredEuclidean')

for i in range(2,10):  
    KMeans_model = KMeans(k=i)
    KMeans_fit=KMeans_model.fit(rescaledData)
    output=KMeans_fit.transform(rescaledData)
    score=evaluator.evaluate(output) 
    silhouette_score.append(score)    
    print("Silhouette Score:", score)

Silhouette Score: 0.0976008504262662
Silhouette Score: 0.14821852091401963
Silhouette Score: 0.15049483139255138
Silhouette Score: 0.1603341734663381
Silhouette Score: 0.16864840283191324
Silhouette Score: 0.19020784681942218
Silhouette Score: 0.16420181156094876
Silhouette Score: 0.18791031822847862


- The silhoutte score is maximized when k = 8 but the difference between average silhoutte when k = 8 and k = 4 is very small. And, since the difference between the silhouette score is very small, Clustering with k = 8 may not group the points together that has similar characterstics than does k = 4. So, we are using clusters with 4 k's.

In [0]:
display(results.groupby('prediction').count().sort('prediction'))

prediction,count
0,52421
1,45430
2,238623
3,65129


Output can only be rendered in Databricks

# Customer Categories

#### New feature: Product Category

In [0]:
# we add the product category allocations to the dataframe
from pyspark.sql.types import StructType, StructField, LongType

def with_column_index(df): 
    new_schema = StructType(df.schema.fields + [StructField("ColumnIndex", LongType(), False),])
    return df.rdd.zipWithIndex().map(lambda row: row[0] + (row[1],)).toDF(schema=new_schema)

df1_ci = with_column_index(df1)
df2_ci = with_column_index(results)

join_on_index = df1_ci.join(df2_ci, df1_ci.ColumnIndex == df2_ci.ColumnIndex, 'inner').drop("ColumnIndex","features","StockCode","Quantity","UnitPrice","Country","dateFormatted")

join_on_index.sort('InvoiceNo').show()

+---------+--------------------+----------+------------------+----------+
|InvoiceNo|         Description|CustomerID|        TotalPrice|prediction|
+---------+--------------------+----------+------------------+----------+
|   536365| WHITE METAL LANTERN|     17850|             20.34|         2|
|   536365|CREAM CUPID HEART...|     17850|              22.0|         2|
|   536365|RED WOOLLY HOTTIE...|     17850|             20.34|         3|
|   536365|KNITTED UNION FLA...|     17850|             20.34|         3|
|   536365|SET 7 BABUSHKA NE...|     17850|              15.3|         1|
|   536365|WHITE HANGING HEA...|     17850|15.299999999999999|         0|
|   536365|GLASS STAR FROSTE...|     17850|              25.5|         3|
|   536366|HAND WARMER RED P...|     17850|11.100000000000001|         3|
|   536366|HAND WARMER UNION...|     17850|11.100000000000001|         2|
|   536367|IVORY KNITTED MUG...|     13047| 9.899999999999999|         3|
|   536367|ASSORTED COLOUR B...|     1

#### Monetary value per category

In [0]:
# create columns from product category and fill null values with 0
pivotDF = join_on_index.groupBy("CustomerID","InvoiceNo","Description").pivot("prediction").sum("TotalPrice").na.fill(value=0)
display(pivotDF)

CustomerID,InvoiceNo,Description,0,1,2,3
16697,549910,RECYCLED ACAPULCO MAT TURQUOISE,8.25,0.0,0.0,0.0
16320,552695,HANGING JAM JAR T-LIGHT HOLDER,20.4,0.0,0.0,0.0
17193,556972,SET OF 4 KNICK KNACK TINS LONDON,0.0,0.0,49.8,0.0
16550,537391,PINK POLKADOT CUP,0.0,0.0,0.0,5.95
12415,543989,CHILDS BREAKFAST SET SPACEBOY,0.0,0.0,204.0,0.0
14911,545657,JAM MAKING SET WITH JARS,0.0,0.0,22.5,0.0
15750,557745,UNION JACK FLAG PASSPORT COVER,0.0,0.0,0.0,6.300000000000001
17511,542789,WRAP ALPHABET DESIGN,0.0,0.0,10.5,0.0
18192,568792,DOORMAT SPOTTY HOME SWEET HOME,0.0,16.5,0.0,0.0
16628,561391,WOODLAND CHARLOTTE BAG,0.0,0.0,8.5,0.0


In [0]:
mvBasketDF.sort('CustomerID').show()

+----------+---------+------------------+
|CustomerID|InvoiceNo|       BasketPrice|
+----------+---------+------------------+
|     12346|  C541433|          -77183.6|
|     12346|   541431|           77183.6|
|     12347|   556201|382.52000000000004|
|     12347|   562032| 584.9100000000001|
|     12347|   573511|           1294.32|
|     12347|   537626| 711.7900000000001|
|     12347|   549222|            636.25|
|     12347|   581180|            224.82|
|     12347|   542237|            475.39|
|     12348|   541998|227.43999999999997|
|     12348|   539318|             892.8|
|     12348|   568172|             310.0|
|     12348|   548955|             367.0|
|     12349|   577609|           1757.55|
|     12350|   543037|334.40000000000003|
|     12352|   544156|             296.5|
|     12352|   546869|            120.33|
|     12352|   547390|160.32999999999998|
|     12352|   545332|             840.3|
|     12352|   567505|366.25000000000006|
+----------+---------+------------

#### Value of each category per basket

In [0]:
# value distribution by product category
basketValPerCat = pivotDF.groupBy('CustomerID','InvoiceNo').agg(sum('0').alias('cat0'),sum('1').alias('cat1'),sum('2').alias('cat2'),sum('3').alias('cat3'))

# total basket value as a new column
basketValPerCat = basketValPerCat.withColumn("BasketValue", basketValPerCat.cat0 + basketValPerCat.cat1 + basketValPerCat.cat2 + basketValPerCat.cat3)

# drop rows whose total basket value is negative (ie cancellations)
basketValPerCat = basketValPerCat.where(basketValPerCat.BasketValue > 0)
basketValPerCat.show()

+----------+---------+------------------+------------------+------------------+------------------+------------------+
|CustomerID|InvoiceNo|              cat0|              cat1|              cat2|              cat3|       BasketValue|
+----------+---------+------------------+------------------+------------------+------------------+------------------+
|     16056|   572726|              88.4|60.150000000000006|435.30999999999995|             66.89| 650.7499999999999|
|     13328|   542397|              30.0| 79.19999999999999|             889.8|            309.48|           1308.48|
|     13018|   548920|             160.1|              81.6|             200.8|             50.95|            493.45|
|     12700|   569568|107.30000000000001|            270.18| 907.7099999999998|             191.0|1476.1899999999998|
|     14796|   547730|             77.05|             75.94|            126.02|              70.5|            349.51|
|     16931|   560244|28.269999999999996|             22

#### Customer data

In [0]:
from pyspark.sql.functions import count, min, sum, max, avg

customersDF = basketValPerCat.groupBy('CustomerID').agg(count('InvoiceNo').alias('no_purchases'),min('BasketValue').alias('minVal'),max('BasketValue').alias('maxVal'),\
                                                        avg('BasketValue').alias('meanVal'),sum('BasketValue').alias('totalVal'),sum('cat0').alias('cat0'),sum('cat1').alias('cat1'),sum('cat2').alias('cat2'),sum('cat3').alias('cat3'))

In [0]:
# distribution of total value of purchases per customer

display(customersDF)

CustomerID,no_purchases,minVal,maxVal,meanVal,totalVal,cat0,cat1,cat2,cat3
17389,34,52.5,7427.32,936.2847058823528,31833.68,440.94,243.9,25600.619999999995,5548.219999999999
14450,3,108.42000000000002,213.8,161.08333333333334,483.25,17.58,84.0,266.67,115.0
15727,7,374.8200000000001,1516.95,737.0085714285714,5159.06,767.6300000000001,193.21,3277.66,920.56
15790,1,218.75,218.75,218.75,218.75,14.22,36.5,111.09,56.94
13285,4,506.38,796.8299999999999,677.28,2709.12,256.42,437.9,1602.56,412.24
14570,2,50.150000000000006,167.91,109.03,218.06,38.8,13.95,143.91,21.4
16574,1,451.44000000000005,451.44000000000005,451.44000000000005,451.44000000000005,43.32,15.3,303.72,89.1
13832,1,52.2,52.2,52.2,52.2,32.4,0.0,19.8,0.0
13623,5,88.6,188.57,145.54799999999997,727.7399999999999,103.01,158.73,425.1,40.900000000000006
15619,1,336.40000000000003,336.40000000000003,336.40000000000003,336.40000000000003,0.0,91.8,99.6,145.0


Output can only be rendered in Databricks

Output can only be rendered in Databricks

Output can only be rendered in Databricks

Output can only be rendered in Databricks

Output can only be rendered in Databricks

Output can only be rendered in Databricks

Output can only be rendered in Databricks

Output can only be rendered in Databricks

Output can only be rendered in Databricks

The distribution of each individual numerical column is highly skewed probably because there exists absurdly high values related to **cancelled/erronous orders** which were not removed from the dataset. This **will** significantly affect the generizability of the model to be trained on the data.

#### Creating customer categories

- In this step we use unsupervised learning to create customer categories.
- It is the basis of the next step - classification

In [0]:
# Vector assembler is used to create a vector of input features
assembler2 = VectorAssembler(inputCols=['no_purchases','minVal','maxVal','meanVal','totalVal','cat0','cat1','cat2','cat3'],\
                            outputCol="features")

# Pipeline is used to pass the data through indexer and assembler simultaneously. Also, it helps to pre-rocess the test data
# in the same way as that of the train data.
pipe2 = Pipeline(stages=[assembler2])

final_data2=pipe2.fit(customersDF).transform(customersDF)

# Create an object for the Logistic Regression model

kmeans_model2 = KMeans(k=4)

# fit
fit_model2 = kmeans_model2.fit(final_data2.select('features'))

# wsse = fit_model.computeCost(final_data) for spark 2.7
wssse2 = fit_model2.summary.trainingCost # for spark 3.0
print("The within set sum of squared error of the mode is {}".format(wssse2))

# Store the results in a dataframe

results2 = fit_model2.transform(final_data2.select('features'))

The within set sum of squared error of the mode is 98509463715.91588


In [0]:
display(results2.groupby('prediction').count().sort('prediction'))

prediction,count
0,4212
1,4
2,104
3,18


Output can only be rendered in Databricks

In [0]:
# we add the customer category allocations to the customers dataframe
from pyspark.sql.types import StructType, StructField, LongType

def with_column_index(df): 
    new_schema = StructType(df.schema.fields + [StructField("ColumnIndex", LongType(), False),])
    return df.rdd.zipWithIndex().map(lambda row: row[0] + (row[1],)).toDF(schema=new_schema)

# using the function defined earlier
df11_ci = with_column_index(customersDF)
df21_ci = with_column_index(results2)

join_on_index1 = df11_ci.join(df21_ci, df11_ci.ColumnIndex == df21_ci.ColumnIndex, 'inner').drop("ColumnIndex","features","no_purchases","minVal","maxVal","totalVal")

#### Data visualization

In [0]:
display(join_on_index1)

CustomerID,meanVal,cat0,cat1,cat2,cat3,prediction
17389,936.2847058823528,440.94,243.9,25600.619999999995,5548.219999999999,2
14450,161.08333333333334,17.58,84.0,266.67,115.0,0
15727,737.0085714285714,767.6300000000001,193.21,3277.66,920.56,0
15790,218.75,14.22,36.5,111.09,56.94,0
13285,677.28,256.42,437.9,1602.56,412.24,0
14570,109.03,38.8,13.95,143.91,21.4,0
16574,451.44000000000005,43.32,15.3,303.72,89.1,0
13832,52.2,32.4,0.0,19.8,0.0,0
13623,145.54799999999997,103.01,158.73,425.1,40.900000000000006,0
15619,336.40000000000003,0.0,91.8,99.6,145.0,0


# Classification of Customers
- We use only a few of the features from earlier steps in this next step to classify customers.

#### Train-test split

In [0]:
data = join_on_index1.withColumnRenamed("prediction", "customerGroup")

train_data,test_data=data.randomSplit([0.7,0.3])

#### Classifier: Decision Tree

In [0]:
# Import the required libraries
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline

In [0]:
# Vector assembler is used to create a vector of input features
assembler = VectorAssembler(inputCols=['CustomerID', 'meanVal', 'cat0', 'cat1', 'cat2', 'cat3'],
                            outputCol="features")

# Create an object for the decision tree classifier
dt_model = DecisionTreeClassifier(labelCol='customerGroup',maxBins=5000)

pipe = Pipeline(stages=[assembler,dt_model])

# fit model
fit_model=pipe.fit(train_data)

# Store the results in a dataframe
results = fit_model.transform(test_data)

# evaluate model
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

ACC_evaluator = MulticlassClassificationEvaluator(labelCol="customerGroup", predictionCol="prediction", metricName="accuracy")
accuracy = ACC_evaluator.evaluate(results)

print("The accuracy of the decision tree classifier is {}".format(accuracy))

The accuracy of the decision tree classifier is 0.9947328818660647


#### Classifier: Random Forest

In [0]:
from pyspark.ml.classification import RandomForestClassifier

# RandomForest model.
rf = RandomForestClassifier(labelCol="customerGroup", numTrees=10)

# Chain indexers and forest in a Pipeline
pipeRF = Pipeline(stages=[assembler, rf])

# fit model
fit_modelRF = pipeRF.fit(train_data)

# Store the results in a dataframe
resultsRF = fit_modelRF.transform(test_data)

# evaluate model
accuracyRF = ACC_evaluator.evaluate(resultsRF)

print("The accuracy of the random forest classifier is {}".format(accuracyRF))

The accuracy of the random forest classifier is 0.9917231000752446
