In [4]:
data = sc.textFile("iris num.csv")

Display 2 elements to make sure data is loaded

In [5]:
data.take(2)

[u'5.1,3.5,1.4,0.2,0', u'4.9,3,1.4,0.2,0']

In [6]:
from numpy import array

In [7]:
pdata = data.map(lambda line : array([float(x) for x in line.split(",")]))

In [8]:
pdata.take(2)

[array([5.1, 3.5, 1.4, 0.2, 0. ]), array([4.9, 3. , 1.4, 0.2, 0. ])]

Prepare the data for spark mllib algorithms. Those algorithms require the data to be labeled input / output by encapsulating it in LabeledPoint objects

In [9]:
from pyspark.mllib.regression import LabeledPoint

In [10]:
def parse(l):
    return LabeledPoint(l[4], l[0:4])

In [11]:
fdata = pdata.map(lambda l: parse(l))

In [12]:
fdata.take(2)

[LabeledPoint(0.0, [5.1,3.5,1.4,0.2]), LabeledPoint(0.0, [4.9,3.0,1.4,0.2])]

Divide the data into training and test

In [53]:
(trainingData, testData) = fdata.randomSplit([0.8, 0.2])

Use the decision tree classifier to train the model

In [54]:
from pyspark.mllib.tree import DecisionTree

In [55]:
model = DecisionTree.trainClassifier(trainingData, numClasses=3, categoricalFeaturesInfo={})

In [56]:
predictions = model.predict(testData.map(lambda row: row.features))

Create Confusion Matrix to evaluate the accuracy of the model

We create a matrix containing the test labels as a first column (real values) and predicted values as second column

In [57]:
predictionsAndLabels = testData.map(lambda labeledpoint: labeledpoint.label).zip(predictions)

In [58]:
predictionsAndLabels.collect()

