In [1]:
from pyspark import SparkContext
from pyspark.sql import SQLContext

In [2]:
sqlContext = SQLContext(sc)

### Loading data

In [6]:
model_data = sqlContext.read.json('./data/model_data')

In [7]:
model_data.printSchema()

root
 |-- area_name: string (nullable = true)
 |-- crime_code: string (nullable = true)
 |-- is_holiday: boolean (nullable = true)
 |-- time_bucket: long (nullable = true)



In [8]:
model_data.head()

Row(area_name=u'Central', crime_code=u'946', is_holiday=None, time_bucket=None)

In [11]:
model_data.registerTempTable('crime')

In [12]:
sqlContext.sql('select count(distinct crime_code) from crime').show()

+---+
|_c0|
+---+
|132|
+---+



### Baseline Modeling

In [32]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [33]:
stringIndexer = StringIndexer(inputCol="area_name", outputCol="indexed_area")
model = stringIndexer.fit(model_data)
indexed = model.transform(model_data)
encoder = OneHotEncoder(dropLast=False, inputCol="indexed_area", outputCol="areaVec")
encoded = encoder.transform(indexed)
encoded.select("*").show()

+---------+----------+----------+-----------+------------+---------------+
|area_name|crime_code|is_holiday|time_bucket|indexed_area|        areaVec|
+---------+----------+----------+-----------+------------+---------------+
|  Central|       946|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       330|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       442|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       442|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       626|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       442|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       442|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       648|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       740|      null|       null|         8.0| (21,[8],[1.0])|
|  Central|       745|      null|       null|         8.0| (21,[8],[1.0])|
|  Rampart|       626|   

In [29]:
evaluator = MulticlassClassificationEvaluator(predictionCol=['indexed', 'is_holiday', 'time_bucket'], labelCol='crime_code')

In [35]:
evaluator.evaluate(encoded)

Py4JJavaError: An error occurred while calling o229.evaluate.
: java.lang.ClassCastException: java.util.ArrayList cannot be cast to java.lang.String
	at org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate(MulticlassClassificationEvaluator.scala:74)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:497)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:231)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:381)
	at py4j.Gateway.invoke(Gateway.java:259)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:209)
	at java.lang.Thread.run(Thread.java:745)
