In [1]:
import torch
from EnsembleXAI.Ensemble import normEnsembleXAI
from EnsembleXAI.Normalization import mean_var_normalize
from EnsembleXAI.Metrics import accordance_precision, accordance_recall

In [8]:
# Custom F1_score so that when precision and recall = 0 (F1 nan) that is just interpretted as zero 
# (throwing away the consideration of that point or making really bad -> just bad)
def F1_score(
        explanations: torch.Tensor, masks: torch.Tensor, threshold: float = 0.0
) -> float:
    acc_recall = accordance_recall(explanations, masks, threshold=threshold)
    acc_prec = accordance_precision(explanations, masks, threshold=threshold)
    values = 2 * (acc_recall * acc_prec) / (acc_recall + acc_prec)
    values[values != values] = 0
    value = torch.sum(values) / values.shape[0]
    return value.item()

In [3]:
attributions = {
    'attributions_ig': torch.load('ImageNet/attributions_ig.pt'),
    'attributions_s': torch.load('ImageNet/attributions_s.pt'),
    'attributions_gs': torch.load('ImageNet/attributions_gs.pt'),
    'attributions_gb': torch.load('ImageNet/attributions_gb.pt'),
    'attributions_d': torch.load('ImageNet/attributions_d.pt'),
    'attributions_ixg': torch.load('ImageNet/attributions_ixg.pt'),
    'attributions_l': torch.load('ImageNet/attributions_l.pt'),
    'attributions_o': torch.load('ImageNet/attributions_o.pt'),
    'attributions_svs': torch.load('ImageNet/attributions_svs.pt'),
    'attributions_fa': torch.load('ImageNet/attributions_fa.pt'),
    'attributions_ks': torch.load('ImageNet/attributions_ks.pt'),
    'attributions_nt': torch.load('ImageNet/attributions_nt.pt'),
}

normalized_attributions = {attr: mean_var_normalize(attributions[attr]) for attr in attributions}

In [4]:
ground_truth_masks = torch.load("ImageNet/ground_truth_masks.pt").cuda()

In [9]:
attributions_weight = {
    attribution: F1_score(attributions[attribution], ground_truth_masks) for attribution in attributions
}
weight_sum = sum([attributions_weight[attribution] for attribution in attributions_weight])
attributions_weight = {
    attribution: attributions_weight[attribution] / weight_sum for attribution in attributions
}

In [10]:
attributions_weight

{'attributions_ig': 0.07576233122529771,
 'attributions_s': 0.09991986593436414,
 'attributions_gs': 0.07570613808468314,
 'attributions_gb': 0.07502384027356462,
 'attributions_d': 0.0710520701080619,
 'attributions_ixg': 0.07553182139136243,
 'attributions_l': 0.08509368208397593,
 'attributions_o': 0.08121739650805401,
 'attributions_svs': 0.1013098916247693,
 'attributions_fa': 0.0910585615316911,
 'attributions_ks': 0.09256187064423206,
 'attributions_nt': 0.07576253058994363}

In [11]:
attributions_weighted = {
    attribution: attributions_weight[attribution] * attributions[attribution] for attribution in attributions
}

In [13]:
threshold = 1 / len(attributions)
print(threshold)

0.08333333333333333


In [14]:
threshold = 1 / len(attributions)
attributions_trimmed = {}
for attribution in attributions_weight:
    if attributions_weight[attribution] < threshold:
        continue
    attributions_trimmed[attribution] = attributions[attribution]

In [15]:
attributions_trimmed.keys()

dict_keys(['attributions_s', 'attributions_l', 'attributions_svs', 'attributions_fa', 'attributions_ks'])

In [16]:
normalized_weighted_attributions = {attr: mean_var_normalize(attributions_weighted[attr]) for attr in attributions_weighted}
normalized_trimmed_attributions = {attr: mean_var_normalize(attributions_trimmed[attr]) for attr in attributions_trimmed}

In [17]:
_attributions = {attr: mean_var_normalize(normalized_weighted_attributions[attr]) for attr in attributions}
explanations = torch.stack([torch.abs(normalized_weighted_attributions[attr]) for attr in normalized_weighted_attributions], dim=1)
agg = normEnsembleXAI(explanations.detach(), aggregating_func='avg')
torch.save(agg, "ImageNet/weighted_agg.pt")

In [18]:
_attributions = {attr: mean_var_normalize(normalized_trimmed_attributions[attr]) for attr in normalized_trimmed_attributions}
explanations = torch.stack([torch.abs(normalized_trimmed_attributions[attr]) for attr in normalized_trimmed_attributions], dim=1)
agg = normEnsembleXAI(explanations.detach(), aggregating_func='avg')
torch.save(agg, "ImageNet/trimmed_agg.pt")