In [2]:
import re
from glob import glob
from zennitcrp.crp.attribution import CondAttribution
import os
import json
from zennit.composites import EpsilonPlusFlat
from zennit.canonizers import SequentialMergeBatchNorm
from crp.attribution import CondAttribution
from sklearn.cluster import KMeans
import numpy as np
import hdbscan
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.cluster import SpectralClustering
from sklearn.metrics import silhouette_score
import torch
from torchvision.models.vgg import vgg16_bn
import torchvision.transforms as T
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

In [3]:
def analysis_clusters(methods_cluster,
                      heatmaps_scaled,
                      picture_name,
                      cluster_folder_save_path,
                      min_cluster : int = 2,
                      max_cluster : int = 16):
    cluster_range = range(min_cluster, max_cluster)
    silhouette_scores = {method: [] for method in methods_cluster}
    GMM_dict = {}
    KMeans_dict = {}
    Spectral_dict = {}
    Agglomerative_dict = {}

    for n_clusters in cluster_range:
        for method_name, clustering_function in methods_cluster.items():
            model = clustering_function(n_clusters)
            labels = model.fit_predict(heatmaps_scaled)

            if len(set(labels)) > 1:
                score = silhouette_score(heatmaps_scaled, labels)
            else:
                score = -1
            silhouette_scores[method_name].append(score)

            if method_name == GMM_CONST:
                GMM_dict[n_clusters] = labels
            elif method_name == KMEANS_CONST:
                KMeans_dict[n_clusters] = labels
            elif method_name == SPECTRAL_CONST:
                Spectral_dict[n_clusters] = labels
            elif method_name == AGGLOMERATIVE_CONST:
                Agglomerative_dict[n_clusters] = labels

    general_dict = {}
    general_score_dict = {}

    for method_name, _ in methods_cluster.items():
        if method_name == GMM_CONST:
            list_dict_keys = list(GMM_dict.keys())
            for key in list_dict_keys:
                general_dict[f"{key}_{GMM_CONST}"] = np.array(GMM_dict[key])
            general_score_dict[f"{GMM_CONST}"] = np.array(silhouette_scores[GMM_CONST])
        elif method_name == KMEANS_CONST:
            list_dict_keys = list(KMeans_dict.keys())
            for key in list_dict_keys:
                general_dict[f"{key}_{KMEANS_CONST}"] = np.array(KMeans_dict[key])
            general_score_dict[f"{KMEANS_CONST}"] = np.array(silhouette_scores[KMEANS_CONST])
        elif method_name == SPECTRAL_CONST:
            list_dict_keys = list(Spectral_dict.keys())
            for key in list_dict_keys:
                general_dict[f"{key}_{SPECTRAL_CONST}"] = np.array(Spectral_dict[key])
            general_score_dict[f"{SPECTRAL_CONST}"] = np.array(silhouette_scores[SPECTRAL_CONST])
        elif method_name == AGGLOMERATIVE_CONST:
            list_dict_keys = list(Agglomerative_dict.keys())
            for key in list_dict_keys:
                general_dict[f"{key}_{AGGLOMERATIVE_CONST}"] = np.array(Agglomerative_dict[key])
            general_score_dict[f"{AGGLOMERATIVE_CONST}"] = np.array(silhouette_scores[AGGLOMERATIVE_CONST])
    general_score_dict[f"{MIN_CLUSTER_CONST}"] = min_cluster
    general_score_dict[f"{MAX_CLUSTER_CONST}"] = max_cluster

    np.savez(f"{cluster_folder_save_path}/clusters_{picture_name}.npz", **general_dict)
    np.savez(f"{cluster_folder_save_path}/scores_{picture_name}.npz", **general_score_dict)
    return silhouette_scores, cluster_range

In [4]:
def plot_silhouette_scores(silhouette_scores, cluster_range, file_name):
    fig, ax = plt.subplots(figsize=(10, 6))

    for method, scores in silhouette_scores.items():
        ax.plot(cluster_range, scores, marker='o', linestyle='-', label=method)

    ax.set_xlabel("Nombre de Clusters")
    ax.set_ylabel("Silhouette Score")
    ax.set_title(f"Comparaison des Algorithmes de Clustering\n{file_name}")
    ax.legend()
    ax.grid(True)
    plt.show()

