# K-meansによるクラスタリング

#### irisのデータを用いて、その種類のクラスタリングを試みる

In [1]:
import numpy as np
import pandas as pd

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local").appName("kmeans").getOrCreate()

In [3]:
filename = "./data/iris.csv"
data = spark.read.csv(filename, header=True, inferSchema=True, sep=",")
data.show()

+------------+-----------+------------+-----------+-------+
|sepal.length|sepal.width|petal.length|petal.width|variety|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| Setosa|
|         4.9|        3.0|         1.4|        0.2| Setosa|
|         4.7|        3.2|         1.3|        0.2| Setosa|
|         4.6|        3.1|         1.5|        0.2| Setosa|
|         5.0|        3.6|         1.4|        0.2| Setosa|
|         5.4|        3.9|         1.7|        0.4| Setosa|
|         4.6|        3.4|         1.4|        0.3| Setosa|
|         5.0|        3.4|         1.5|        0.2| Setosa|
|         4.4|        2.9|         1.4|        0.2| Setosa|
|         4.9|        3.1|         1.5|        0.1| Setosa|
|         5.4|        3.7|         1.5|        0.2| Setosa|
|         4.8|        3.4|         1.6|        0.2| Setosa|
|         4.8|        3.0|         1.4|        0.1| Setosa|
|         4.3|        3.0|         1.1| 

In [4]:
data = data.withColumnRenamed("sepal.length", "sepal_length") \
           .withColumnRenamed("sepal.width", "sepal_width") \
           .withColumnRenamed("petal.length", "petal_length") \
           .withColumnRenamed("petal.width", "petal_width")
data.show()

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|variety|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| Setosa|
|         4.9|        3.0|         1.4|        0.2| Setosa|
|         4.7|        3.2|         1.3|        0.2| Setosa|
|         4.6|        3.1|         1.5|        0.2| Setosa|
|         5.0|        3.6|         1.4|        0.2| Setosa|
|         5.4|        3.9|         1.7|        0.4| Setosa|
|         4.6|        3.4|         1.4|        0.3| Setosa|
|         5.0|        3.4|         1.5|        0.2| Setosa|
|         4.4|        2.9|         1.4|        0.2| Setosa|
|         4.9|        3.1|         1.5|        0.1| Setosa|
|         5.4|        3.7|         1.5|        0.2| Setosa|
|         4.8|        3.4|         1.6|        0.2| Setosa|
|         4.8|        3.0|         1.4|        0.1| Setosa|
|         4.3|        3.0|         1.1| 

In [5]:
data.printSchema()

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- variety: string (nullable = true)



In [6]:
data.dtypes

[('sepal_length', 'double'),
 ('sepal_width', 'double'),
 ('petal_length', 'double'),
 ('petal_width', 'double'),
 ('variety', 'string')]

In [7]:
try:
    summary_data = data.summary()
    summary_data.show()
except Exception as e:
    print("Error occurred:", e)

+-------+------------------+-------------------+------------------+------------------+---------+
|summary|      sepal_length|        sepal_width|      petal_length|       petal_width|  variety|
+-------+------------------+-------------------+------------------+------------------+---------+
|  count|               150|                150|               150|               150|      150|
|   mean| 5.843333333333335|  3.057333333333334|3.7580000000000027| 1.199333333333334|     NULL|
| stddev|0.8280661279778637|0.43586628493669793|1.7652982332594662|0.7622376689603467|     NULL|
|    min|               4.3|                2.0|               1.0|               0.1|   Setosa|
|    25%|               5.1|                2.8|               1.6|               0.3|     NULL|
|    50%|               5.8|                3.0|               4.3|               1.3|     NULL|
|    75%|               6.4|                3.3|               5.1|               1.8|     NULL|
|    max|               7.9|  

In [8]:
data.groupBy("variety").count().show()

+----------+-----+
|   variety|count|
+----------+-----+
| Virginica|   50|
|    Setosa|   50|
|Versicolor|   50|
+----------+-----+



