In [0]:
# imports
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, MinMaxScaler
from pyspark.sql.functions import col
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.rdd import RDD
import numpy as np

In [0]:
def Kmeans(data, k, ct=0.0001, iterations=30, initial_centroids=None):
    """
    Implements the K-means clustering algorithm with robust handling of missing clusters.
    """
    # Step 1: Prepare and normalize the data
    data = PrepareData(data)
    data_rdd = data.rdd.map(lambda row: row["scaled_features"].toArray())

    # Step 2: Initialize centroids
    if initial_centroids:
        centroids = np.array(initial_centroids, dtype=float)
    else:
        centroids = np.array(data_rdd.takeSample(False, k), dtype=float)
    

    for iteration in range(iterations):
        # Assign each point to the nearest centroid
        def assign_cluster(point):
            distances = [np.linalg.norm(point - c) for c in centroids]
            cluster_id = np.argmin(distances)
            return (cluster_id, point)

        clustered_rdd = data_rdd.map(assign_cluster)

        # Count points in each cluster
        cluster_counts = clustered_rdd.map(lambda x: x[0]).countByValue()
        

        # Handle missing clusters
        missing_clusters = set(range(k)) - set(cluster_counts.keys())
        if missing_clusters:
            # Reinitialize missing clusters with random points
            reinitialized_centroids = np.array(data_rdd.takeSample(False, len(missing_clusters)), dtype=float)
            for i, missing_cluster in enumerate(missing_clusters):
                centroids[missing_cluster] = reinitialized_centroids[i]
            

        # Recalculate centroids
        def sum_points(p1, p2):
            return [x + y for x, y in zip(p1, p2)]

        cluster_sum_rdd = clustered_rdd.mapValues(lambda x: (x, 1)) \
                                       .reduceByKey(lambda a, b: (sum_points(a[0], b[0]), a[1] + b[1]))
        new_centroids_rdd = cluster_sum_rdd.mapValues(lambda x: [v / x[1] for v in x[0]])
        new_centroids = new_centroids_rdd.map(lambda x: x[1]).collect()

        # Check for convergence
        centroid_shift = sum(np.linalg.norm(np.array(new_c) - np.array(old_c)) for new_c, old_c in zip(new_centroids, centroids))

        if centroid_shift < ct:
            break

        centroids = np.array(new_centroids, dtype=float)


    return [tuple(map(lambda x: round(x, 5), c)) for c in centroids]


def PrepareData(data):
    """
    Prepares the data for clustering by normalizing features.
    """
    feature_columns = data.columns[:-1]
    vector_assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
    data_features = vector_assembler.transform(data).select("features")

    scaler = MinMaxScaler(inputCol="features", outputCol="scaled_features")
    scaler_model = scaler.fit(data_features)
    data_normalized = scaler_model.transform(data_features).select("scaled_features")
    return data_normalized

In [0]:
# """ ******************* Importing the data ******************* """

# Iris = "/FileStore/tables/Iris.csv"
# file_type = "csv"
# infer_schema = "True"
# first_row_is_header = "True"
# delimiter = ","
# iris = spark.read.format(file_type) \
#     .option("inferSchema", infer_schema) \
#     .option("header", first_row_is_header) \
#     .option("sep", delimiter) \
#     .load(Iris)

# # """ ******************* Example for testing ******************* """
# new_centroids = Kmeans(iris, 2, initial_centroids=[(0.5,0.5,0.5,0.5,0.5),(0.3,0.3,0.3,0.3,0.3)])
# round_new_centroids=[tuple(round(num, 5) for num in tup) for tup in new_centroids]
# expected_new_centroids=[(0.16871,0.19553,0.58252,0.08475,0.06618),(0.67067,0.54882,0.36532,0.66478,0.65951)]
# if (not len(new_centroids)==len(expected_new_centroids)):
#     print("Failed - Number of clusters is different than requested")
# if set(round_new_centroids)==set(expected_new_centroids):
#     print("The test passed successfully")
# else:
#     print("The test failed")