In [5]:
def load_and_process_normalized(file_name : str,
                                base_path,
                                layer_name : str = "layer_40"):
    data = np.load(f"{base_path}/heatmap/{file_name}")
    heatmaps = data[layer_name]
    heatmaps = np.abs(heatmaps)
    heatmaps_scaled = np.zeros_like(heatmaps)
    for i in range(heatmaps.shape[0]):
        min_val = np.min(heatmaps[i])
        max_val = np.max(heatmaps[i])
        if max_val > min_val:  # Éviter division par zéro
            heatmaps_scaled[i] = (heatmaps[i] - min_val) / (max_val - min_val)
    heatmaps_flat = heatmaps_scaled.reshape(heatmaps_scaled.shape[0], -1)
    return heatmaps, heatmaps_flat

In [6]:
def compute_feature_importance(model, input_tensor, layer_idx, num_features, pred_class):
    """
    Calcule l'importance de chaque feature d'une couche donnée pour une classe prédite.

    Arguments :
    - model : le modèle VGG16
    - input_tensor : l'image d'entrée sous forme de tenseur
    - layer_idx : l'index de la couche (ex : 40)
    - num_features : le nombre total de features dans cette couche (ex : 512)
    - pred_class : la classe prédite initialement

    Retourne :
    - Un dictionnaire {feature_idx : importance} trié par importance décroissante
    """
    input_tensor.requires_grad = True
    # Obtenir la probabilité originale de la classe prédite
    with torch.no_grad():
        output_original = model(input_tensor)
        probs_original = torch.nn.functional.softmax(output_original, dim=1)
        original_score = probs_original[0, pred_class].item()

    feature_importance = {}

    # Désactiver chaque feature une par une et mesurer l'impact
    for feature_idx in range(num_features):
        def zero_out_feature(module, input, output, feature_idx=feature_idx):
            output[:, feature_idx, :, :] = 0  # Désactiver la feature
            return output

        # Ajouter un hook temporaire
        hook = model.features[layer_idx].register_forward_hook(zero_out_feature)

        # Faire une prédiction avec la feature désactivée
        with torch.no_grad():
            output_disabled = model(input_tensor)
            probs_disabled = torch.nn.functional.softmax(output_disabled, dim=1)
            new_score = probs_disabled[0, pred_class].item()

        # Supprimer le hook
        hook.remove()

        # Calcul de l'importance
        importance = original_score - new_score
        feature_importance[feature_idx] = float(importance)

        # Affichage de progression
        #print(f"Feature {feature_idx+1}/{num_features} - Importance: {importance:.4f}")

    # Trier les features par importance décroissante
    sorted_importance = dict(sorted(feature_importance.items(), key=lambda item: item[1], reverse=True))

    return sorted_importance


In [7]:
def processPicture(model,
                   transform,
                   global_dictionary : dict,
                   picture_path : str,
                   heatmap_folder_save_path : str,
                   device : str = "cpu"):
    local_dictionary = {}
    image_name = os.path.splitext(os.path.basename(picture_path))[0]
    image = Image.open(picture_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)
    input_tensor.requires_grad = True

    output = model(input_tensor)
    pred_class = torch.argmax(output, dim=1).item()
    probs = torch.nn.functional.softmax(output, dim=1)

    local_dictionary[CLASSE_PREDICTED] = pred_class
    local_dictionary[PROBABILITY] = probs[0, pred_class]

    composite = EpsilonPlusFlat([SequentialMergeBatchNorm()])
    attribution = CondAttribution(model, no_param_grad=True)

    features_dict = {}
    layers_heatmaps = {}
    for layer_idx, num_features in features_per_layer.items():
        all_heatmaps = []
        num_feature_per_batch = 8
        index = 0
        borne_sup = 0
        while borne_sup != num_features:
            borne_sup = min((index+1)*num_feature_per_batch, num_features)
            conditions = [{"y": [40], "features.40": [j]} for j in range(index*num_feature_per_batch, borne_sup)]
            heatmaps, _, _, _ = attribution(input_tensor, conditions, composite)
            all_heatmaps.append(heatmaps)
            index += 1
        heatmaps = np.concatenate([heatmap.cpu().numpy() for heatmap in all_heatmaps], axis=0)
        #heatmaps = np.concatenate(all_heatmaps, axis=0)
        layers_heatmaps[layer_idx] = heatmaps

        importance_dict = compute_feature_importance(model, input_tensor, layer_idx=layer_idx, num_features=num_features, pred_class=pred_class)

        features_dict[layer_idx] = importance_dict

    # Normalisation globale sur toutes les heatmaps
    min_value = min([heatmaps.min() for heatmaps in layers_heatmaps.values()])
    max_value = max([heatmaps.max() for heatmaps in layers_heatmaps.values()])
    max_value = max(abs(min_value), abs(max_value))
    min_value = -max_value

    save_path = os.path.join(heatmap_folder_save_path, f"{image_name}.npz")

    save_dict = {f"layer_{idx_layers}": np.array(heatmaps) for idx_layers, heatmaps in layers_heatmaps.items()}
    np.savez(save_path, **save_dict)

    local_dictionary[FEATURES] = features_dict
    global_dictionary[image_name] = local_dictionary
    return global_dictionary

