In [1]:
import torch
import torchvision.transforms as T
from torchvision.models.vgg import vgg16_bn
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
# Importer ZennitCRP
from zennit.composites import EpsilonPlusFlat
from zennit.canonizers import SequentialMergeBatchNorm
from crp.attribution import CondAttribution
from crp.helper import get_layer_names
from crp.concepts import ChannelConcept

In [2]:
import torch
from torchvision.models.vgg import vgg16_bn
import torchvision.transforms as T
from PIL import Image

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = vgg16_bn(True).to(device)
model.eval()

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])



In [3]:
import math

def betterPrintHeatmap(heatmaps):
    heatmaps = [h.detach().cpu().numpy() for h in heatmaps[:10]]
    number_line = math.ceil(len(heatmaps) / 5)
    fig, axes = plt.subplots(number_line, 6, figsize=(18, 6))

    vmin = min(h.min() for h in heatmaps)
    vmax = max(h.max() for h in heatmaps)
    print(f"Min value for a pixel: {vmin}, Max value for a pixel: {vmax}")
    heatmap_element = 0
    for i, ax in enumerate(axes.flat):
        row, col = divmod(i, 6)
        if heatmap_element >= len(heatmaps):
            ax.axis("off")
            continue
        if col > 4:
            ax.axis("off")
            continue
        im = ax.imshow(heatmaps[heatmap_element], cmap="seismic", interpolation="nearest", vmin=vmin, vmax=vmax)
        ax.set_title(f"Heatmap {heatmap_element+1}")
        heatmap_element += 1


    # Ajouter une barre de couleur verticale à droite de la grille
    fig.colorbar(im, ax=axes, orientation='vertical', location='right')

    # Ajuster l'espacement pour que les heatmaps ne se chevauchent pas
    plt.subplots_adjust(wspace=0.3, hspace=0.3)

    # Afficher la grille
    plt.show()

In [4]:
def compute_feature_importance(model, input_tensor, layer_idx, allFeaturesIDX, pred_class):
    """
    Calcule l'importance globale des features d'une couche donnée pour une classe prédite
    après désactivation de toutes les features spécifiées dans allFeaturesIDX.

    Arguments :
    - model : le modèle VGG16
    - input_tensor : l'image d'entrée sous forme de tenseur
    - layer_idx : l'index de la couche (ex : 40)
    - allFeaturesIDX : liste des indices des features à désactiver (list)
    - pred_class : la classe prédite initialement

    Retourne :
    - La différence entre le score original et le score après désactivation
    """

    # Obtenir la probabilité originale de la classe prédite
    with torch.no_grad():
        output_original = model(input_tensor)
        probs_original = torch.nn.functional.softmax(output_original, dim=1)
        original_score = probs_original[0, pred_class].item()

    def zero_out_features(module, input, output):
        for idx in allFeaturesIDX:
            output[:, idx, :, :] = 0  # Désactiver chaque feature spécifiée
        return output

    # Ajouter un hook temporaire
    hook = model.features[layer_idx].register_forward_hook(zero_out_features)

    # Faire une prédiction avec les features désactivées
    with torch.no_grad():
        output_disabled = model(input_tensor)
        probs_disabled = torch.nn.functional.softmax(output_disabled, dim=1)
        new_score = probs_disabled[0, pred_class].item()

    # Supprimer le hook
    hook.remove()

    # Calcul de l'importance globale des features désactivées
    importance = original_score - new_score
    print(original_score)
    print(new_score)
    print(f"Importance globale après désactivation des features : {importance:.4f}")

    return importance

In [5]:
import torch
from torchvision.models.vgg import vgg16_bn

def compute_feature_importance_all_input_tensor(model, all_imput_tensor_of_path, layer_idx, num_features, pred_class):
    sum_original_score = 0
    mean_original_score = 0
    feature_score = {}
    feature_importance = {}
    for _, tuple_input_pred_prob in all_imput_tensor_of_path.items() :
        input_tensor = tuple_input_pred_prob[0]
        with torch.no_grad():
            output_original = model(input_tensor)
            probs_original = torch.nn.functional.softmax(output_original, dim=1)
            original_score = probs_original[0, pred_class].item()
            sum_original_score += original_score
    mean_original_score = sum_original_score / len(all_imput_tensor_of_path)

    for feature_idx in range(num_features):
        def zero_out_feature(module, input, output, feature_idx=feature_idx):
            output[:, feature_idx, :, :] = 0  # Désactiver la feature
            return output

        hook = model.features[layer_idx].register_forward_hook(zero_out_feature)

        sum_new_score = 0
        mean_new_score = 0
        for _, tuple_input_pred_prob in all_imput_tensor_of_path.items():
            input_tensor = tuple_input_pred_prob[0]
            with torch.no_grad():
                output_disabled = model(input_tensor)
                probs_disabled = torch.nn.functional.softmax(output_disabled, dim=1)
                new_score = probs_disabled[0, pred_class].item()
                sum_new_score += new_score
        mean_new_score = sum_new_score / len(all_imput_tensor_of_path)
        hook.remove()

        feature_score[feature_idx] = mean_new_score

        # Affichage de progression
        importance = mean_original_score - mean_new_score
        feature_importance[feature_idx] = importance
        print(f"Feature {feature_idx+1}/{num_features} - Importance: {importance:.4f}")

    sorted_importance = dict(sorted(feature_importance.items(), key=lambda item: item[1], reverse=True))

    return sorted_importance, mean_original_score, feature_score

