In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.cluster import SpectralClustering
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
from torchvision import datasets, transforms, models
import torch
import torch.nn as nn

# =====================
# CONFIGURATION
# =====================
DATA_DIR = "./data"       # Th∆∞ m·ª•c ·∫£nh, chia theo class kh√¥ng b·∫Øt bu·ªôc
NUM_CLUSTERS = 5          # S·ªë c·ª•m c·∫ßn t√¨m
USE_CNN_FEATURES = True   # True: tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng b·∫±ng ResNet18, False: d√πng trung b√¨nh m√†u
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =====================
# FEATURE EXTRACTION
# =====================
def extract_features(dataset, batch_size=32):
    """
    Tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng ·∫£nh b·∫±ng ResNet18 pre-trained.
    """
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    model = models.resnet18(weights="IMAGENET1K_V1")
    model.fc = nn.Identity()  # B·ªè layer ph√¢n lo·∫°i cu·ªëi ƒë·ªÉ l·∫•y vector ƒë·∫∑c tr∆∞ng
    model = model.to(DEVICE)
    model.eval()

    features = []
    with torch.no_grad():
        for imgs, _ in tqdm(dataloader, desc="Extracting features"):
            imgs = imgs.to(DEVICE)
            feat = model(imgs)
            features.append(feat.cpu().numpy())

    return np.vstack(features)


def extract_color_features(dataset):
    """
    Tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng m√†u trung b√¨nh (3 gi√° tr·ªã R,G,B).
    """
    color_features = []
    for img, _ in tqdm(dataset, desc="Extracting color features"):
        arr = np.array(img).reshape(-1, 3)
        color_features.append(arr.mean(axis=0))
    return np.array(color_features)


# =====================
# MAIN PIPELINE
# =====================
if __name__ == "__main__":
    print("üöÄ Spectral Clustering for Images")

    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
    ])

    dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
    print(f"Loaded {len(dataset)} images from {DATA_DIR}")

    # --- Tr√≠ch xu·∫•t ƒë·∫∑c tr∆∞ng ---
    if USE_CNN_FEATURES:
        features = extract_features(dataset)
    else:
        features = extract_color_features(dataset)

    print(f"Feature shape: {features.shape}")

    # --- Chu·∫©n h√≥a d·ªØ li·ªáu ---
    features_std = StandardScaler().fit_transform(features)

    # --- √Åp d·ª•ng Spectral Clustering ---
    print("\nüîπ Running Spectral Clustering...")
    clustering = SpectralClustering(
        n_clusters=NUM_CLUSTERS,
        affinity='nearest_neighbors',  # Ho·∫∑c 'rbf' n·∫øu d·ªØ li·ªáu li√™n t·ª•c
        n_neighbors=10,
        assign_labels='kmeans',
        random_state=42,
        n_jobs=-1
    )

    labels = clustering.fit_predict(features_std)

    # --- ƒê√°nh gi√° b·∫±ng Silhouette Score ---
    score = silhouette_score(features_std, labels)
    print(f"‚úÖ Silhouette Score = {score:.4f}")

    # --- Visualization b·∫±ng PCA ---
    print("üìä Reducing dimensions for visualization...")
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(features_std)

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(reduced[:, 0], reduced[:, 1], c=labels, cmap="tab10", s=20)
    plt.title("Spectral Clustering Visualization (PCA 2D)")
    plt.xlabel("PCA 1")
    plt.ylabel("PCA 2")
    plt.colorbar(scatter, label="Cluster ID")
    plt.tight_layout()
    plt.show()

    # --- Hi·ªÉn th·ªã ·∫£nh ƒë·∫°i di·ªán m·ªói c·ª•m ---
    print("\nüñºÔ∏è Sample images per cluster:")
    samples_per_cluster = 3
    img_paths = np.array(dataset.imgs)[:, 0]
    for cluster_id in range(NUM_CLUSTERS):
        cluster_indices = np.where(labels == cluster_id)[0]
        if len(cluster_indices) == 0:
            continue
        chosen = np.random.choice(cluster_indices, min(samples_per_cluster, len(cluster_indices)), replace=False)

        fig, axes = plt.subplots(1, len(chosen), figsize=(10, 3))
        fig.suptitle(f"Cluster {cluster_id}", fontsize=14)
        for i, idx in enumerate(chosen):
            img, _ = dataset[idx]
            axes[i].imshow(np.transpose(img.numpy(), (1, 2, 0)))
            axes[i].axis("off")
        plt.show()

    print("üéØ Done.")