In [8]:
def tensor_to_list(obj):
    if isinstance(obj, torch.Tensor):
        return obj.tolist()
    raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")

In [9]:
def get_filename_without_extension(path: str) -> str:
    return os.path.splitext(os.path.basename(path))[0]

In [10]:
def create_json_from_data(data, output_filename):
    """
    Fonction qui prend un dictionnaire de données avec des informations sur des images,
    et les sauvegarde dans un fichier JSON structuré.

    :param data: Dictionnaire avec les données des images.
    :param output_filename: Nom du fichier JSON à créer.
    """
    # Créer un dictionnaire pour les données au format désiré
    image_data = {}

    for image_name, info in data.items():
        # Extraire les informations : classe, probabilité, et dictionnaire de features
        classe = info[CLASSE_PREDICTED]
        probability = info[PROBABILITY]
        features = info[FEATURES]

        # Ajouter ces informations dans le dictionnaire final
        image_data[image_name] = {
            CLASSE_PREDICTED: classe,
            PROBABILITY: probability,
            FEATURES: features
        }

    # Sauvegarder les données dans un fichier JSON
    with open(output_filename, 'w') as json_file:
        json.dump(image_data, json_file, indent=4, default=tensor_to_list)
    print(f"Le fichier JSON '{output_filename}' a été créé avec succès.")

In [None]:
CLASSE_PREDICTED = "classe_predicted"
PROBABILITY = "probability"
FEATURES = "features"

In [None]:
GMM_CONST = "GMM"
KMEANS_CONST = "KMeans"
SPECTRAL_CONST = "SpectralClustering"
AGGLOMERATIVE_CONST = "AgglomerativeClustering"
HDBSCAN_CONST = "HDBSCAN"
MIN_CLUSTER_CONST = "min_cluster"
MAX_CLUSTER_CONST = "max_cluster"

TODO : modifier pour mettre le bon chemin
picture_folder_path chemin vers le dossier avec les photos
cluster_folder_save_path chemin où sont sauvegarder les .npz des heatmaps et l'importance de chaque features dans un .json
heatmap_folder_save_path chemin où sont sauvegarder les .npz des clusters et les scores de chaque méthode de clustering

In [11]:
base_path = "./data/v6/clusters"

In [None]:
def getDevice():
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    return device

In [None]:
def getModel(device : str = "cpu"):
    model_vgg16 = vgg16_bn(True).to(device)
    model_vgg16.eval()
    return model_vgg16

In [None]:
def getTransform():
    transform_vgg16 = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform_vgg16

In [None]:
def extract_heatmap_from_image(model,
                    transform,
                    device,
                    heatmap_folder_save_path : str,
                    list_image_paths : list,
                    avoid_images : list = None):
    if avoid_images is None:
        avoid_images = []
    for image_path in list_image_paths:
        global_dictionary = {}
        if image_path in avoid_images:
            print(f"Image {image_path} avoid")
            continue
        image_name = os.path.splitext(os.path.basename(image_path))[0]
        global_dictionary = processPicture(model=model,
                                           transform=transform,
                                           global_dictionary = global_dictionary,
                                           picture_path = image_path,
                                           heatmap_folder_save_path=heatmap_folder_save_path,
                                           device=device)
        avoid_images.append(image_path)
        create_json_from_data(global_dictionary, f"{heatmap_folder_save_path}/{image_name}_importance.json")


In [None]:
def get_all_image_paths(picture_folder_path : str, extension : str = "jpeg"):
    return glob(os.path.join(picture_folder_path, f"*.{extension}"))

In [None]:
def perform_kmeans_clustering(heatmaps_scaled, n_clusters : int = 10):
    model = KMeans(n_clusters=n_clusters, random_state=7)
    labels = model.fit_predict(heatmaps_scaled)
    return labels

