# Pink Flamingo - Eglence Inc.
# Cluster Analysis

by Cédric Membrez, June 2020

In [1]:
# imports
from pyspark.sql import SQLContext
from pyspark.sql.functions import months_between, round #to_date, datediff,
from pyspark.sql.functions import format_number
from pyspark.sql.functions import year, month, dayofmonth
from pyspark.sql.functions import sum as _sum
from pyspark.sql.functions import count as _count
from pyspark.sql.functions import col as _col

from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler

import pycountry_convert as pc
import pycountry
import pandas as pd
from datetime import datetime

import matplotlib.pyplot as plt

In [2]:
# create Spark Context
#sc = SparkContext.getOrCreate()
sqlContext = SQLContext(sc)

### READ DATA FILES
1) read separate CSV files

2) merge into one dataframe

3) clean dataframe

In [3]:
# read ad-clicks.csv
adclicks_raw = sqlContext.read.load('file:///home/cloudera/flamingo-data/ad-clicks.csv',
                           format='com.databricks.spark.csv',
                           header='true',
                           inferSchema='true')

In [4]:
# buy-clicks.csv
buyclicks_raw = sqlContext.read.load('file:///home/cloudera/flamingo-data/buy-clicks.csv',
                                format='com.databricks.spark.csv',
                                header='true',
                                inferSchema='true')

In [5]:
# game-clicks.csv
gameclicks_raw = sqlContext.read.load('file:///home/cloudera/flamingo-data/game-clicks.csv',
                                     format='com.databricks.spark.csv',
                                     header='true',
                                     inferSchema='true')

In [6]:
# level-events.csv
levelevents_raw = sqlContext.read.load('file:///home/cloudera/flamingo-data/level-events.csv',
                                      format='com.databricks.spark.csv',
                                      header='true',
                                      inferSchema='true')

In [7]:
# team-assignments.csv
teamassignments_raw = sqlContext.read.load('file:///home/cloudera/flamingo-data/team-assignments.csv',
                                          format='com.databricks.spark.csv',
                                          header='true',
                                          inferSchema='true')

In [8]:
# team.csv
team_raw = sqlContext.read.load('file:///home/cloudera/flamingo-data/team.csv',
                               format='com.databricks.spark.csv',
                               header='true',
                               inferSchema='true')

In [9]:
# user-session.csv
usersession_raw = sqlContext.read.load('file:///home/cloudera/flamingo-data/user-session.csv',
                                      format='com.databricks.spark.csv',
                                      header='true',
                                      inferSchema='true')

In [10]:
# users.csv
users_raw = sqlContext.read.load('file:///home/cloudera/flamingo-data/users.csv',
                                format='com.databricks.spark.csv',
                                header='true',
                                inferSchema='true')

In [11]:
# NUMBER OF ROWS PER FILE
print("Number of rows per file\n")
print("ad-clicks: {}".format(adclicks_raw.count()))
print("buy-clicks: {}".format(buyclicks_raw.count()))
print("game-clicks: {}".format(gameclicks_raw.count()))
print("level-events: {}".format(levelevents_raw.count()))
print("team-assignments: {}".format(teamassignments_raw.count()))
print("team: {}".format(team_raw.count()))
print("user-session: {}".format(usersession_raw.count()))
print("users: {}".format(users_raw.count()))

Number of rows per file

ad-clicks: 16323
buy-clicks: 2947
game-clicks: 755806
level-events: 1254
team-assignments: 9826
team: 109
user-session: 9250
users: 2393


In [12]:
adclicks_raw.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- txId: integer (nullable = true)
 |-- userSessionId: integer (nullable = true)
 |-- teamId: integer (nullable = true)
 |-- userId: integer (nullable = true)
 |-- adId: integer (nullable = true)
 |-- adCategory: string (nullable = true)



In [13]:
adclicks_raw.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
txId,16323,24613.82925932733,9513.244787359043,5972,39833
userSessionId,16323,22090.77344850824,8780.27306545314,5649,39623
teamId,16323,70.29492127672609,39.63199500554996,2,179
userId,16323,1187.4641916314404,691.5619445575,1,2387
adId,16323,14.654046437542119,8.623599159144106,0,29


