In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").appName("PCA").getOrCreate()

In [9]:
iris = spark.read.options(delimiter=",", header="true", inferSchema="true").csv("iris.csv")
iris.show()

+---------+--------+---------+--------+-----------+
|sp_length|sp_width|pt_length|pt_width|    species|
+---------+--------+---------+--------+-----------+
|      5.1|     3.5|      1.4|     0.2|Iris-setosa|
|      4.9|     3.0|      1.4|     0.2|Iris-setosa|
|      4.7|     3.2|      1.3|     0.2|Iris-setosa|
|      4.6|     3.1|      1.5|     0.2|Iris-setosa|
|      5.0|     3.6|      1.4|     0.2|Iris-setosa|
|      5.4|     3.9|      1.7|     0.4|Iris-setosa|
|      4.6|     3.4|      1.4|     0.3|Iris-setosa|
|      5.0|     3.4|      1.5|     0.2|Iris-setosa|
|      4.4|     2.9|      1.4|     0.2|Iris-setosa|
|      4.9|     3.1|      1.5|     0.1|Iris-setosa|
|      5.4|     3.7|      1.5|     0.2|Iris-setosa|
|      4.8|     3.4|      1.6|     0.2|Iris-setosa|
|      4.8|     3.0|      1.4|     0.1|Iris-setosa|
|      4.3|     3.0|      1.1|     0.1|Iris-setosa|
|      5.8|     4.0|      1.2|     0.2|Iris-setosa|
|      5.7|     4.4|      1.5|     0.4|Iris-setosa|
|      5.4| 

In [10]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(inputCols=["sp_length", "sp_width", "pt_length", "pt_width"], outputCol="features")
data = assembler.transform(iris)
data.show()

+---------+--------+---------+--------+-----------+-----------------+
|sp_length|sp_width|pt_length|pt_width|    species|         features|
+---------+--------+---------+--------+-----------+-----------------+
|      5.1|     3.5|      1.4|     0.2|Iris-setosa|[5.1,3.5,1.4,0.2]|
|      4.9|     3.0|      1.4|     0.2|Iris-setosa|[4.9,3.0,1.4,0.2]|
|      4.7|     3.2|      1.3|     0.2|Iris-setosa|[4.7,3.2,1.3,0.2]|
|      4.6|     3.1|      1.5|     0.2|Iris-setosa|[4.6,3.1,1.5,0.2]|
|      5.0|     3.6|      1.4|     0.2|Iris-setosa|[5.0,3.6,1.4,0.2]|
|      5.4|     3.9|      1.7|     0.4|Iris-setosa|[5.4,3.9,1.7,0.4]|
|      4.6|     3.4|      1.4|     0.3|Iris-setosa|[4.6,3.4,1.4,0.3]|
|      5.0|     3.4|      1.5|     0.2|Iris-setosa|[5.0,3.4,1.5,0.2]|
|      4.4|     2.9|      1.4|     0.2|Iris-setosa|[4.4,2.9,1.4,0.2]|
|      4.9|     3.1|      1.5|     0.1|Iris-setosa|[4.9,3.1,1.5,0.1]|
|      5.4|     3.7|      1.5|     0.2|Iris-setosa|[5.4,3.7,1.5,0.2]|
|      4.8|     3.4|

In [11]:
from pyspark.ml.feature import PCA

pca = PCA(k=2, inputCol="features", outputCol="pcaFeatures")
model = pca.fit(data)

result = model.transform(data).select("features", "pcaFeatures", "species")
result.show(truncate=False)

+-----------------+----------------------------------------+-----------+
|features         |pcaFeatures                             |species    |
+-----------------+----------------------------------------+-----------+
|[5.1,3.5,1.4,0.2]|[-2.827135972679027,-5.641331045573321] |Iris-setosa|
|[4.9,3.0,1.4,0.2]|[-2.7959524821488437,-5.145166883252896]|Iris-setosa|
|[4.7,3.2,1.3,0.2]|[-2.6215235581650584,-5.177378121203909]|Iris-setosa|
|[4.6,3.1,1.5,0.2]|[-2.7649059004742402,-5.003599415056946]|Iris-setosa|
|[5.0,3.6,1.4,0.2]|[-2.7827501159516603,-5.648648294377395]|Iris-setosa|
|[5.4,3.9,1.7,0.4]|[-3.231445736773378,-6.062506444034077] |Iris-setosa|
|[4.6,3.4,1.4,0.3]|[-2.690452415602345,-5.232619219784267] |Iris-setosa|
|[5.0,3.4,1.5,0.2]|[-2.8848611044591563,-5.485129079769225]|Iris-setosa|
|[4.4,2.9,1.4,0.2]|[-2.6233845324473406,-4.743925704477345]|Iris-setosa|
|[4.9,3.1,1.5,0.1]|[-2.8374984110638537,-5.208032027056187]|Iris-setosa|
|[5.4,3.7,1.5,0.2]|[-3.004816308444072,-5.966658744

In [14]:
from pyspark.ml.feature import StringIndexer

indexer = StringIndexer(inputCol="species", outputCol="label")
indexed_result = indexer.fit(result).transform(result)
indexed_result.groupBy("species", "label").count().show()

+---------------+-----+-----+
|        species|label|count|
+---------------+-----+-----+
|    Iris-setosa|  0.0|   50|
| Iris-virginica|  2.0|   50|
|Iris-versicolor|  1.0|   50|
+---------------+-----+-----+



In [15]:
%matplotlib notebook
import matplotlib.pyplot as plt

collected = indexed_result.select("pcaFeatures", "label").collect()

dots = [item[0] for item in collected]
colors = [item[1] for item in collected]

plt.scatter(*zip(*dots), c=colors)
plt.show()

TypeError: scatter() got multiple values for argument 'c'

In [21]:
from pyspark.ml.classification import LogisticRegression

seed = 20

train_df, test_df = indexed_result.select("pcaFeatures", "label").toDF("features", "label").randomSplit([0.7, 0.3], seed=seed)
lr = LogisticRegression()
lr_model = lr.fit(train_df)
prediction = lr_model.transform(test_df)
result = prediction.collect()
accuracy = [r.prediction == r.label for r in result]

print(sum(accuracy), "/", len(accuracy))

48 / 50