In [9]:
df = data.select("sepal_length", "sepal_width", "petal_length", "petal_width", "variety")
df.show()

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|variety|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| Setosa|
|         4.9|        3.0|         1.4|        0.2| Setosa|
|         4.7|        3.2|         1.3|        0.2| Setosa|
|         4.6|        3.1|         1.5|        0.2| Setosa|
|         5.0|        3.6|         1.4|        0.2| Setosa|
|         5.4|        3.9|         1.7|        0.4| Setosa|
|         4.6|        3.4|         1.4|        0.3| Setosa|
|         5.0|        3.4|         1.5|        0.2| Setosa|
|         4.4|        2.9|         1.4|        0.2| Setosa|
|         4.9|        3.1|         1.5|        0.1| Setosa|
|         5.4|        3.7|         1.5|        0.2| Setosa|
|         4.8|        3.4|         1.6|        0.2| Setosa|
|         4.8|        3.0|         1.4|        0.1| Setosa|
|         4.3|        3.0|         1.1| 

In [10]:
# Assemble
from pyspark.ml.feature import VectorAssembler
assemble = VectorAssembler(inputCols=["sepal_length", "sepal_width", "petal_length", "petal_width"], outputCol="features")
pred = assemble.transform(df)

In [11]:
pred.show()

+------------+-----------+------------+-----------+-------+-----------------+
|sepal_length|sepal_width|petal_length|petal_width|variety|         features|
+------------+-----------+------------+-----------+-------+-----------------+
|         5.1|        3.5|         1.4|        0.2| Setosa|[5.1,3.5,1.4,0.2]|
|         4.9|        3.0|         1.4|        0.2| Setosa|[4.9,3.0,1.4,0.2]|
|         4.7|        3.2|         1.3|        0.2| Setosa|[4.7,3.2,1.3,0.2]|
|         4.6|        3.1|         1.5|        0.2| Setosa|[4.6,3.1,1.5,0.2]|
|         5.0|        3.6|         1.4|        0.2| Setosa|[5.0,3.6,1.4,0.2]|
|         5.4|        3.9|         1.7|        0.4| Setosa|[5.4,3.9,1.7,0.4]|
|         4.6|        3.4|         1.4|        0.3| Setosa|[4.6,3.4,1.4,0.3]|
|         5.0|        3.4|         1.5|        0.2| Setosa|[5.0,3.4,1.5,0.2]|
|         4.4|        2.9|         1.4|        0.2| Setosa|[4.4,2.9,1.4,0.2]|
|         4.9|        3.1|         1.5|        0.1| Setosa|[4.9,

In [12]:
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import ClusteringEvaluator

In [13]:
for k in range(2, 7):
    kmeans = KMeans().setK(k).setSeed(1)
    model = kmeans.fit(pred)
    prediction = model.transform(pred)
    evaluator = ClusteringEvaluator()
    sil = evaluator.evaluate(prediction)
    print("k={}".format(k))
    print("シルエット係数={}".format(sil))

k=2
シルエット係数=0.850351222925146
k=3
シルエット係数=0.7356596054332228
k=4
シルエット係数=0.6722537284209328
k=5
シルエット係数=0.615835261889785
k=6
シルエット係数=0.5479277763909295


In [14]:
# k=3の場合
kmeans = KMeans().setK(3).setSeed(1)
model = kmeans.fit(pred)
prediction = model.transform(pred)
#evaluator = ClusteringEvaluator()
#sil = evaluator.evaluate(prediction)

In [15]:
prediction.show()

+------------+-----------+------------+-----------+-------+-----------------+----------+
|sepal_length|sepal_width|petal_length|petal_width|variety|         features|prediction|
+------------+-----------+------------+-----------+-------+-----------------+----------+
|         5.1|        3.5|         1.4|        0.2| Setosa|[5.1,3.5,1.4,0.2]|         1|
|         4.9|        3.0|         1.4|        0.2| Setosa|[4.9,3.0,1.4,0.2]|         1|
|         4.7|        3.2|         1.3|        0.2| Setosa|[4.7,3.2,1.3,0.2]|         1|
|         4.6|        3.1|         1.5|        0.2| Setosa|[4.6,3.1,1.5,0.2]|         1|
|         5.0|        3.6|         1.4|        0.2| Setosa|[5.0,3.6,1.4,0.2]|         1|
|         5.4|        3.9|         1.7|        0.4| Setosa|[5.4,3.9,1.7,0.4]|         1|
|         4.6|        3.4|         1.4|        0.3| Setosa|[4.6,3.4,1.4,0.3]|         1|
|         5.0|        3.4|         1.5|        0.2| Setosa|[5.0,3.4,1.5,0.2]|         1|
|         4.4|       

In [16]:
prediction.groupBy("prediction").count().orderBy("prediction").show()

+----------+-----+
|prediction|count|
+----------+-----+
|         0|   62|
|         1|   50|
|         2|   38|
+----------+-----+