In [14]:
buyclicks_raw.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- txId: integer (nullable = true)
 |-- userSessionId: integer (nullable = true)
 |-- team: integer (nullable = true)
 |-- userId: integer (nullable = true)
 |-- buyId: integer (nullable = true)
 |-- price: double (nullable = true)



In [15]:
buyclicks_raw.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
txId,2947,25443.01085850017,9343.543792592367,6004,39842
userSessionId,2947,22884.75229046488,8669.35362676463,5652,39275
team,2947,70.31896844248388,40.27452540199273,2,178
userId,2947,1187.4591109602986,685.7038088341923,1,2387
buyId,2947,2.530709195792331,1.7799870722907862,0,5
price,2947,7.263997285374957,7.076313004712134,1.0,20.0


In [16]:
gameclicks_raw.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- clickId: integer (nullable = true)
 |-- userId: integer (nullable = true)
 |-- userSessionId: integer (nullable = true)
 |-- isHit: integer (nullable = true)
 |-- teamId: integer (nullable = true)
 |-- teamLevel: integer (nullable = true)



In [17]:
gameclicks_raw.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
clickId,755806,377902.5,218182.54311509102,0,755805
userId,755806,1230.0917761965372,689.1974297436151,0,2389
userSessionId,755806,21444.64732484262,8807.252224659778,5648,39790
isHit,755806,0.1103232840173272,0.31329249414120713,0,1
teamId,755806,91.55913289918313,43.862726470866114,2,181
teamLevel,755806,4.561732772695639,2.0500182843454655,1,8


In [18]:
levelevents_raw.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- eventId: integer (nullable = true)
 |-- teamId: integer (nullable = true)
 |-- teamLevel: integer (nullable = true)
 |-- eventType: string (nullable = true)



In [19]:
levelevents_raw.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
eventId,1254,626.5,362.14292758522845,0,1253
teamId,1254,94.33971291866028,44.37661933285508,2,179
teamLevel,1254,4.085326953748006,1.9584409193937282,1,8


In [20]:
teamassignments_raw.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- team: integer (nullable = true)
 |-- userId: integer (nullable = true)
 |-- assignmentId: integer (nullable = true)



In [21]:
teamassignments_raw.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
team,9826,87.46468552819051,43.42936266588799,2,184
userId,9826,1209.5669651943822,689.1245775643522,0,2392
assignmentId,9826,18491.93018522288,9883.61709631016,5000,39860


In [22]:
team_raw.printSchema()

root
 |-- teamId: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- teamCreationTime: timestamp (nullable = true)
 |-- teamEndTime: timestamp (nullable = true)
 |-- strength: double (nullable = true)
 |-- currentLevel: integer (nullable = true)



In [23]:
team_raw.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
teamId,109,56.88990825688074,38.115161620566106,0,171
strength,109,0.4835207572530579,0.2788252003005053,0.00401521168869,0.994851162257
currentLevel,109,1.0,0.0,1,1


In [24]:
usersession_raw.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- userSessionId: integer (nullable = true)
 |-- userId: integer (nullable = true)
 |-- teamId: integer (nullable = true)
 |-- assignmentId: integer (nullable = true)
 |-- sessionType: string (nullable = true)
 |-- teamLevel: integer (nullable = true)
 |-- platformType: string (nullable = true)



In [25]:
usersession_raw.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
userSessionId,9250,17973.522594594593,7953.932718534248,5648,38722
userId,9250,1189.6912432432432,691.0445045861406,0,2389
teamId,9250,72.3907027027027,41.458779047213525,2,179
assignmentId,9250,10288.336,6960.614801008102,5002,37948
teamLevel,9250,4.357405405405405,1.9248621663280168,1,8


In [26]:
users_raw.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- userId: integer (nullable = true)
 |-- nick: string (nullable = true)
 |-- twitter: string (nullable = true)
 |-- dob: string (nullable = true)
 |-- country: string (nullable = true)



In [27]:
users_raw.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
userId,2393,1196.0,690.9439195766904,0,2392


# USER
* COUNTRY - REGION -> pycountry
* dob: date-of-birth -> age

In [28]:
# retrieve country codes from user table
country_codes = users_raw.select('country').toPandas()["country"]

