In [None]:
import os
import json
import torch
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tqdm import tqdm
import shutil
import umap

# Paths
embedding_folder = '/data1/dxw_data/llm/redbook_final/script_next/image_embeddings_tag'
output_folder = '/data1/dxw_data/llm/redbook_final/script_next/combined_seg_img_pure_094_cluster_imagebind3'
os.makedirs(output_folder, exist_ok=True)

# Load concatenated embeddings
embedding_files = [os.path.join(embedding_folder, fname) for fname in os.listdir(embedding_folder) if fname.endswith('.pt')]

# Load all embeddings into a list
all_embeddings = []
for embedding_file in tqdm(embedding_files, desc="Loading embeddings"):
    embedding = torch.load(embedding_file)
    all_embeddings.append(embedding)

# Concatenate all embeddings into a single tensor
all_embeddings = torch.cat(all_embeddings, dim=0).cpu()  # Move to CPU

# Apply UMAP for dimensionality reduction
umap_model = umap.UMAP(n_components=2, random_state=0)
reduced_embeddings = umap_model.fit_transform(all_embeddings.numpy())

# Determine optimal number of clusters using Average Silhouette Method
silhouette_scores = []
k_values = list(range(2, 501, 5))  # k values from 2 to 500 with a step of 5

for k in k_values:
    kmeans = KMeans(n_clusters=k, random_state=0)
    labels = kmeans.fit_predict(reduced_embeddings)
    
    if len(set(labels)) > 1:  # Check if we have more than 1 label
        score = silhouette_score(reduced_embeddings, labels)
        silhouette_scores.append(score)
    else:
        silhouette_scores.append(-1)  # Append a placeholder score if there's only one label

# Plot silhouette scores
plt.figure(figsize=(18, 6))
plt.plot(k_values, silhouette_scores, marker='o')
plt.xlabel('Number of clusters (k)')
plt.ylabel('Average Silhouette Score')
plt.title('Average Silhouette Score vs. Number of Clusters')
plt.grid(True)
plt.show()

# Select the optimal number of clusters based on the silhouette scores
optimal_k = k_values[silhouette_scores.index(max(silhouette_scores))]

# Perform clustering with optimal k
kmeans = KMeans(n_clusters=optimal_k, random_state=0)
labels = kmeans.fit_predict(reduced_embeddings)

# Save clustered images to output folders based on the clustering results
for idx, label in tqdm(enumerate(labels), desc="Saving clustered images", total=len(labels)):
    label_folder = os.path.join(output_folder, str(label))
    os.makedirs(label_folder, exist_ok=True)
    
    # Extract corresponding image file name from embedding file name
    embedding_file = embedding_files[idx]
    image_filename = os.path.basename(embedding_file).replace('_embedding.pt', '.png')
    
    # Define source image path
    source_image_path = os.path.join(embedding_folder.replace('combined_embeddings', 'combined_seg_img_pure_094'), image_filename)
    
    # Copy image to corresponding cluster folder
    if os.path.exists(source_image_path):
        shutil.copy(source_image_path, os.path.join(label_folder, image_filename))

# Save clustering labels to JSON file
labels_json = {os.path.basename(embedding_files[idx]): int(label) for idx, label in enumerate(labels)}
with open(os.path.join(output_folder, 'labels.json'), 'w') as f:
    json.dump(labels_json, f)

print(f'Clustering complete. Output saved to {output_folder}')
