In [36]:
import timm

model = timm.create_model("vgg16")
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [37]:
import torch
from torch import nn
from types import MethodType
from copy import deepcopy

def abslrp_rule(self, output_relevance: torch.tensor) -> torch.tensor:
    input = self.saved_tensors["input"][0]
    output = self.saved_tensors["output"]
    if "abs_output" in self.saved_tensors:
        output = output + self.saved_tensors["abs_output"]
    output_relevance = output_relevance * output.sign() / (output + 1e-9)
    # apply abslrp rule
    relevance = (
        torch.autograd.grad(output, input, output_relevance, retain_graph=True)[0]
        * input
    )
    # delete saved tensors
    self.saved_tensors = {}
    return relevance

def composite_abslrp_rule(self, output_relevance: torch.tensor) -> torch.tensor:
    for child in list(self.children())[::-1]:
        output_relevance = child.explain(output_relevance)
    return output_relevance

def abslrp_forward_hook(module: nn.Module, args: tuple, output: tuple) -> None:
    print(module.__class__)
    module.saved_tensors = {}
    # if module has learnable parameters, create a temporary copy of the module and infer over absolute parameters
    if getattr(module, "weight", None) is not None:
        abs_module = deepcopy(module)
        # remove this hook from copied model
        if getattr(abs_module, "_forward_hooks", None):
            for i, hook in abs_module._forward_hooks.items():
                if hook.__name__ == "abslrp_forward_hook":
                    break
            del abs_module._forward_hooks[i]

        abs_module.weight.data = abs_module.weight.data.abs()
        if getattr(module, "bias", None) is not None:
            abs_module.bias.data = abs_module.bias.data.abs()
        abs_output = abs_module(*args)
        # save the outputs and inputs
        module.saved_tensors["abs_output"] = abs_output
    module.saved_tensors["output"] = output
    module.saved_tensors["input"] = args

def apply_rule(module: nn.Module) -> list[nn.Module]:
    children_list = list(module.children())
    if not children_list:
        module.explain = MethodType(abslrp_rule, module)
        module.register_forward_hook(abslrp_forward_hook)
        return

    module.explain = MethodType(composite_abslrp_rule, module)
    for child_module in children_list:
        apply_rule(child_module)

    return

In [38]:
apply_rule(model)

In [39]:
x = torch.randn(8, 3, 224, 224)
x.requires_grad = True
output = model(x)

<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>


In [40]:
model.explain(output.detach())

tensor([[[[-1.0072e-06, -3.8141e-05,  9.6621e-05,  ...,  1.2323e-04,
           -1.5043e-06,  1.4120e-04],
          [-1.3541e-05, -1.1183e-04,  4.1037e-05,  ...,  4.4946e-05,
            4.7424e-05, -1.3385e-04],
          [-3.0320e-04, -3.9273e-05, -6.6231e-04,  ...,  1.6909e-04,
            1.1594e-04, -1.0147e-04],
          ...,
          [-3.5727e-04,  8.9618e-04, -3.8637e-04,  ...,  1.9110e-04,
            3.9099e-05,  5.9507e-05],
          [-1.0939e-04, -3.3876e-04,  1.6876e-03,  ...,  8.6850e-05,
           -6.6800e-05,  8.3234e-06],
          [-1.5830e-04,  4.1178e-04,  2.2517e-03,  ..., -2.1597e-05,
           -1.2600e-05,  3.3275e-05]],

         [[-4.1443e-05,  1.5024e-04,  3.6927e-04,  ..., -1.2943e-04,
            4.1440e-05, -3.7286e-05],
          [-1.6585e-04,  5.1053e-04, -1.1176e-04,  ...,  4.8987e-04,
            9.2183e-04,  8.9593e-06],
          [ 2.6083e-05, -3.9482e-04,  5.0190e-04,  ..., -4.8869e-04,
           -3.4478e-05, -7.9500e-05],
          ...,
     

In [25]:
model.features[0].saved_tensors

{'output': tensor([[[[0.0000, 0.1340, 0.0000,  ..., 0.2387, 0.4909, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.1710, 0.0000, 0.0000],
           [0.0000, 0.0756, 0.1354,  ..., 0.0957, 0.1056, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0639, 0.0000, 0.0000],
           [0.0090, 0.3438, 0.3167,  ..., 0.0000, 0.0000, 0.2110],
           [0.2764, 0.3691, 0.0000,  ..., 0.0176, 0.1407, 0.2590]],
 
          [[0.0000, 0.0000, 0.3978,  ..., 0.4876, 0.1116, 0.0974],
           [0.0000, 0.0000, 0.0862,  ..., 0.0000, 0.2036, 0.0000],
           [0.0000, 0.0000, 0.2259,  ..., 0.0000, 0.3673, 0.1435],
           ...,
           [0.0000, 0.3118, 0.1653,  ..., 0.0000, 0.0000, 0.0000],
           [0.1337, 0.1739, 0.0000,  ..., 0.4365, 0.0000, 0.0000],
           [0.1032, 0.0711, 0.3093,  ..., 0.0000, 0.5031, 0.0000]],
 
          [[0.0000, 0.0000, 0.3442,  ..., 0.0353, 0.2122, 0.0000],
           [0.2539, 0.0000, 0.4047,  ..., 0.0000, 0.7316, 0.4200],
           [0.

In [10]:
model.ramp()

<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