[(0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (0.0, 0.0),
 (1.0, 2.0),
 (1.0, 1.0),
 (1.0, 1.0),
 (1.0, 1.0),
 (1.0, 1.0),
 (1.0, 1.0),
 (1.0, 1.0),
 (1.0, 1.0),
 (1.0, 2.0),
 (1.0, 1.0),
 (2.0, 2.0),
 (2.0, 2.0),
 (2.0, 2.0),
 (2.0, 2.0),
 (2.0, 2.0),
 (2.0, 2.0),
 (2.0, 2.0),
 (2.0, 2.0)]

## Evaluate the accuracy of the model

In [59]:
from pyspark.mllib.evaluation import MulticlassMetrics

In [60]:
metrics = MulticlassMetrics(predictionsAndLabels)

### Display Model Accuracy

In [61]:
accuracy = metrics.accuracy

In [62]:
print(accuracy)

0.935483870968


In [63]:
confusionMatrix = metrics.confusionMatrix()

In [64]:
confusionMatrix.toArray()

array([[13.,  0.,  0.],
       [ 0.,  8.,  0.],
       [ 0.,  2.,  8.]])

### Display Model Precision

In [65]:
print("Recall for Class #1", metrics.recall(0))
print("Recall for Class #2", metrics.recall(1))
print("Recall for Class #3", metrics.recall(2))

('Recall for Class #1', 1.0)
('Recall for Class #2', 1.0)
('Recall for Class #3', 0.8)


In [66]:
print("Precision for Class #1", metrics.precision(0))
print("Precision for Class #2", metrics.precision(1))
print("Precision for Class #3", metrics.precision(2))

('Precision for Class #1', 1.0)
('Precision for Class #2', 0.8)
('Precision for Class #3', 1.0)


In [71]:
metrics.precision()

0.9354838709677419

In [72]:
metrics.recall()

0.9354838709677419

In [73]:
metrics.fMeasure()

0.9354838709677419

## Use Spark Dataframe

In [68]:
df = spark.read.load("Iris1.csv", format="csv", sep=",", inferSchema="true", header="true")

In [69]:
df.show()

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| setosa|
|         4.9|        3.0|         1.4|        0.2| setosa|
|         4.7|        3.2|         1.3|        0.2| setosa|
|         4.6|        3.1|         1.5|        0.2| setosa|
|         5.0|        3.6|         1.4|        0.2| setosa|
|         5.4|        3.9|         1.7|        0.4| setosa|
|         4.6|        3.4|         1.4|        0.3| setosa|
|         5.0|        3.4|         1.5|        0.2| setosa|
|         4.4|        2.9|         1.4|        0.2| setosa|
|         4.9|        3.1|         1.5|        0.1| setosa|
|         5.4|        3.7|         1.5|        0.2| setosa|
|         4.8|        3.4|         1.6|        0.2| setosa|
|         4.8|        3.0|         1.4|        0.1| setosa|
|         4.3|        3.0|         1.1| 

In [70]:
df.show(2)

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| setosa|
|         4.9|        3.0|         1.4|        0.2| setosa|
+------------+-----------+------------+-----------+-------+
only showing top 2 rows



### Use Spark sqlContext to load the data as a data frame

In [74]:
df1 = sqlContext.read.format("csv").options(header='true', inferSchema='true').load('Iris1.csv')

In [75]:
df1.show()

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|species|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| setosa|
|         4.9|        3.0|         1.4|        0.2| setosa|
|         4.7|        3.2|         1.3|        0.2| setosa|
|         4.6|        3.1|         1.5|        0.2| setosa|
|         5.0|        3.6|         1.4|        0.2| setosa|
|         5.4|        3.9|         1.7|        0.4| setosa|
|         4.6|        3.4|         1.4|        0.3| setosa|
|         5.0|        3.4|         1.5|        0.2| setosa|
|         4.4|        2.9|         1.4|        0.2| setosa|
|         4.9|        3.1|         1.5|        0.1| setosa|
|         5.4|        3.7|         1.5|        0.2| setosa|
|         4.8|        3.4|         1.6|        0.2| setosa|
|         4.8|        3.0|         1.4|        0.1| setosa|
|         4.3|        3.0|         1.1| 

Print row count

In [77]:
print('Row count: %s' % df.count())

Row count: 150


Filter the data based on a condition

In [78]:
df.filter(df['petal_length'] > 6).show()

+------------+-----------+------------+-----------+---------+
|sepal_length|sepal_width|petal_length|petal_width|  species|
+------------+-----------+------------+-----------+---------+
|         7.6|        3.0|         6.6|        2.1|virginica|
|         7.3|        2.9|         6.3|        1.8|virginica|
|         7.2|        3.6|         6.1|        2.5|virginica|
|         7.7|        3.8|         6.7|        2.2|virginica|
|         7.7|        2.6|         6.9|        2.3|virginica|
|         7.7|        2.8|         6.7|        2.0|virginica|
|         7.4|        2.8|         6.1|        1.9|virginica|
|         7.9|        3.8|         6.4|        2.0|virginica|
|         7.7|        3.0|         6.1|        2.3|virginica|
+------------+-----------+------------+-----------+---------+



In [81]:
df.groupBy(df['species']).count().show()

+----------+-----+
|   species|count|
+----------+-----+
| virginica|   50|
|versicolor|   50|
|    setosa|   50|
+----------+-----+



Print the first 10 elements

In [85]:
df.head(10)

[Row(sepal_length=5.1, sepal_width=3.5, petal_length=1.4, petal_width=0.2, species=u'setosa'),
 Row(sepal_length=4.9, sepal_width=3.0, petal_length=1.4, petal_width=0.2, species=u'setosa'),
 Row(sepal_length=4.7, sepal_width=3.2, petal_length=1.3, petal_width=0.2, species=u'setosa'),
 Row(sepal_length=4.6, sepal_width=3.1, petal_length=1.5, petal_width=0.2, species=u'setosa'),
 Row(sepal_length=5.0, sepal_width=3.6, petal_length=1.4, petal_width=0.2, species=u'setosa'),
 Row(sepal_length=5.4, sepal_width=3.9, petal_length=1.7, petal_width=0.4, species=u'setosa'),
 Row(sepal_length=4.6, sepal_width=3.4, petal_length=1.4, petal_width=0.3, species=u'setosa'),
 Row(sepal_length=5.0, sepal_width=3.4, petal_length=1.5, petal_width=0.2, species=u'setosa'),
 Row(sepal_length=4.4, sepal_width=2.9, petal_length=1.4, petal_width=0.2, species=u'setosa'),
 Row(sepal_length=4.9, sepal_width=3.1, petal_length=1.5, petal_width=0.1, species=u'setosa')]

### Execute SQL Query

In [86]:
df.registerTempTable('mytable')

In [87]:
distinct_classes = sqlContext.sql("select distinct species from mytable")
distinct_classes.show()

+----------+
|   species|
+----------+
| virginica|
|versicolor|
|    setosa|
+----------+