In [29]:
# convert country to continent codes
continent_codes = []
for x in country_codes.tolist():
    try:
        continent_codes.append(pc.country_alpha2_to_continent_code(x))
    except Exception as e:
        print(e)

"Invalid Country Alpha-2 code: 'UM'"
"Invalid Country Alpha-2 code: 'AN'"
"Invalid Country Alpha-2 code: 'AQ'"
"Invalid Country Alpha-2 code: 'VA'"
"Invalid Country Alpha-2 code: 'PN'"
"Invalid Country Alpha-2 code: 'TF'"
"Invalid Country Alpha-2 code: 'AN'"
"Invalid Country Alpha-2 code: 'TF'"
"Invalid Country Alpha-2 code: 'PN'"
"Invalid Country Alpha-2 code: 'AQ'"
"Invalid Country Alpha-2 code: 'AQ'"
"Invalid Country Alpha-2 code: 'UM'"
"Invalid Country Alpha-2 code: 'VA'"
"Invalid Country Alpha-2 code: 'AQ'"
"Invalid Country Alpha-2 code: 'KV'"
"Invalid Country Alpha-2 code: 'EH'"
"Invalid Country Alpha-2 code: 'EH'"
"Invalid Country Alpha-2 code: 'UM'"
"Invalid Country Alpha-2 code: 'EH'"
"Invalid Country Alpha-2 code: 'TF'"
"Invalid Country Alpha-2 code: 'PN'"
"Invalid Country Alpha-2 code: 'AQ'"
"Invalid Country Alpha-2 code: 'PN'"
"Invalid Country Alpha-2 code: 'PN'"
"Invalid Country Alpha-2 code: 'TF'"
"Invalid Country Alpha-2 code: 'AQ'"
"Invalid Country Alpha-2 code: 'PN'"
"

In [30]:
# translate continent code to names
continent_codes_names_dict = {
    'AF': 'Africa',
    'AN': 'Antarctica',
    'AS': 'Asia',
    'EU': 'Europe',
    'NA': 'North America',
    'OC': 'Oceania',
    'SA': 'South America'
}
continent_names = [continent_codes_names_dict[x] for x in continent_codes]

In [31]:
# assess number of lines 
# -> 'dropped' values because some unavailable country codes in pycountry
print(len(country_codes))
print(len(continent_codes))
print(len(continent_names))

2393
2313
2313


In [32]:
# shape into a DataFrame
continent_codes_df = pd.DataFrame(continent_names, 
                                  columns=["continentname"])
continent_codes_df["count"] = 1
continent_codes_df.shape

(2313, 2)

In [33]:
# number of users per continent
# remark: the underlying data as 9-10 users per country
# remark: the data is simulated (cf techincal appendix, github)
continent_codes_df.groupby(["continentname"]).count()

Unnamed: 0_level_0,count
continentname,Unnamed: 1_level_1
Africa,555
Antarctica,20
Asia,537
Europe,457
North America,358
Oceania,239
South America,147


### Age: from date of birth 'dob'

In [34]:
users_raw.select('userId', 'timestamp', 'dob').show(5)

+------+--------------------+----------+
|userId|           timestamp|       dob|
+------+--------------------+----------+
|   442|2012-06-19 14:53:...|1994-07-20|
|   949|2012-06-19 19:29:...|1971-04-22|
|  1654|2012-06-20 19:34:...|1970-04-19|
|  1586|2012-06-21 01:18:...|1965-11-23|
|   599|2012-06-21 15:35:...|1994-08-23|
+------+--------------------+----------+
only showing top 5 rows



In [35]:
# users_raw.select('timestamp', to_date('timestamp')).show(5)

In [36]:
# compute user's age
users_age = users_raw.select('userId', round(months_between(users_raw['timestamp'], 
                                    users_raw['dob'])/12,0).alias("age"))

In [37]:
# check few results:
users_age.show(3)

+------+----+
|userId| age|
+------+----+
|   442|18.0|
|   949|41.0|
|  1654|42.0|
+------+----+
only showing top 3 rows



In [38]:
users_age.describe().toPandas().transpose()

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
userId,2393,1196.0,690.9439195766904,0,2392
age,2393,37.68992895946511,14.12441621737243,14.0,69.0


