In [None]:
import cv2
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import calinski_harabasz_score, silhouette_score
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# Load the image
image = cv2.imread("Images/sample_img_1.png")

In [None]:
def calculate_silhouette(x, y):
    # Set the sample size for calculation
    sample_size = 5000

    # Randomly select a sample from X and corresponding labels
    indices = np.random.choice(len(x), size=sample_size, replace=False)
    X_sample = X[indices]
    labels_sample = y[indices]

    # Calculate the silhouette score for the sample
    silhouette_avg = silhouette_score(X_sample, labels_sample)
    return silhouette_avg

In [None]:
k_min = 3
k_max = 15

model_score = np.zeros(shape=(3,(k_max-k_min)))
model_score.shape

In [None]:
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)

X=hsv_image.reshape(-1, 3)

k_min = 2
k_max = 20

model_score = np.zeros(shape=(3,(k_max-k_min)))
for i in tqdm(range(k_max - k_min)):
    kmeans = KMeans(n_clusters = k_min + i, init = 'k-means++', random_state = 42, n_init='auto')
    cluster_labels  = kmeans.fit_predict(X)
    model_score[0][i] = silhouette_score(X, cluster_labels, sample_size=5000)
    model_score[1][i] = calinski_harabasz_score(X, cluster_labels)
    model_score[2][i] = kmeans.inertia_


In [None]:
fig, axs = plt.subplots(1,3, figsize=(15,5))

axs[0].plot(range(k_min, model_score.shape[1]+k_min), model_score[0], 'gs-')
axs[0].set_title("silhouette_score")

axs[1].plot(range(k_min, model_score.shape[1]+k_min), model_score[1], 'gs-')
axs[1].set_title("calinski_harabasz_score")

axs[2].plot(range(k_min, model_score.shape[1]+k_min), model_score[2], 'gs-')
axs[2].set_title("kmeans.inertia")


plt.show()

In [None]:
# How many clusters to use
cluster_final_num = 15

In [None]:
kmeans = KMeans(n_clusters=cluster_final_num, random_state=42, n_init='auto')
labels = kmeans.fit_predict(X)


# Reshape labels back to the original image shape
clustered_image = labels.reshape(image.shape[:-1])

In [None]:
# Display the original and clustered images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(clustered_image, cmap='viridis')
plt.title(f'Clustered Image ({cluster_final_num} Clusters)')
plt.axis('off')

plt.tight_layout()
plt.show()