In [None]:
def load_and_process_normalized(file_name_with_extension : str,
                                base_path,
                                layer_name : str = "layer_40"):
    data = np.load(f"{base_path}/heatmap/{file_name_with_extension}")
    heatmaps = data[layer_name]
    heatmaps = np.abs(heatmaps)
    heatmaps_scaled = np.zeros_like(heatmaps)
    for i in range(heatmaps.shape[0]):
        min_val = np.min(heatmaps[i])
        max_val = np.max(heatmaps[i])
        if max_val > min_val:  # Éviter division par zéro
            heatmaps_scaled[i] = (heatmaps[i] - min_val) / (max_val - min_val)
    heatmaps_flat = heatmaps_scaled.reshape(heatmaps_scaled.shape[0], -1)
    return heatmaps, heatmaps_flat

In [None]:
def indices_par_valeur(lst):
    arr = np.array(lst)
    return {val: np.where(arr == val)[0].tolist() for val in np.unique(arr)}

In [None]:
def dictionary_one_hot_encoding(lst):
    arr = np.array(lst)
    indices_dict = {}
    for val in np.unique(arr):
        binaire = np.zeros(len(arr), dtype=int)  # Crée une liste de zéros de taille 512
        binaire[np.where(arr == val)[0]] = 1     # Met un 1 aux indices correspondant à la valeur
        indices_dict[val] = binaire.tolist()
    return indices_dict

In [None]:
final_number_big_clusters = 10 #the number of features different from all the class

In [None]:
random_state = 7
n_clusters_per_picture = 10
avoid_images = []
save_all_hot_encoding = True
save_cluster = True
extract_heatmap = True

extrait toutes els heatmaps des fichiers dans "{base_path}/pictures"
Ensuite fait cluster KMeans à 10.
Ensuite fait un KMean à 10 de type big-cluster dessus
Ensuite fait PCA et TSNE pour visualiser les clusters
Sauvegarde dans un dictionnaire avec chaque big-cluster les features associé.

In [None]:
def a(base_path : str,
      final_number_big_clusters : int = 10,
      extract_heatmap : bool = True,
      save_cluster : bool = True,
      save_all_hot_encoding : bool = True,
      avoid_images : list = None,
      n_clusters_per_picture : int = 10,
      random_state : int = 7):
    device = getDevice()
    model = getModel(device)
    transform = getTransform()

    heatmap_folder_save_path = f"{base_path}/heatmap"
    image_paths = f"{base_path}/pictures"
    #verification paths

    list_image_paths = get_all_image_paths(image_paths)
    if avoid_images is None:
        avoid_images = []
    if extract_heatmap :
        extract_heatmap_from_image(device, model, transform, heatmap_folder_save_path, list_image_paths, avoid_images)

    if save_cluster:
        dict_label = {}
        dict_label["n_clusters"] = 10
        for file_name_with_extension in os.listdir(f"{base_path}/heatmap"):
            file_name = os.path.splitext(file_name_with_extension)[0]
            _, heatmaps_scaled = load_and_process_normalized(file_name_with_extension, base_path)
            labels = perform_kmeans_clustering(heatmaps_scaled, n_clusters=n_clusters_per_picture)

            dict_label[file_name] = labels
        np.savez(f"{base_path}/clusters/labels.npz", **dict_label)



    if save_all_hot_encoding:
        dict_label = np.load(f"{base_path}/clusters/labels.npz")
        all_hot_encoding = np.array()
        for file_name, cluster_label in dict_label.items():
            index_in_each_cluster = dictionary_one_hot_encoding(cluster_label)
            local_hot_encoding = np.array(list(index_in_each_cluster.values()))
            all_hot_encoding = np.concatenate((all_hot_encoding, local_hot_encoding), axis=0)
        np.savez(f"{base_path}/clusters/hot_encoding.npz", all_hot_encoding=all_hot_encoding)

    all_hot_encoding = np.load(f"{base_path}/clusters/hot_encoding.npz")["all_hot_encoding"]

    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(np.stack(all_hot_encoding))

    tsne = TSNE(n_components=2, random_state=random_state, perplexity=30, n_iter=1000)
    tsne_result = tsne.fit_transform(np.stack(all_hot_encoding))

    kmeans = KMeans(n_clusters=final_number_big_clusters, random_state=random_state)
    big_clusters = kmeans.fit_predict(np.stack(all_hot_encoding))

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.scatter(pca_result[:, 0], pca_result[:, 1], c=big_clusters, cmap='tab10', s=5)
    plt.colorbar()
    plt.title('PCA Visualization with K-Means Clusters')

    # 5. Visualisation des clusters avec t-SNE
    plt.subplot(1, 2, 2)
    plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=big_clusters, cmap='tab10', s=5)
    plt.colorbar()
    plt.title('t-SNE Visualization with K-Means Clusters')

    plt.tight_layout()
    plt.show()

    clusters_by_big_clusters = indices_par_valeur(big_clusters)
    feature_to_big_cluster = {}
    for big_cluster_index, clusters_list in clusters_by_big_clusters.items():
        for cluster_index in clusters_list:
            one_hot_encoding_current_big_cluster = all_hot_encoding[cluster_index]
            for feature_index, value in enumerate(one_hot_encoding_current_big_cluster):
                if value == 1:
                    if feature_index not in feature_to_big_cluster:
                        feature_to_big_cluster[feature_index] = [0] * final_number_big_clusters
                    feature_to_big_cluster[feature_index][big_cluster_index] += 1

    big_bluster_index_to_feature = {}
    for feature_index, big_clusters_list in feature_to_big_cluster.items():
        best_big_cluster = np.argmax(big_clusters_list)
        if best_big_cluster not in big_bluster_index_to_feature:
            big_bluster_index_to_feature[best_big_cluster] = []
        big_bluster_index_to_feature[best_big_cluster].append(feature_index)

    return big_bluster_index_to_feature