# BuyClicks
purchase behavior of users
* based on buyId, price
* amount spent: total, yearly, monthly?
* frequency of purchase?
* (average price -> already in Classification: Iphone (buy expensive) vs ...)

In [39]:
buyclicks_raw.printSchema()

root
 |-- timestamp: timestamp (nullable = true)
 |-- txId: integer (nullable = true)
 |-- userSessionId: integer (nullable = true)
 |-- team: integer (nullable = true)
 |-- userId: integer (nullable = true)
 |-- buyId: integer (nullable = true)
 |-- price: double (nullable = true)



In [40]:
buyclicks_raw.show(3)

+--------------------+----+-------------+----+------+-----+-----+
|           timestamp|txId|userSessionId|team|userId|buyId|price|
+--------------------+----+-------------+----+------+-----+-----+
|2016-05-26 15:36:...|6004|         5820|   9|  1300|    2|  3.0|
|2016-05-26 15:36:...|6005|         5775|  35|   868|    4| 10.0|
|2016-05-26 15:36:...|6006|         5679|  97|   819|    5| 20.0|
+--------------------+----+-------------+----+------+-----+-----+
only showing top 3 rows



In [41]:
buyclicks_raw.count()

2947

In [42]:
copy_df = buyclicks_raw
copy_df_noNa = copy_df.na.drop()
copy_df_noNa.count()

2947

In [43]:
# our dataset covers only the year 2016, and the months of May and June.
# -> we can group by months:
# group the purchases per Month, per Day, and count #items bought:
buyclicks_perDay = buyclicks_raw.select(month("timestamp").alias("month"),
                dayofmonth("timestamp").alias("day"), "userId", "buyId",
                     "price").groupBy("month", "day").count()

In [44]:
buyclicks_perDay.show(100)  # show all 22 days of data

+-----+---+-----+
|month|day|count|
+-----+---+-----+
|    5| 26|   22|
|    5| 27|   54|
|    5| 28|   69|
|    5| 29|   66|
|    5| 30|   72|
|    5| 31|   82|
|    6|  1|   97|
|    6|  2|   85|
|    6|  3|  119|
|    6|  4|  135|
|    6|  5|  132|
|    6|  6|  148|
|    6|  7|  134|
|    6|  8|  146|
|    6|  9|  187|
|    6| 10|  166|
|    6| 11|  207|
|    6| 12|  220|
|    6| 13|  199|
|    6| 14|  257|
|    6| 15|  221|
|    6| 16|  129|
+-----+---+-----+



In [45]:
# NUMBER ITEMS per USER
buyclicks_raw.groupBy("userId").count().\
                orderBy("count", ascending=False).show(15)

+------+-----+
|userId|count|
+------+-----+
|  1300|   16|
|  1143|   15|
|  1027|   15|
|  1022|   15|
|  2229|   15|
|  1697|   14|
|   221|   14|
|  2248|   14|
|  1260|   14|
|   355|   13|
|   670|   13|
|  1162|   13|
|    12|   13|
|  1892|   13|
|   643|   13|
+------+-----+
only showing top 15 rows



In [46]:
# NUMBER OF ITEMS per SESSION per USER
buyclicks_raw.groupBy("userId", "userSessionId").count().\
                orderBy(["count", "userId", "userSessionId"],
                        ascending=False).show(5)

+------+-------------+-----+
|userId|userSessionId|count|
+------+-------------+-----+
|  2209|        27354|    6|
|  1623|        20557|    6|
|  1294|        34632|    6|
|   243|        20762|    6|
|  1997|        26703|    5|
+------+-------------+-----+
only showing top 5 rows



In [47]:
# sum price, pivoted by prices
buyclicks_pivotPrice = buyclicks_raw.groupBy("userId", 
                                "buyId").pivot("price").sum("price")

In [48]:
buyclicks_pivotPrice.orderBy("userId").show(15)

