In [1]:
import torch
from EnsembleXAI.Metrics import accordance_precision, accordance_recall

In [2]:
# Custom F1_score so that when precision and recall = 0 (F1 nan) that is just interpretted as zero 
# (throwing away the consideration of that point or making really bad -> just bad)
def F1_score(
        explanations: torch.Tensor, masks: torch.Tensor, threshold: float = 0.0
) -> float:
    acc_recall = accordance_recall(explanations, masks, threshold=threshold)
    acc_prec = accordance_precision(explanations, masks, threshold=threshold)
    values = 2 * (acc_recall * acc_prec) / (acc_recall + acc_prec)
    values[values != values] = 0
    value = torch.sum(values) / values.shape[0]
    return value.item()

In [4]:
ground_truth_masks = torch.load("ImageNet/ground_truth_masks.pt").cuda()

In [5]:
integrated_gradients = torch.load("Results/attributions_integrated_gradients_0.41_1714368472.568931.pt").cuda()

In [6]:
F1_score(integrated_gradients, ground_truth_masks)

0.40941235423088074

In [7]:
saliency = torch.load("Results/attributions_saliency_0.54_1714400644.2581055.pt")
F1_score(saliency, ground_truth_masks)

0.5377201437950134

In [8]:
gradient_shap = torch.load("Results/attributions_gradient_shap_0.41_1714400936.6504378.pt")
F1_score(gradient_shap, ground_truth_masks)

0.40826356410980225

In [9]:
lime = torch.load("Results/attributions_lime_0.54_1714402004.904838.pt")
F1_score(lime, ground_truth_masks)

0.5439344644546509

In [14]:
occulsion = torch.load("Results/attributions_occulsion_0.47_1714404377.9766824.pt")
F1_score(occulsion, ground_truth_masks)

0.46652284264564514

In [18]:
SVS = torch.load("Results/attributions_shapley_value_sampling_0.55_1714416655.5245132.pt")
F1_score(SVS, ground_truth_masks)

0.5504565834999084

In [19]:
fa = torch.load("Results/attributions_feature_ablation_0.50_1714419732.9062037.pt")
F1_score(fa, ground_truth_masks)

0.5002642273902893

In [20]:
ks = torch.load("Results/attributions_kernel_shap_0.52_1714421025.4944754.pt")
F1_score(ks, ground_truth_masks)

0.5163415670394897

In [21]:
nt = torch.load("Results/attributions_noise_tunnel_0.54_1714425519.9697309.pt")
F1_score(nt, ground_truth_masks)

0.5377201437950134

In [22]:
from EnsembleXAI.Ensemble import normEnsembleXAI
from EnsembleXAI.Normalization import mean_var_normalize

attributions = {
    'attributions_ig': torch.load("Results/attributions_integrated_gradients_0.41_1714368472.568931.pt"),
    'attributions_s': torch.load("Results/attributions_saliency_0.54_1714400644.2581055.pt"),
    'attributions_gs': torch.load("Results/attributions_gradient_shap_0.41_1714400936.6504378.pt"),
    'attributions_gb': torch.load('ImageNet/attributions_gb.pt'),
    'attributions_d': torch.load('ImageNet/attributions_d.pt'),
    'attributions_ixg': torch.load('ImageNet/attributions_ixg.pt'),
    'attributions_l': torch.load("Results/attributions_lime_0.54_1714402004.904838.pt"),
    'attributions_o': torch.load("Results/attributions_occulsion_0.47_1714404377.9766824.pt"),
    'attributions_svs': torch.load("Results/attributions_shapley_value_sampling_0.55_1714416655.5245132.pt"),
    'attributions_fa': torch.load("Results/attributions_feature_ablation_0.50_1714419732.9062037.pt"),
    'attributions_ks': torch.load("Results/attributions_kernel_shap_0.52_1714421025.4944754.pt"),
    'attributions_nt': torch.load("Results/attributions_noise_tunnel_0.54_1714425519.9697309.pt"),
}

normalized_attributions = {attr: mean_var_normalize(attributions[attr]) for attr in attributions}

explanations = torch.stack([normalized_attributions[attr] for attr in normalized_attributions], dim=1)

agg = normEnsembleXAI(explanations.detach(), aggregating_func='avg')

torch.save(agg, "ImageNet/tuned_avg_agg.pt")

In [23]:
F1_score(agg, ground_truth_masks)

0.5144526362419128

In [24]:
F1_score(torch.load("ImageNet/agg.pt"), ground_truth_masks)

0.4830857217311859