In [2]:
import os
import re
import json
from sklearn.cluster import KMeans
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.cluster import SpectralClustering

In [3]:
def list_files(directory, extension:str=".npz"):
    """Liste tous les fichiers .npz dans le dossier spécifié."""
    return [f for f in os.listdir(directory) if f.endswith(extension)]

In [4]:
def get_filename_without_extension(path: str) -> str:
    return os.path.splitext(os.path.basename(path))[0]

In [5]:
def extract_code(filename: str) -> str:
    match = re.search(r"(n\d+)", filename)
    return match.group(1) if match else ""

In [6]:
def extract_code_2(filename: str) -> str:
    match = re.search(r"(n\d+_\d+)", filename)
    return match.group(1) if match else ""

In [7]:
def load_and_process_flat(file_name : str,
                          base_path,
                          layer_name : str = "layer_40"):
    data = np.load(f"{base_path}/heatmap/{file_name}")
    heatmaps = data[layer_name]
    heatmaps = np.abs(heatmaps)
    heatmaps_flat = heatmaps.reshape(heatmaps.shape[0], -1)
    return heatmaps, heatmaps_flat

In [8]:
def visualize_clusters(clusters_analysis_file_base_name : str,
                       base_path : str,
                       target_occurence_min_per_class : dict,
                       list_image_name_to_test : list):
    all_feature_in_each_class = np.load(f"{base_path}/clusters_analysis/{clusters_analysis_file_base_name}_all_feature_in_each_class.npz")
    class_and_occurence_per_feature = np.load(f"{base_path}/clusters_analysis/{clusters_analysis_file_base_name}_class_and_occurence_per_feature.npz")

    dict_feature_per_class = {}
    for key in all_feature_in_each_class.keys():
        if key not in class_and_occurence_per_feature.keys():
            dict_feature_per_class[key] = all_feature_in_each_class[key]
            print(f"Key {key} not found in class_and_occurence_per_feature. Getting all features of all_feature_in_each_class for this key.")
            continue
        if key not in target_occurence_min_per_class.keys():
            dict_feature_per_class[key] = all_feature_in_each_class[key]
            print(f"Key {key} not found in target_occurence_min_per_class. Getting all features of all_feature_in_each_class for this key.")
            continue
        dict_feature_per_class[key] = []
        for feature in all_feature_in_each_class[key]:
            min_occurence_for_this_class = target_occurence_min_per_class[key]
            if class_and_occurence_per_feature[feature][1] >= min_occurence_for_this_class:
                dict_feature_per_class[key].append(feature)


    num_row = len(list(dict_feature_per_class.keys()))
    num_col = len(list_image_name_to_test)
    fig, axes = plt.subplots(nrows=num_col, ncols=num_col, figsize=(16, 4 * num_row))

    vmin, vmax = float("inf"), float("-inf")
    heatmap_per_class_per_image = {}
    for key in dict_feature_per_class.keys():
        heatmap_per_image = {}
        for image_name in list_image_name_to_test:
            heatmap_file_name = f"{image_name}.npz"
            heatmaps, _ = load_and_process_flat(heatmap_file_name, base_path, layer_name="layer_40")
            feature_index = np.array(dict_feature_per_class[key])
            cluster_heatmaps = heatmaps[feature_index]
            combined_heatmap = np.sum(cluster_heatmaps, axis=0)
            vmin = min(vmin, combined_heatmap.min())
            vmax = max(vmax, combined_heatmap.max())
            heatmap_per_image[image_name] = combined_heatmap
        heatmap_per_class_per_image[key] = heatmap_per_image


    num_methods = len(methods)

    for cluster_idx, cluster_num in enumerate(range(min_cluster, max_cluster + 1)):
        fig, axes = plt.subplots(nrows=num_methods, ncols=max(4, cluster_idx + min_cluster), figsize=(16, 4 * num_methods))
        heatmaps_by_method = {}
        vmin, vmax = float("inf"), float("-inf")

        for method in methods:
            cluster_label = data[f"{cluster_num}_{method}"]
            heatmaps, _ = load_and_process_flat(picture_name, base_path, layer_name="layer_40")
            heatmaps_by_cluster = {}

            for cluster in set(cluster_label):
                if cluster == -1:
                    continue
                indices = np.where(cluster_label == cluster)[0]
                cluster_heatmaps = heatmaps[indices]
                heatmaps_by_cluster[cluster] = cluster_heatmaps

                if cluster_heatmaps.shape[0] > 0:
                    combined_heatmap = np.sum(cluster_heatmaps, axis=0)
                    vmin = min(vmin, combined_heatmap.min())
                    vmax = max(vmax, combined_heatmap.max())

            heatmaps_by_method[method] = heatmaps_by_cluster

        for method_idx, method in enumerate(methods):
            heatmaps_by_cluster = heatmaps_by_method[method]

            for cluster, heatmaps in heatmaps_by_cluster.items():
                if heatmaps.shape[0] == 0:
                    continue
                combined_heatmap = np.sum(heatmaps, axis=0)

                ax = axes[method_idx, cluster] if num_methods > 1 else axes[cluster]
                im = ax.imshow(combined_heatmap, cmap="jet", alpha=0.7, vmin=vmin, vmax=vmax)
                ax.axis("off")
                ax.set_title(f"Cluster {cluster}\n{method}\n{heatmaps.shape[0]} éléments")

            # Ajouter la barre des couleurs dans la dernière colonne
            cax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # Position relative de la colorbar
            fig.colorbar(im, cax=cax)

        #plt.tight_layout(rect=[0, 0, 0.9, 1])  # Ajustement pour laisser de la place à la colorbar
        plt.show()