In [None]:
def visualize_clusters(dict_feature_per_class: str,
                       base_path: str,
                       list_image_name_to_test: list):

    heatmap_per_class_per_image = {}
    for key in dict_feature_per_class.keys():
        heatmap_per_image = {}
        for image_name in list_image_name_to_test:
            heatmap_file_name = f"{image_name}.npz"
            try:
                heatmaps, _ = load_and_process_normalized(heatmap_file_name, base_path, layer_name="layer_40")
                feature_index = np.array(dict_feature_per_class[key])

                if len(feature_index) > 0:
                    cluster_heatmaps = heatmaps[feature_index]
                    combined_heatmap = np.sum(cluster_heatmaps, axis=0)
                    vmin = min(vmin, combined_heatmap.min())
                    vmax = max(vmax, combined_heatmap.max())
                    heatmap_per_image[image_name] = combined_heatmap
            except Exception as e:
                print(f"Error processing {image_name} for class {key}: {str(e)}")
                continue

        if heatmap_per_image:  # Ne garder que les classes qui ont des heatmaps
            heatmap_per_class_per_image[key] = heatmap_per_image

    # Filtrer les classes vides et créer la liste des classes à afficher
    classes_to_display = [key for key, heatmaps in heatmap_per_class_per_image.items() if heatmaps]

    if not classes_to_display:
        print("No data to display!")
        return

    nb_line_adjust = len(classes_to_display)
    nb_col_adjust = len(list_image_name_to_test)

    # Création de la figure
    fig, axes = plt.subplots(nrows=nb_line_adjust,
                            ncols=nb_col_adjust,
                            figsize=(4 * nb_col_adjust, 4 * nb_line_adjust))
    fig.tight_layout(pad=3.0)

    # Assurer que axes est toujours 2D
    if nb_line_adjust == 1 and nb_col_adjust == 1:
        axes = np.array([[axes]])
    elif nb_line_adjust == 1:
        axes = axes.reshape(1, -1)
    elif nb_col_adjust == 1:
        axes = axes.reshape(-1, 1)

    # Création des visualisations
    im = None
    for row_idx, class_name in enumerate(classes_to_display):
        heatmap_per_image = heatmap_per_class_per_image[class_name]
        for col_idx, image_name in enumerate(list_image_name_to_test):
            ax = axes[row_idx, col_idx]
            heatmap = heatmap_per_image.get(image_name)

            if heatmap is not None:
                im = ax.imshow(heatmap, cmap='viridis', vmin=vmin, vmax=vmax)
                ax.set_title(f"{class_name} - {image_name}", fontsize=10)
            else:
                ax.axis('off')
            ax.axis('off')

    if im is not None:
        cbar = fig.colorbar(im, ax=axes.ravel().tolist(), orientation='vertical', fraction=0.05, pad=0.02)
        cbar.set_label('Activation Intensity')

    plt.show()

In [None]:
big_bluster_index_to_feature = a(base_path,
                              final_number_big_clusters,
                              extract_heatmap,
                              save_cluster,
                              save_all_hot_encoding,
                              avoid_images,
                              n_clusters_per_picture,
                              random_state)

In [None]:
list_image_name_to_test = [
    'n02437312_2790',
    'n02423022_9745',
    'n02437616_14498',
    'n02437312_3178',
    'n02437616_8125',
    'n02423022_2042'
]

In [None]:
visualize_clusters(big_bluster_index_to_feature, base_path, list_image_name_to_test)