In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.cluster import SpectralClustering
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
from torchvision import datasets, transforms, models
import torch
import torch.nn as nn

# =====================
# CONFIGURATION
# =====================
DATA_DIR = "./data"       # Thư mục ảnh, chia theo class không bắt buộc
NUM_CLUSTERS = 5          # Số cụm cần tìm
USE_CNN_FEATURES = True   # True: trích xuất đặc trưng bằng ResNet18, False: dùng trung bình màu
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =====================
# FEATURE EXTRACTION
# =====================
def extract_features(dataset, batch_size=32):
    """
    Trích xuất đặc trưng ảnh bằng ResNet18 pre-trained.
    """
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    model = models.resnet18(weights="IMAGENET1K_V1")
    model.fc = nn.Identity()  # Bỏ layer phân loại cuối để lấy vector đặc trưng
    model = model.to(DEVICE)
    model.eval()

    features = []
    with torch.no_grad():
        for imgs, _ in tqdm(dataloader, desc="Extracting features"):
            imgs = imgs.to(DEVICE)
            feat = model(imgs)
            features.append(feat.cpu().numpy())

    return np.vstack(features)


def extract_color_features(dataset):
    """
    Trích xuất đặc trưng màu trung bình (3 giá trị R,G,B).
    """
    color_features = []
    for img, _ in tqdm(dataset, desc="Extracting color features"):
        arr = np.array(img).reshape(-1, 3)
        color_features.append(arr.mean(axis=0))
    return np.array(color_features)


# =====================
# MAIN PIPELINE
# =====================
if __name__ == "__main__":
    print("🚀 Spectral Clustering for Images")

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])

    dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
    print(f"Loaded {len(dataset)} images from {DATA_DIR}")

    # --- Trích xuất đặc trưng ---
    if USE_CNN_FEATURES:
        features = extract_features(dataset)
    else:
        features = extract_color_features(dataset)

    print(f"Feature shape: {features.shape}")

    # --- Chuẩn hóa dữ liệu ---
    features_std = StandardScaler().fit_transform(features)

    # --- Áp dụng Spectral Clustering ---
    print("\n🔹 Running Spectral Clustering...")
    clustering = SpectralClustering(
        n_clusters=NUM_CLUSTERS,
        affinity='nearest_neighbors',  # Hoặc 'rbf' nếu dữ liệu liên tục
        n_neighbors=10,
        assign_labels='kmeans',
        random_state=42,
        n_jobs=-1
    )

    labels = clustering.fit_predict(features_std)

    # --- Đánh giá bằng Silhouette Score ---
    score = silhouette_score(features_std, labels)
    print(f"✅ Silhouette Score = {score:.4f}")

    # --- Visualization bằng PCA ---
    print("📊 Reducing dimensions for visualization...")
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(features_std)

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(reduced[:, 0], reduced[:, 1], c=labels, cmap="tab10", s=20)
    plt.title("Spectral Clustering Visualization (PCA 2D)")
    plt.xlabel("PCA 1")
    plt.ylabel("PCA 2")
    plt.colorbar(scatter, label="Cluster ID")
    plt.tight_layout()
    plt.show()

    # --- Hiển thị ảnh đại diện mỗi cụm ---
    print("\n🖼️ Sample images per cluster:")
    samples_per_cluster = 3
    img_paths = np.array(dataset.imgs)[:, 0]
    for cluster_id in range(NUM_CLUSTERS):
        cluster_indices = np.where(labels == cluster_id)[0]
        if len(cluster_indices) == 0:
            continue
        chosen = np.random.choice(cluster_indices, min(samples_per_cluster, len(cluster_indices)), replace=False)

        fig, axes = plt.subplots(1, len(chosen), figsize=(10, 3))
        fig.suptitle(f"Cluster {cluster_id}", fontsize=14)
        for i, idx in enumerate(chosen):
            img, _ = dataset[idx]
            axes[i].imshow(np.transpose(img.numpy(), (1, 2, 0)))
            axes[i].axis("off")
        plt.show()

    print("🎯 Done.")