In [6]:
def load_images(folder_path):
    all_imput_tensor_of_path = {}
    for filename in os.listdir(folder_path):
        if filename.endswith(".jpeg"):
            image_path = os.path.join(folder_path, filename)

            image = Image.open(image_path).convert("RGB")
            input_tensor = transform(image).unsqueeze(0).to(device)
            all_imput_tensor_of_path[filename] = input_tensor
            #print(f"Image {filename} chargée et transformée.")
    return all_imput_tensor_of_path

In [7]:
def separate_well_bad_predicted(all_imput_tensor_of_path, expected_class):
    all_wrong_predicted = {}
    all_well_predicted = {}
    for name, input_tensor in all_imput_tensor_of_path.items() :
        with torch.no_grad():
            output_original = model(input_tensor)
            pred_class = torch.argmax(output_original, dim=1).item()
            probs_original = torch.nn.functional.softmax(output_original, dim=1)
            original_score = probs_original[0, pred_class].item()
        if pred_class != expected_class :
            all_wrong_predicted[name] = (input_tensor, pred_class, original_score)
            continue
        all_well_predicted[name] = (input_tensor, pred_class, original_score)
    print(f"Nombre d'images mal prédites : {len(all_wrong_predicted)}")
    return all_wrong_predicted, all_well_predicted

In [8]:
def positive_negative_features(importance_dict):
    positiveFeatures = []
    negativeFeatures = []

    for feature, importance in importance_dict.items():
        if importance >= 0:
            positiveFeatures.append(feature)
        elif importance < 0:
            negativeFeatures.append(feature)

    print(f"Il y a {len(positiveFeatures)} features positives et {len(negativeFeatures)} features négatives.")
    return positiveFeatures, negativeFeatures

In [12]:
folder_path_toucan = r"/home/timotheewsl/PER/Guillaume/data/train/n01843383" #toucan
folder_path_hornbill = r"/home/timotheewsl/PER/Guillaume/data/train/n01829413"#hornbill

In [13]:
expected_class_toucan = 96
expected_class_hornbill = 93

In [14]:
all_imput_tensor_of_path_toucan = load_images(folder_path_toucan)
all_imput_tensor_of_path_hornbill = load_images(folder_path_hornbill)

In [15]:
all_wrong_predicted_toucan, all_well_predicted_toucan = separate_well_bad_predicted(all_imput_tensor_of_path_toucan, expected_class_toucan)
all_wrong_predicted_hornbill, all_well_predicted_hornbill = separate_well_bad_predicted(all_imput_tensor_of_path_hornbill, expected_class_hornbill)

Nombre d'images mal prédites : 35
Nombre d'images mal prédites : 52


In [16]:
importance_dict_toucan, mean_original_score_toucan, feature_score_toucan = compute_feature_importance_all_input_tensor(model, all_well_predicted_toucan, layer_idx=40, num_features=512, pred_class=expected_class_toucan)

Feature 1/512 - Importance: -0.0002
Feature 2/512 - Importance: 0.0002
Feature 3/512 - Importance: -0.0007
Feature 4/512 - Importance: -0.0011
Feature 5/512 - Importance: 0.0003
Feature 6/512 - Importance: -0.0002
Feature 7/512 - Importance: -0.0001
Feature 8/512 - Importance: 0.0006
Feature 9/512 - Importance: -0.0001
Feature 10/512 - Importance: -0.0008
Feature 11/512 - Importance: -0.0015
Feature 12/512 - Importance: -0.0003
Feature 13/512 - Importance: -0.0011
Feature 14/512 - Importance: -0.0016
Feature 15/512 - Importance: 0.0005
Feature 16/512 - Importance: 0.0026
Feature 17/512 - Importance: -0.0013
Feature 18/512 - Importance: -0.0002
Feature 19/512 - Importance: 0.0006
Feature 20/512 - Importance: -0.0018
Feature 21/512 - Importance: 0.0001
Feature 22/512 - Importance: -0.0009
Feature 23/512 - Importance: -0.0004
Feature 24/512 - Importance: 0.0000
Feature 25/512 - Importance: 0.0015
Feature 26/512 - Importance: 0.0001
Feature 27/512 - Importance: -0.0001
Feature 28/512 - Imp

