In [11]:
# %conda install -y openjdk
# %conda install -y pyspark
# %conda install -y -c conda-forge findspark

In [12]:
from pyspark.sql import SparkSession
from pyspark.ml.clustering import PowerIterationClustering
import os

In [13]:
# start spark session
spark = SparkSession.builder.appName("PowerIterationClustering Example").getOrCreate()
print('Spark Version: {}'.format(spark.version))

Spark Version: 3.5.0


In [14]:
# create graph dataset (edges)
data = [(1, 0, 0.5),
        (2, 0, 0.5), (2, 1, 0.7),
        (3, 0, 0.5), (3, 1, 0.7), (3, 2, 0.9),
        (4, 0, 0.5), (4, 1, 0.7), (4, 2, 0.9), (4, 3, 1.1),
        (5, 0, 0.5), (5, 1, 0.7), (5, 2, 0.9), (5, 3, 1.1), (5, 4, 1.3)]
df = spark.createDataFrame(data).toDF("src", "dst", "weight").repartition(1)
df.show()

[Stage 0:>                                                          (0 + 8) / 8]

+---+---+------+
|src|dst|weight|
+---+---+------+
|  1|  0|   0.5|
|  2|  0|   0.5|
|  2|  1|   0.7|
|  3|  0|   0.5|
|  3|  1|   0.7|
|  3|  2|   0.9|
|  4|  0|   0.5|
|  4|  1|   0.7|
|  4|  2|   0.9|
|  4|  3|   1.1|
|  5|  0|   0.5|
|  5|  1|   0.7|
|  5|  2|   0.9|
|  5|  3|   1.1|
|  5|  4|   1.3|
+---+---+------+



                                                                                

In [15]:
# create PIC model
pic = PowerIterationClustering(k=2, weightCol="weight")
pic.setMaxIter(40)

PowerIterationClustering_ac1fcd513a58

In [16]:
# apply PIC model to graph dataset
assignments = pic.assignClusters(df)
assignments.sort(assignments.id).show(truncate=False)

+---+-------+
|id |cluster|
+---+-------+
|0  |0      |
|1  |0      |
|2  |0      |
|3  |0      |
|4  |0      |
|5  |1      |
+---+-------+



In [17]:
# save PIC model
pic_path = os.getcwd() + "/pic"
pic.write().overwrite().save(pic_path)

In [18]:
# load PIC model
pic2 = PowerIterationClustering.load(pic_path)
pic2.getK()

2

In [19]:
pic2.getMaxIter()

40

In [20]:
# compare the cluster assignments from before with the assignments of the saved model
pic2.assignClusters(df).take(6) == assignments.take(6)

True

In [21]:
# stop session 
spark.stop()