+------+-----+----+----+----+----+----+-----+
|userId|buyId| 1.0| 2.0| 3.0| 5.0|10.0| 20.0|
+------+-----+----+----+----+----+----+-----+
|     1|    1|null| 4.0|null|null|null| null|
|     1|    2|null|null|15.0|null|null| null|
|     1|    0| 2.0|null|null|null|null| null|
|     8|    5|null|null|null|null|null| 20.0|
|     8|    4|null|null|null|null|30.0| null|
|     8|    2|null|null| 3.0|null|null| null|
|     9|    4|null|null|null|null|40.0| null|
|     9|    5|null|null|null|null|null| 40.0|
|    10|    1|null| 2.0|null|null|null| null|
|    10|    0| 9.0|null|null|null|null| null|
|    12|    5|null|null|null|null|null|200.0|
|    12|    1|null| 2.0|null|null|null| null|
|    12|    2|null|null| 3.0|null|null| null|
|    12|    4|null|null|null|null|10.0| null|
|    13|    2|null|null|15.0|null|null| null|
+------+-----+----+----+----+----+----+-----+
only showing top 15 rows



In [49]:
# ORDERED table by USER, SESSION, TIME, ITEM
buyclicks_raw.orderBy("userId", "userSessionId", 
                      "timestamp", "buyId").show(15)

+--------------------+-----+-------------+----+------+-----+-----+
|           timestamp| txId|userSessionId|team|userId|buyId|price|
+--------------------+-----+-------------+----+------+-----+-----+
|2016-05-31 06:06:...|11790|        10041|  99|     1|    2|  3.0|
|2016-05-31 06:36:...|11817|        10041|  99|     1|    2|  3.0|
|2016-06-01 08:06:...|13447|        12713|  99|     1|    2|  3.0|
|2016-06-02 04:36:...|14479|        12713|  99|     1|    0|  1.0|
|2016-06-07 05:06:...|22551|        21014|  99|     1|    2|  3.0|
|2016-06-09 11:36:...|25852|        21014|  99|     1|    1|  2.0|
|2016-06-10 03:36:...|27703|        26938|  99|     1|    2|  3.0|
|2016-06-13 05:36:...|32998|        26938|  99|     1|    0|  1.0|
|2016-06-14 09:36:...|35904|        34802|  99|     1|    1|  2.0|
|2016-06-10 10:06:...|28140|        27918| 124|     8|    5| 20.0|
|2016-06-11 14:06:...|29958|        27918| 124|     8|    4| 10.0|
|2016-06-12 23:06:...|32573|        27918| 124|     8|    2|  

In [50]:
# TOTAL SPENT BY USER
buyclicks_total_by_user = buyclicks_raw.groupBy("userId").agg(_sum("price").alias("totalspent")).orderBy("totalspent",
                                    ascending=False)

In [51]:
buyclicks_total_by_user.show(5)

+------+----------+
|userId|totalspent|
+------+----------+
|  2229|     223.0|
|    12|     215.0|
|   471|     202.0|
|   511|     200.0|
|  1027|     189.0|
+------+----------+
only showing top 5 rows



## AD CLICKS

In [52]:
adclicks_raw.count()

16323

In [53]:
adclicks_raw.show(3)

+--------------------+----+-------------+------+------+----+-----------+
|           timestamp|txId|userSessionId|teamId|userId|adId| adCategory|
+--------------------+----+-------------+------+------+----+-----------+
|2016-05-26 15:13:...|5974|         5809|    27|   611|   2|electronics|
|2016-05-26 15:17:...|5976|         5705|    18|  1874|  21|     movies|
|2016-05-26 15:22:...|5978|         5791|    53|  2139|  25|  computers|
+--------------------+----+-------------+------+------+----+-----------+
only showing top 3 rows



In [54]:
# NUMBER OF CLICKED AD per USER:
adclicks_clicks_by_user = adclicks_raw.groupBy("userId").count().\
                        select("userId",_col("count").alias("adClicks"))

In [55]:
adclicks_clicks_by_user.show(3)

+------+--------+
|userId|adClicks|
+------+--------+
|   231|      19|
|  2032|      39|
|   233|      37|
+------+--------+
only showing top 3 rows



In [56]:
adclicks_clicks_by_user.orderBy("adClicks", ascending=False).show(5)

+------+--------+
|userId|adClicks|
+------+--------+
|  2221|      67|
|  2306|      61|
|  2009|      59|
|   807|      58|
|   243|      56|
+------+--------+
only showing top 5 rows