In [17]:
importance_dict_hornbill, mean_original_score_hornbill, feature_score_hornbill = compute_feature_importance_all_input_tensor(model, all_well_predicted_hornbill, layer_idx=40, num_features=512, pred_class=expected_class_hornbill)

Feature 1/512 - Importance: -0.0006
Feature 2/512 - Importance: 0.0027
Feature 3/512 - Importance: -0.0006
Feature 4/512 - Importance: 0.0024
Feature 5/512 - Importance: 0.0023
Feature 6/512 - Importance: -0.0001
Feature 7/512 - Importance: -0.0004
Feature 8/512 - Importance: -0.0011
Feature 9/512 - Importance: -0.0008
Feature 10/512 - Importance: 0.0022
Feature 11/512 - Importance: 0.0027
Feature 12/512 - Importance: -0.0009
Feature 13/512 - Importance: 0.0011
Feature 14/512 - Importance: 0.0008
Feature 15/512 - Importance: 0.0049
Feature 16/512 - Importance: 0.0070
Feature 17/512 - Importance: 0.0017
Feature 18/512 - Importance: -0.0002
Feature 19/512 - Importance: 0.0044
Feature 20/512 - Importance: 0.0102
Feature 21/512 - Importance: -0.0001
Feature 22/512 - Importance: 0.0001
Feature 23/512 - Importance: -0.0007
Feature 24/512 - Importance: -0.0000
Feature 25/512 - Importance: -0.0013
Feature 26/512 - Importance: 0.0001
Feature 27/512 - Importance: 0.0001
Feature 28/512 - Importan

In [18]:
positiveFeatures_toucan, negativeFeatures_toucan = positive_negative_features(importance_dict_toucan)

Il y a 270 features positives et 242 features négatives.


In [19]:
positiveFeatures_hornbill, negativeFeatures_hornbill = positive_negative_features(importance_dict_hornbill)

Il y a 250 features positives et 262 features négatives.


In [20]:
def categorize_indices(dict1, dict2):
    both_pos_or_null = []
    both_neg = []
    first_pos_second_neg = []
    first_neg_second_pos = []

    all_keys = set(dict1.keys()).union(set(dict2.keys()))

    for key in all_keys:
        score1 = dict1.get(key, 0)
        score2 = dict2.get(key, 0)

        if score1 >= 0 and score2 >= 0:
            both_pos_or_null.append(key)
        elif score1 < 0 and score2 < 0:
            both_neg.append(key)
        elif score1 >= 0 and score2 < 0:
            first_pos_second_neg.append(key)
        elif score1 < 0 and score2 >= 0:
            first_neg_second_pos.append(key)

    return both_pos_or_null, both_neg, first_pos_second_neg, first_neg_second_pos

res = categorize_indices(importance_dict_toucan, importance_dict_hornbill)
print("Les IDX avec score positif ou nul pour les deux :", res[0])
print("Les IDX avec score négatif pour les deux :", res[1])
print("Les IDX avec premier positif ou nul et deuxième négatif :", res[2])
print("Les IDX avec premier négatif et deuxième positif ou nul :", res[3])

Les IDX avec score positif ou nul pour les deux : [1, 4, 14, 15, 18, 25, 27, 30, 33, 37, 40, 44, 51, 54, 60, 62, 74, 83, 87, 97, 98, 101, 104, 106, 107, 111, 121, 122, 125, 129, 135, 139, 149, 152, 153, 154, 155, 157, 159, 167, 168, 172, 184, 186, 188, 191, 192, 194, 196, 200, 204, 211, 226, 234, 246, 249, 255, 257, 269, 271, 272, 276, 280, 281, 283, 296, 298, 301, 305, 310, 311, 312, 313, 318, 325, 328, 331, 332, 337, 340, 344, 351, 355, 357, 367, 377, 380, 381, 392, 398, 399, 402, 417, 420, 421, 430, 432, 440, 442, 446, 469, 471, 474, 476, 477, 483, 500, 503, 504, 507, 511]
Les IDX avec score négatif pour les deux : [0, 2, 5, 6, 8, 11, 17, 22, 28, 49, 50, 59, 64, 67, 71, 72, 85, 92, 103, 108, 109, 120, 123, 124, 127, 131, 142, 147, 156, 169, 170, 175, 180, 185, 189, 190, 198, 201, 202, 212, 214, 215, 217, 219, 222, 224, 227, 228, 238, 241, 254, 258, 260, 261, 264, 268, 273, 274, 279, 289, 293, 308, 317, 324, 329, 335, 346, 348, 354, 358, 361, 363, 371, 379, 382, 383, 393, 413, 414, 4

In [21]:
#number in each list
print(f"both_pos_or_null : {len(res[0])}")
print(f"both_neg : {len(res[1])}")
print(f"first_pos_second_neg : {len(res[2])}")
print(f"first_neg_second_pos : {len(res[3])}")

both_pos_or_null : 111
both_neg : 103
first_pos_second_neg : 159
first_neg_second_pos : 139