In [21]:
def visualize_clusters(clusters_analysis_file_base_name : str,
                       base_path : str,
                       target_occurence_min_per_class : dict,
                       list_image_name_to_test : list):
    all_feature_in_each_class = np.load(f"{base_path}/clusters_analysis/{clusters_analysis_file_base_name}_all_feature_in_each_class.npz")
    class_and_occurence_per_feature = np.load(f"{base_path}/clusters_analysis/{clusters_analysis_file_base_name}_class_and_occurence_per_feature.npz")

    dict_feature_per_class = {}
    for key in all_feature_in_each_class.keys():
        if key not in class_and_occurence_per_feature.keys():
            dict_feature_per_class[key] = all_feature_in_each_class[key]
            print(f"Key {key} not found in class_and_occurence_per_feature. Getting all features of all_feature_in_each_class for this key.")
            continue
        if key not in target_occurence_min_per_class.keys():
            dict_feature_per_class[key] = all_feature_in_each_class[key]
            print(f"Key {key} not found in target_occurence_min_per_class. Getting all features of all_feature_in_each_class for this key.")
            continue
        dict_feature_per_class[key] = []
        for feature in all_feature_in_each_class[key]:
            min_occurence_for_this_class = target_occurence_min_per_class[key]
            if class_and_occurence_per_feature[feature][1] >= min_occurence_for_this_class:
                dict_feature_per_class[key].append(feature)


    num_row = len(list(dict_feature_per_class.keys()))
    num_col = len(list_image_name_to_test)

    vmin, vmax = float("inf"), float("-inf")
    heatmap_per_class_per_image = {}
    for key in dict_feature_per_class.keys():
        heatmap_per_image = {}
        for image_name in list_image_name_to_test:
            heatmap_file_name = f"{image_name}.npz"
            heatmaps, _ = load_and_process_flat(heatmap_file_name, base_path, layer_name="layer_40")
            feature_index = np.array(dict_feature_per_class[key])
            cluster_heatmaps = heatmaps[feature_index]
            combined_heatmap = np.sum(cluster_heatmaps, axis=0)
            vmin = min(vmin, combined_heatmap.min())
            vmax = max(vmax, combined_heatmap.max())
            heatmap_per_image[image_name] = combined_heatmap
        heatmap_per_class_per_image[key] = heatmap_per_image


    fig, axes = plt.subplots(nrows=num_row, ncols=num_col, figsize=(4 * num_col, 4 * num_row))
    fig.tight_layout(pad=3.0)

    # Parcourir les classes (lignes) et les images (colonnes) pour afficher les heatmaps
    for row_idx, (class_name, heatmap_per_image) in enumerate(heatmap_per_class_per_image.items()):
        for col_idx, image_name in enumerate(list_image_name_to_test):
            ax = axes[row_idx, col_idx]
            heatmap = heatmap_per_image[image_name]

            im = ax.imshow(heatmap, cmap='viridis', vmin=vmin, vmax=vmax)
            ax.set_title(f"{class_name} - {image_name}", fontsize=10)
            ax.axis('off')

    # Ajouter une barre de couleur commune
    cbar = fig.colorbar(im, ax=axes, orientation='horizontal', fraction=0.05, pad=0.05)
    cbar.set_label('Activation Intensity')

    plt.show()

TODO : put file name in key and in values a tuple with min and max born and cluster method to display

In [9]:
base_path = "./data/v6/clusters"

In [10]:
def load_and_display_npz(base_path: str, file_name: str):
    """
    Charge un fichier .npz à partir de base_path et file_name,
    affiche toutes les clés et leurs valeurs.

    Args:
        base_path (str): Chemin de base.
        file_name (str): Nom du fichier (sans extension .npz).
    """
    file_path = os.path.join(base_path, "clusters_analysis", f"{file_name}.npz")

    try:
        with np.load(file_path, allow_pickle=True) as data:
            print(f"Fichier chargé : {file_path}\n")
            print("Clés dans le fichier:")
            for key in data.keys():
                print(f"- {key}")

            print("\nValeurs associées:")
            for key, value in data.items():
                print(f"{key}: {value}\n")
    except FileNotFoundError:
        print(f"Erreur : Le fichier '{file_path}' n'a pas été trouvé.")
    except Exception as e:
        print(f"Une erreur est survenue : {e}")