In [57]:
adclicks_raw.groupBy("userSessionId").count().orderBy(["count"],
                                        ascending=False).show(15)

+-------------+-----+
|userSessionId|count|
+-------------+-----+
|        26952|   17|
|        26714|   17|
|        28577|   17|
|        26396|   17|
|        20555|   16|
|        26441|   16|
|        26409|   16|
|        26918|   15|
|        26552|   15|
|        26890|   15|
|        20824|   15|
|        20860|   15|
|        26681|   15|
|        16043|   15|
|        26689|   14|
+-------------+-----+
only showing top 15 rows



# GAME CLICK

In [58]:
gameclicks_raw.count()

755806

In [59]:
gameclicks_raw.show(3)

+--------------------+-------+------+-------------+-----+------+---------+
|           timestamp|clickId|userId|userSessionId|isHit|teamId|teamLevel|
+--------------------+-------+------+-------------+-----+------+---------+
|2016-05-26 15:06:...|    105|  1038|         5916|    0|    25|        1|
|2016-05-26 15:07:...|    154|  1099|         5898|    0|    44|        1|
|2016-05-26 15:07:...|    229|   899|         5757|    0|    71|        1|
+--------------------+-------+------+-------------+-----+------+---------+
only showing top 3 rows



In [60]:
gameclicks_count_by_user = gameclicks_raw.groupBy("userId").count()

In [61]:
gameclicks_hit_by_user = gameclicks_raw.groupBy("userId").sum("isHit")

In [62]:
type(gameclicks_hit_by_user)

pyspark.sql.dataframe.DataFrame

In [63]:
gameclicks_union = gameclicks_count_by_user.join(gameclicks_hit_by_user, 
                                        on=["userId"])

In [64]:
gameclicks_hitDetails_per_user = gameclicks_union.withColumn("hitRatio", 
                round(_col("sum(isHit)") / _col("count"),4))

In [65]:
gameclicks_hitRatio_per_user = gameclicks_hitDetails_per_user.\
                select("userId", "hitRatio")

## CLEANING DATA
* drop na: myDF.na.drop()

In [66]:
# union the ad-, buy-clicks data and age
# 1193 users gameclicked
users_age.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- age: double (nullable = true)



In [67]:
buyclicks_total_by_user.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- totalspent: double (nullable = true)



In [68]:
adclicks_clicks_by_user.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- adClicks: long (nullable = false)



In [69]:
adBuyClicks_age_data = users_age.join(buyclicks_total_by_user,
                                                on="userId").\
                                        join(adclicks_clicks_by_user,
                                            on="userId")

In [70]:
noUserId_adBuyClicks_age_data= adBuyClicks_age_data.select("age", "totalspent", "adClicks")

In [71]:
print("{} rows x {} columns".format(noUserId_adBuyClicks_age_data.count(), 
                                    len(noUserId_adBuyClicks_age_data.columns)))

543 rows x 3 columns


In [72]:
clicks_data_clean = noUserId_adBuyClicks_age_data.na.drop()

In [73]:
clicks_data_clean.count()

543

# FEATURES

In [74]:
from mpl_toolkits.mplot3d import Axes3D

In [75]:
myData = clicks_data_clean.toPandas()

In [76]:
plt.scatter(myData["age"], myData["totalspent"])
plt.show()

In [77]:
plt.scatter(myData["age"], myData["adClicks"])
plt.show()

In [78]:
plt.scatter(myData["adClicks"], myData["totalspent"])
plt.show()

In [79]:
# 3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# x-, y-, z-axis
ax.scatter(myData["age"],
           myData["totalspent"],
           myData["adClicks"])
# set axis labels
ax.set_xlabel('age')
ax.set_ylabel('total spent')
ax.set_zlabel('ad clicks')

plt.show()

In [80]:
# features used
featuresUsed = ['age', 'totalspent', 'adClicks']

In [81]:
# Vector Assembler
assembler = VectorAssembler(inputCols=featuresUsed, 
                            outputCol="features_unscaled")


In [82]:
#transform
assembled = assembler.transform(clicks_data_clean)

