In [1]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext, SparkSession, Row
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import functions as func
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier,DecisionTreeClassificationModel
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline

In [2]:
max_bins_1 = 32
max_bins_2 = 50
max_depth = 10
max_iterations = 20

In [3]:
### Initialize streaming context
conf = SparkConf()\
                .setMaster("local[2]")\
                .setAppName("MobileAnalyticsDecisionTreeClassifier")\
                .set("spark.executor.memory", "2g")\
                .set("spark.driver.memory", "2g")
sc = SparkContext(conf=conf)
sc.setLogLevel("ERROR")
sqlContext = SQLContext(sc)
spark = SparkSession.builder.appName("spark play").getOrCreate()

In [24]:
train = spark.read.csv("data/features/train/part-00000-f4d4e2ca-2586-4c69-8cdf-70682260c76f-c000.csv", header=True, mode="DROPMALFORMED", inferSchema='true', encoding="utf-8").persist()
print train.count()
train.show()

186716
+--------------------+-----------------------+---------------------+-------------------+---------+---------+-------------+--------+--------+---------+---------+---------+---------+---------+---------+---------+-------------+-------------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+-----+------------+-----------------------------+---------------------------+------------------------+---------------+---------------+-------------------+---------------+--------------+--------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+-------------------+-------------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+----------

In [25]:
test = spark.read.csv("data/features/test/part-00000-2f2a0753-fe6e-4506-8782-d6e223f89a6d-c000.csv", header=True, mode="DROPMALFORMED", inferSchema='true', encoding="utf-8").persist()
print test.count()
test.show()

186716
+--------------------+-----------------------+---------------------+-------------------+---------+---------+-------------+--------+--------+---------+---------+---------+---------+---------+---------+---------+-------------+-------------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+-----+------------+-----------------------------+---------------------------+------------------------+---------------+---------------+-------------------+---------------+--------------+--------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+-------------------+-------------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+----------

In [26]:
assembler = VectorAssembler(
                inputCols=[
                    "device_index",
                    "events_per_device_count_index",
                    "apps_per_device_count_index",
                    "apps_per_event_avg_index",
                    "lat_count_index",
                    "lng_count_index",
                    "lat_lng_count_index",
                    "min_hour_index",
                    "max_hour_index",
                    "mon_count_index",
                    "tue_count_index",
                    "wed_count_index",
                    "thu_count_index",
                    "fri_count_index",
                    "sat_count_index",
                    "sun_count_index",
                    "weekend_count_index",
                    "weekday_count_index",
                    "am_count_index",
                    "pm_count_index",
                    "h0_count_index",
                    "h1_count_index",
                    "h2_count_index",
                    "h3_count_index",
                    "h4_count_index",
                    "h5_count_index",
                    "h6_count_index",
                    "h7_count_index",
                    "h8_count_index",
                    "h9_count_index",
                    "h10_count_index",
                    "h11_count_index",
                    "h12_count_index",
                    "h13_count_index",
                    "h14_count_index",
                    "h15_count_index",
                    "h16_count_index",
                    "h17_count_index",
                    "h18_count_index",
                    "h19_count_index",
                    "h20_count_index",
                    "h21_count_index",
                    "h22_count_index",
                    "h23_count_index"
                ],
                outputCol="features"
            )

In [27]:
dt = DecisionTreeClassifier()\
        .setFeaturesCol("features")\
        .setLabelCol("label")\
        .setPredictionCol("prediction")\
        .setProbabilityCol("probability")\
        .setRawPredictionCol("confidence")

In [28]:
pipeline = Pipeline(stages=[assembler, dt])

In [29]:
params = ParamGridBuilder()\
            .addGrid(dt.maxBins, [max_bins_1, max_bins_2])\
            .addGrid(dt.maxDepth, [max_depth])\
            .addGrid(dt.impurity, ["entropy", "gini"]) \
            .build() 

In [30]:
# Iterate and choose the best fit model  
cross_validator = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=params,
                          evaluator=MulticlassClassificationEvaluator(),
                          numFolds=5)  # use 3+ folds in practice         
                
cross_validator_model =  cross_validator.fit(train)

In [32]:
output = cross_validator_model.transform(test)\
                        .select("device_index","probability")

In [33]:
print output.show()

+------------+-----------+
|device_index|probability|
+------------+-----------+
|    113266.0|  [0.0,1.0]|
|    154649.0|  [0.0,1.0]|
|     61033.0|  [0.0,1.0]|
|    101658.0|  [0.0,1.0]|
|    124814.0|  [0.0,1.0]|
|      4028.0|  [0.0,1.0]|
|    125062.0|  [0.0,1.0]|
|    149173.0|  [0.0,1.0]|
|     58289.0|  [0.0,1.0]|
|     98247.0|  [0.0,1.0]|
|     84901.0|  [0.0,1.0]|
|     98388.0|  [1.0,0.0]|
|    115560.0|  [0.0,1.0]|
|     32766.0|  [0.0,1.0]|
|    127697.0|  [0.0,1.0]|
|     26220.0|  [0.0,1.0]|
|     93321.0|  [1.0,0.0]|
|    117316.0|  [0.0,1.0]|
|      6756.0|  [0.0,1.0]|
|    150734.0|  [0.0,1.0]|
+------------+-----------+
only showing top 20 rows

None


In [31]:
output.repartition(1).write.option("header", True).csv("data/output/dt")