In [5]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext, SparkSession, Row
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import functions as func
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier,DecisionTreeClassificationModel
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline

In [6]:
max_bins_1 = 32
max_bins_2 = 50
max_depth = 10
max_iterations = 20

In [7]:
### Initialize streaming context
conf = SparkConf()\
                .setMaster("local[2]")\
                .setAppName("MobileAnalyticsDecisionTreeClassifier")\
                .set("spark.executor.memory", "2g")\
                .set("spark.driver.memory", "2g")
sc = SparkContext(conf=conf)
sc.setLogLevel("ERROR")
sqlContext = SQLContext(sc)
spark = SparkSession.builder.appName("spark play").getOrCreate()

ValueError: Cannot run multiple SparkContexts at once; existing SparkContext(app=MobileAnalyticsDecisionTreeClassifier, master=local[2]) created by __init__ at <ipython-input-3-37e0f190f2bc>:3 

In [8]:
train = spark.read.csv("data/features/train/part-00000-152f8c78-aede-459c-8067-8612b82b1bcc-c000.csv", header=True, mode="DROPMALFORMED", inferSchema='true', encoding="utf-8").persist()
print train.count()
train.show()

1266952
+------------+--------+-----------------------+--------+--------+---------+---------+---------+---------+---------+---------+---------+-------------+-------------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+--------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------------------+------------------+----------------------------+-------------------------+---------------------------+----------------------+-------------------------+--------------------------+-----------------+------------------+
|device_index|   label|events_per_device_count|min_hour|max_hour|mon_count|tue_count|wed_count|thu_count|fri_count|sat_count|sun_count|weekend_count|weekday_count|am_count|pm_count|h0_count|h1_count|h2_count|h3_count|h4_count|h5_count|h6_count|h7_count|h8_count|h9_count|h10_count|h11_count|h12_count|h13_count|h14_count|h15_count|h16_count|h1

In [9]:
test = spark.read.csv("data/features/test/part-00000-152f8c78-aede-459c-8067-8612b82b1bcc-c000.csv", header=True, mode="DROPMALFORMED", inferSchema='true', encoding="utf-8").persist()
print test.count()
test.show()

AnalysisException: u'Path does not exist: file:/Users/swaite/Stirling/CSIE-63/final-project/data/features/test/part-00000-152f8c78-aede-459c-8067-8612b82b1bcc-c000.csv;'

In [10]:
assembler = VectorAssembler(
                inputCols=[
                            "device_index", 
                             "events_per_device_count",
                             "min_hour",
                             "max_hour",
                             "mon_count",
                             "tue_count",
                             "wed_count",
                             "thu_count",
                             "fri_count",
                             "sat_count",
                             "sun_count",
                             "weekend_count",
                             "weekday_count",
                             "am_count",
                             "pm_count",
                             "h0_count",
                             "h1_count",
                             "h2_count",
                             "h3_count",
                             "h4_count",
                             "h5_count",
                             "h6_count",
                             "h7_count",
                             "h8_count",
                             "h9_count",
                             "h10_count",
                             "h11_count",
                             "h12_count",
                             "h13_count",
                             "h14_count",
                             "h15_count",
                             "h16_count",
                             "h17_count",
                             "h18_count",
                             "h19_count",
                             "h20_count",
                             "h21_count",
                             "h22_count",
                             "h23_count",
                             "apps_per_device_count",
                             "apps_per_event_avg",
                             "apps_active_per_device_count",
                             "apps_active_per_event_avg",
                             "categories_per_device_count", 
                             "categories_per_app_avg",
                             "latitude_per_device_count", 
                             "longitude_per_device_count",
                             "phone_brand_index",
                             "device_model_index"
                ],
                outputCol="features"
            )

In [23]:
dt = DecisionTreeClassifier()\
        .setFeaturesCol("features")\
        .setLabelCol("phone_brand_index")\
        .setPredictionCol("prediction")\
        .setProbabilityCol("probability")\
        .setRawPredictionCol("confidence")

In [24]:
pipeline = Pipeline(stages=[assembler, dt])

In [25]:
params = ParamGridBuilder()\
            .addGrid(dt.maxBins, [max_bins_1, max_bins_2])\
            .addGrid(dt.maxDepth, [max_depth])\
            .addGrid(dt.impurity, ["entropy", "gini"]) \
            .build() 

In [26]:
# Iterate and choose the best fit model  
cross_validator = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=params,
                          evaluator=MulticlassClassificationEvaluator(),
                          numFolds=5)  # use 3+ folds in practice         
                
cross_validator_model =  cross_validator.fit(train)

IllegalArgumentException: u'requirement failed: Classifier inferred 127 from label values in column DecisionTreeClassifier_49c891beef8fbb282861__labelCol, but this exceeded the max numClasses (100) allowed to be inferred from values.  To avoid this error for labels with > 100 classes, specify numClasses explicitly in the metadata; this can be done by applying StringIndexer to the label column.'