In [83]:
# Features using StandardScaler
scaler = StandardScaler(inputCol="features_unscaled",
                       outputCol="features",
                       withStd=True,
                       withMean=True)

In [84]:
# fit, transform
scalerModel = scaler.fit(assembled)
scaledData = scalerModel.transform(assembled)

In [85]:
# data persist
scaledData = scaledData.select("features")
scaledData.persist()

DataFrame[features: vector]

# CLUSTERING

In [86]:
#from sklearn.cluster import KMeans as skKMeans


In [87]:
# KMeans clustering: how many clusters?


In [88]:
# generate 2 clusters
kmeans = KMeans(k=3, seed=1)

In [89]:
# fit
model = kmeans.fit(scaledData)

In [90]:
# transform
transformed = model.transform(scaledData)

### PREPARE DATA: for 3D PLOT

In [91]:
print(transformed.count(), " ", len(transformed.columns))

543   2


In [92]:
transformed.show(5)

+--------------------+----------+
|            features|prediction|
+--------------------+----------+
|[-0.6230147439545...|         0|
|[-1.2514302349155...|         1|
|[1.26223172892838...|         1|
|[0.70364018140750...|         2|
|[1.33205567236848...|         1|
+--------------------+----------+
only showing top 5 rows



In [93]:
dataClusters = transformed.toPandas()
dataClusters.shape

(543, 2)

In [94]:
dataClusters.head(5)

Unnamed: 0,features,prediction
0,"[-0.623014743955, 0.57372866686, -0.681757031926]",0
1,"[-1.25143023492, -0.469410302344, 0.63261920928]",1
2,"[1.26223172893, -0.27533793598, 0.50118158516]",1
3,"[0.703640181408, 1.35001813231, 0.304025148979]",2
4,"[1.33205567237, 0.331138208906, 0.764056833401]",1


In [95]:
dataClusters_x = []
dataClusters_y = []
dataClusters_z = []
for point in dataClusters.features:
    dataClusters_x.append(point[0])
    dataClusters_y.append(point[1])
    dataClusters_z.append(point[2])

In [96]:
reshapedDataClusters = pd.DataFrame(data={'x': dataClusters_x,
                                         'y': dataClusters_y,
                                         'z': dataClusters_z,
                        'prediction': dataClusters.prediction})

In [97]:
index0 = reshapedDataClusters.prediction==0
index1 = reshapedDataClusters.prediction==1
index2 = reshapedDataClusters.prediction==2

In [98]:
# CENTER OF CLUSTERS

In [99]:
# get center
centers = model.clusterCenters()

In [100]:
# print center
print(centers)

[array([-0.17525485, -0.42214367, -0.92011812]), array([ 0.26174513, -0.21143332,  0.80604385]), array([-0.16496967,  2.02733069,  0.77019059])]


# ANALYSIS

In [101]:
# analyze of clusters:
# first center is located at: array([ 0.3270975 ,  0.49940114,  0.82038724])
# second center at: array([-0.31410807, -0.47956933, -0.78780869])

In [102]:
figClus = plt.figure()
axClus = figClus.add_subplot(111, projection='3d')
axClus.scatter(reshapedDataClusters.x[index0],
               reshapedDataClusters.y[index0],
               reshapedDataClusters.z[index0],
              c='blue', label='cluster-1')
axClus.scatter(reshapedDataClusters.x[index1],
               reshapedDataClusters.y[index1],
               reshapedDataClusters.z[index1],
              c='orange', label='cluster-2')
axClus.scatter(reshapedDataClusters.x[index2],
               reshapedDataClusters.y[index2],
               reshapedDataClusters.z[index2],
              c='red', label='cluster-3')

# set labels
axClus.set_xlabel('age')
axClus.set_ylabel('total spent')
axClus.set_zlabel('ad clicks')

# clusters' center
axClus.scatter(centers[0][0],
               centers[0][1],
               centers[0][2],
              c='yellow', label='centroid-1')
axClus.scatter(centers[1][0],
               centers[1][1],
               centers[1][2],
              c='green', label='centroid-2')
axClus.scatter(centers[2][0],
               centers[2][1],
               centers[2][2],
              c='black', label='centroid-3')

plt.show()