In [2]:
import torch
import torch.nn as nn

In [34]:
class Net(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(1, 6, 1)
        self.fc1 = nn.Linear(6*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        print(x)
        x = torch.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [38]:
model = Net()
input = torch.ones(1, 1, 4, 4)
output = model(input)
print(output.shape)

tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])
torch.Size([1, 10])


In [39]:
from captum.attr import IntegratedGradients
from captum.metrics import infidelity
saliency = IntegratedGradients(model)
attr = saliency.attribute(input, target=3) 
print(attr.shape)

tensor([[[[5.6680e-04, 5.6680e-04, 5.6680e-04, 5.6680e-04],
          [5.6680e-04, 5.6680e-04, 5.6680e-04, 5.6680e-04],
          [5.6680e-04, 5.6680e-04, 5.6680e-04, 5.6680e-04],
          [5.6680e-04, 5.6680e-04, 5.6680e-04, 5.6680e-04]]],


        [[[2.9840e-03, 2.9840e-03, 2.9840e-03, 2.9840e-03],
          [2.9840e-03, 2.9840e-03, 2.9840e-03, 2.9840e-03],
          [2.9840e-03, 2.9840e-03, 2.9840e-03, 2.9840e-03],
          [2.9840e-03, 2.9840e-03, 2.9840e-03, 2.9840e-03]]],


        [[[7.3230e-03, 7.3230e-03, 7.3230e-03, 7.3230e-03],
          [7.3230e-03, 7.3230e-03, 7.3230e-03, 7.3230e-03],
          [7.3230e-03, 7.3230e-03, 7.3230e-03, 7.3230e-03],
          [7.3230e-03, 7.3230e-03, 7.3230e-03, 7.3230e-03]]],


        [[[1.3568e-02, 1.3568e-02, 1.3568e-02, 1.3568e-02],
          [1.3568e-02, 1.3568e-02, 1.3568e-02, 1.3568e-02],
          [1.3568e-02, 1.3568e-02, 1.3568e-02, 1.3568e-02],
          [1.3568e-02, 1.3568e-02, 1.3568e-02, 1.3568e-02]]],


        [[[2.1695e-02, 2

In [40]:
def perturb_fn(inputs):
    noise = torch.normal(0, 0.003, inputs.shape)
    return noise, inputs - noise

infid = infidelity(model, perturb_fn, input, attr, target=3)

tensor([[[[1.0012, 0.9958, 1.0041, 0.9942],
          [0.9982, 0.9988, 1.0068, 1.0029],
          [0.9986, 0.9990, 1.0009, 0.9978],
          [0.9991, 1.0022, 1.0002, 0.9934]]],


        [[[0.9958, 0.9992, 0.9980, 1.0018],
          [0.9990, 1.0004, 0.9920, 0.9967],
          [0.9954, 0.9997, 0.9952, 0.9968],
          [0.9958, 0.9985, 0.9979, 0.9943]]],


        [[[0.9965, 0.9963, 1.0071, 1.0003],
          [0.9982, 0.9989, 0.9967, 1.0035],
          [0.9999, 0.9990, 1.0031, 0.9972],
          [0.9971, 0.9995, 1.0023, 1.0029]]],


        [[[0.9989, 1.0040, 1.0002, 0.9998],
          [0.9992, 1.0034, 0.9981, 1.0033],
          [0.9948, 0.9975, 1.0002, 0.9985],
          [0.9966, 1.0005, 0.9991, 1.0028]]],


        [[[1.0021, 1.0018, 1.0015, 0.9985],
          [0.9952, 0.9980, 1.0003, 0.9988],
          [0.9978, 1.0034, 1.0032, 0.9997],
          [1.0008, 0.9960, 0.9991, 0.9951]]],


        [[[0.9999, 0.9983, 1.0006, 0.9979],
          [0.9964, 1.0020, 1.0029, 0.9977],
          [1