In [1]:
import torch
from torchvision.models import resnet18, ResNet18_Weights
from captum.attr import Lime
from captum._utils.models.linear_model import *
from EnsembleXAI.Normalization import mean_var_normalize
from EnsembleXAI.Ensemble import normEnsembleXAI

### Lime attribution

In [3]:
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
model_cuda = model.cuda()

In [4]:
# Load tensors
proper_data = torch.load("ImageNet/proper_data.pt").cuda()
preds = torch.load("ImageNet/preds.pt").cuda()
proper_masks = torch.load("ImageNet/proper_masks.pt").cuda()

In [5]:
num_batches = 50

In [6]:
attributions = {
    'attributions_ig': torch.load('ImageNet/attributions_ig.pt'),
    'attributions_s': torch.load('ImageNet/attributions_s.pt'),
    'attributions_gs': torch.load('ImageNet/attributions_gs.pt'),
    'attributions_gb': torch.load('ImageNet/attributions_gb.pt'),
    'attributions_d': torch.load('ImageNet/attributions_d.pt'),
    'attributions_ixg': torch.load('ImageNet/attributions_ixg.pt'),
    'attributions_l': None,
    'attributions_o': torch.load('ImageNet/attributions_o.pt'),
    'attributions_svs': torch.load('ImageNet/attributions_svs.pt'),
    'attributions_fa': torch.load('ImageNet/attributions_fa.pt'),
    'attributions_ks': torch.load('ImageNet/attributions_ks.pt'),
    'attributions_nt': torch.load('ImageNet/attributions_nt.pt'),
}

Ridge

In [6]:
%%time
lime = Lime(model_cuda, interpretable_model=SkLearnRidge(alpha=0.01))
attributions_l = None

for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_l is None:
        attributions_l = lime.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
    else:
        temp = lime.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
        attributions_l = torch.cat((attributions_l, temp), dim = 0)

torch.save(attributions_l, "ImageNet/attributions_l_ridge.pt")



CPU times: total: 6.92 s
Wall time: 39.7 s


In [11]:
attributions["attributions_l"] = torch.load("ImageNet/attributions_l_ridge.pt")
normalized_attributions = {attr: mean_var_normalize(attributions[attr]) for attr in attributions}
explanations = torch.stack([normalized_attributions[attr] for attr in normalized_attributions], dim=1)
agg = normEnsembleXAI(explanations.detach(), aggregating_func='avg')
torch.save(agg, "ImageNet/limeridge_agg.pt")

Linear Regression

In [9]:
%%time
lime = Lime(model_cuda, interpretable_model=SkLearnLinearRegression())
attributions_l = None

for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_l is None:
        attributions_l = lime.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
    else:
        temp = lime.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
        attributions_l = torch.cat((attributions_l, temp), dim = 0)

torch.save(attributions_l, "ImageNet/attributions_l_linearregression.pt")



CPU times: total: 9.36 s
Wall time: 46.7 s


In [12]:
attributions['attributions_l'] = torch.load("ImageNet/attributions_l_linearregression.pt")
normalized_attributions = {attr: mean_var_normalize(attributions[attr]) for attr in attributions}
explanations = torch.stack([normalized_attributions[attr] for attr in normalized_attributions], dim=1)
agg = normEnsembleXAI(explanations.detach(), aggregating_func='avg')
torch.save(agg, "ImageNet/limelinearregression_agg.pt")

Lasso with default values

In [8]:
%%time
lime = Lime(model_cuda, interpretable_model=SkLearnLasso())
attributions_l = None

for i in range(num_batches):
    batch_slice = slice(i * len(proper_data) // num_batches, (i + 1) * len(proper_data) // num_batches)
    if attributions_l is None:
        attributions_l = lime.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
    else:
        temp = lime.attribute(proper_data[batch_slice], target=preds[batch_slice], feature_mask=proper_masks[batch_slice])
        attributions_l = torch.cat((attributions_l, temp), dim = 0)

torch.save(attributions_l, "ImageNet/attributions_l_lasso.pt")



CPU times: total: 8.38 s
Wall time: 41.5 s


In [9]:
attributions['attributions_l'] = torch.load("ImageNet/attributions_l_lasso.pt")
normalized_attributions = {attr: mean_var_normalize(attributions[attr]) for attr in attributions}
explanations = torch.stack([normalized_attributions[attr] for attr in normalized_attributions], dim=1)
agg = normEnsembleXAI(explanations.detach(), aggregating_func='avg')
torch.save(agg, "ImageNet/limelasso_agg.pt")