In [11]:
load_and_display_npz(base_path, "e1_all_feature_in_each_class")

Fichier chargé : ./data/v6/clusters\clusters_analysis\e1_all_feature_in_each_class.npz

Clés dans le fichier:
- None
- Eye
- Beak
- Wing
- Tail

Valeurs associées:
None: [  0   2   5   8  10  11  12  13  18  19  22  28  42  46  51  52  56  57
  61  65  69  81  83  84  86  88  95 101 113 124 125 138 139 155 157 168
 170 175 183 190 191 195 200 201 224 225 226 229 240 245 248 264 272 274
 282 294 297 299 304 305 312 319 321 325 330 332 343 353 355 359 360 362
 364 373 386 389 393 398 399 407 408 415 421 431 434 440 442 449 453 456
 458 473 477 481 494 495 501 506  27  40  53  62  70  71  74  76  79  97
  99 103 117 122 130 137 145 147 156 181 186 187 217 223 228 243 247 253
 257 262 263 267 281 318 336 345 370 387 392 411 430 468 479 483 505  14
  20  68  77 132 136 184 271 290 310 326 327 363 365 396 404 413 429 438
 478  37  39 104 135 176 189 209 211 227 259 260 303 314 320 323 354 372
 405 439 441 445 464 487 493]

Eye: [  1   3   7   9  17  24  25  31  32  33  34  36  38  41  45  47

In [12]:
load_and_display_npz(base_path, "e1_class_and_occurence_per_feature")

Fichier chargé : ./data/v6/clusters\clusters_analysis\e1_class_and_occurence_per_feature.npz

Clés dans le fichier:
- 0
- 2
- 5
- 8
- 10
- 11
- 12
- 13
- 18
- 19
- 22
- 28
- 42
- 46
- 51
- 52
- 56
- 57
- 61
- 65
- 69
- 81
- 83
- 84
- 86
- 88
- 95
- 101
- 113
- 124
- 125
- 138
- 139
- 155
- 157
- 168
- 170
- 175
- 183
- 190
- 191
- 195
- 200
- 201
- 224
- 225
- 226
- 229
- 240
- 245
- 248
- 264
- 272
- 274
- 282
- 294
- 297
- 299
- 304
- 305
- 312
- 319
- 321
- 325
- 330
- 332
- 343
- 353
- 355
- 359
- 360
- 362
- 364
- 373
- 386
- 389
- 393
- 398
- 399
- 407
- 408
- 415
- 421
- 431
- 434
- 440
- 442
- 449
- 453
- 456
- 458
- 473
- 477
- 481
- 494
- 495
- 501
- 506
- 1
- 3
- 7
- 9
- 17
- 24
- 25
- 27
- 31
- 32
- 33
- 34
- 36
- 38
- 40
- 41
- 45
- 47
- 48
- 49
- 53
- 54
- 62
- 63
- 66
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 79
- 80
- 87
- 90
- 91
- 92
- 93
- 94
- 97
- 98
- 99
- 100
- 103
- 106
- 107
- 109
- 110
- 112
- 114
- 117
- 121
- 122
- 123
- 126
- 128
- 129
- 130
- 134
- 137
- 142
- 

In [14]:
SUFFIX_NB_LEVEL_OCCURENCE_PER_CLASS = "nb_level_occurence_per_class"
SUFFIX_CLASS_AND_OCCURENCE_PER_FEATURE = "class_and_occurence_per_feature"
SUFFIX_ALL_FEATURE_IN_EACH_CLASS = "all_feature_in_each_class"

In [15]:
list_image_name_to_test = [
    "n01443537_10008",
    "n01443537_10009",
]

In [16]:
cluster_analysis_file_to_use = "e1"

In [17]:
file_suffix = f"_{cluster_analysis_file_to_use}"

In [18]:
directory_clusters = f"{base_path}/clusters_analysis"
files_clusters = list_files(directory_clusters, extension=".npz")

In [19]:
with open(f"{base_path}/imagenet_class_index_reversed.json", "r") as file:
    dictionary_class_index = json.load(file)

In [20]:
for image_name in list_image_name_to_test:
    heatmap_path = f"{base_path}/heatmap/{image_name}.npz"
    file_path = f"{directory_clusters}/{key}"
    if not os.path.exists(file_path):
        print(f"File {file_path} does not exist.")
        continue
    min_cluster, max_cluster, methods = dictionary_file_to_which_cluster_visualization[key]
    data = np.load(file_path)
    picture_name = get_filename_without_extension(file_path)
    class_id = extract_code(picture_name)
    heatmap_file_name = f"{extract_code_2(picture_name)}.npz"
    print(f"Visualizing clusters for {dictionary_class_index[class_id]} ({class_id})")
    visualize_clusters(data, min_cluster, max_cluster, methods, heatmap_file_name, base_path)

NameError: name 'key' is not defined