In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [3]:
from abc import ABCMeta, abstractmethod

def constant(obj):
    '''Wrapper function to create a function which returns a constant object regardless of arguments.

    Parameters
    ----------
    obj : object
        Constant object which the returned wrapper function will return on call.

    Returns
    -------
    wrapped_const : function
        Function which when called with any arguments will return ``obj``.

    '''
    def wrapped_const(*args, **kwargs):
        return obj
    return wrapped_const


def identity(obj):
    '''Identity function.

    Parameters
    ----------
    obj : object
        Any object which will be returned.

    Result
    ------
    obj : object
        The original input argument ``obj``.

    '''
    return obj

class Attributor(metaclass=ABCMeta):
    '''Base Attributor Class.

    Attributors are convenience objects with an optional composite and when called, compute an attribution, e.g., the
    gradient or anything that is the result of computing the gradient when using the provided composite.  Attributors
    also provide a context to be used in a `with` statement, similar to `CompositeContext`s. If the forward function
    (or `self.__call__`) is called and the composite has not been registered (i.e. `composite.handles` is empty), the
    composite will be temporarily registered to the model.

    Parameters
    ----------
    model: :py:obj:`torch.nn.Module`
        The model for which the attribution will be computed. If `composite` is provided, this will also be the model
        to which the composite will be registered within `with` statements, or when calling the `Attributor` instance.
    composite: :py:obj:`zennit.core.Composite`, optional
        The optional composite to, if provided, be registered to the model within `with` statements or when calling the
        `Attributor` instance.
    attr_output: :py:obj:`torch.Tensor` or callable, optional
        The default output attribution to be used when calling the `Attributor` instance, which is either a Tensor
        compatible with any input used, or a function of the model's output. If None (default), the value will be the
        identity function.

    '''
    def __init__(self, model, composite=None, attr_output=None):
        self.model = model
        self.composite = composite

        if attr_output is None:
            self.attr_output_fn = identity
        elif not callable(attr_output):
            self.attr_output_fn = constant(attr_output)
        else:
            self.attr_output_fn = attr_output

    def __enter__(self):
        '''Register the composite, if provided.

        Returns
        -------
        self: :py:obj:`Attributor`
            The `Attributor` instance.

        '''
        if self.composite is not None:
            self.composite.register(self.model)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        '''Remove the composite, if provided.

        Returns
        -------
        False

        '''
        if self.composite is not None:
            self.composite.remove()
        return False

    def __call__(self, input, attr_output=None):
        '''Compute the attribution of the model wrt. `input`, using `attr_output` as the output attribution if
        provided, or the default output attribution otherwise (if not supplied during instantiation either, this will
        be the full output of the model). If a composite was supplied to the `Attributor` instance, but it was not yet
        registered (either manually, or in a `with` statement), it will be registered to the model temporarily during
        the call.

        Parameters
        ----------
        input: :py:obj:`torch.Tensor`
            Input for the model, and wrt. compute the attribution
        attr_output: :py:obj:`torch.Tensor` or callable, optional
            The output attribution, which is either a Tensor compatible `input` (i.e. has the same shape as the output
            of the model), or a function of the model's output. If None (default), the default attribution will be
            used, which if neither supplied, will result in the model output used as the output attribution.

        Returns
        -------
        output: :py:obj:`torch.Tensor`
            Output of the model with argument `input`.
        attribution: :py:obj:`torch.Tensor`
            Attribution of the model wrt. to `input`, with the same shape as `input`.
        '''
        if attr_output is None:
            attr_output_fn = self.attr_output_fn
        elif not callable(attr_output):
            attr_output_fn = constant(attr_output)
        else:
            attr_output_fn = attr_output

        if self.composite is None or self.composite.handles:
            return self.forward(input, attr_output_fn)

        with self:
            return self.forward(input, attr_output_fn)

    @property
    def inactive(self):
        '''Return the attributor's composite's ``.inactive`` context.'''
        return self.composite.inactive

    @abstractmethod
    def forward(self, input, attr_output_fn):
        '''Abstract method. Compute the attribution of the model wrt. input, by using `attr_output_fn` as the function
        of the model output to provide the output attribution. This function will not register the composite, and is
        wrapped in the `__call__` of `Attributor`.

        Parameters
        ----------
        input: :py:obj:`torch.Tensor`
            Input for the model, and wrt. compute the attribution
        attr_output: :py:obj:`torch.Tensor` or callable, optional
            The output attribution function of the model's output.
        '''

In [10]:
import torch

class GradientTimesInput(Attributor):
    '''Model-agnostic gradient times input.'''
    def forward(self, input, attr_output_fn):
        '''Compute gradient times input.'''
        input_detached = input.detach().requires_grad_(True)
        output = self.model(input_detached)
        gradient, = torch.autograd.grad(
            (output,), (input_detached,), (attr_output_fn(output.detach()),)
        )
        print(gradient.shape)
        relevance = gradient * input
        return output, relevance

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(10, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 1)
)

batch_size = 4
input = torch.randn(batch_size, 10)
attr = GradientTimesInput(model)

output, relevance = attr(input)

In [12]:
from contextlib import contextmanager
import weakref

class RemovableHandle:
    '''Create weak reference to call .remove on some instance.'''
    def __init__(self, instance):
        self.instance_ref = weakref.ref(instance)

    def remove(self):
        '''Call remove on weakly reference instance if it still exists.'''
        instance = self.instance_ref()
        if instance is not None:
            instance.remove()


class RemovableHandleList(list):
    '''A list to hold handles, with the ability to call remove on all of its members.'''
    def remove(self):
        '''Call remove on all members, effectively removing handles from modules, or reverting canonizers.'''
        for handle in self:
            handle.remove()
        self.clear()

class CompositeContext:
    '''A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.

    Parameters
    ----------
    module: :py:class:`torch.nn.Module`
        The module to which `composite` should be registered.
    composite: :py:class:`zennit.core.Composite`
        The composite which shall be registered to `module`.
    '''
    def __init__(self, module, composite):
        self.module = module
        self.composite = composite

    def __enter__(self):
        self.composite.register(self.module)
        return self.module

    def __exit__(self, exc_type, exc_value, traceback):
        self.composite.remove()
        return False

class Composite:
    '''A Composite to apply canonizers and register hooks to modules.
    One Composite instance may only be applied to a single module at a time.

    Parameters
    ----------
    module_map: callable, optional
        A function ``(ctx: dict, name: str, module: torch.nn.Module) -> Hook or None`` which maps a context, name and
        module to a matching :py:class:`~zennit.core.Hook`, or ``None`` if there is no matchin
        :py:class:`~zennit.core.Hook`.
    canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
        List of canonizer instances to be applied before applying hooks.
    '''
    def __init__(self, module_map=None, canonizers=None):
        if module_map is None:
            module_map = self._empty_module_map
        if canonizers is None:
            canonizers = []

        self.module_map = module_map
        self.canonizers = canonizers

        self.handles = RemovableHandleList()
        self.hook_refs = weakref.WeakSet()

    def register(self, module):
        '''Apply all canonizers and register all hooks to a module (and its recursive children).
        Previous canonizers of this composite are reverted and all hooks registered by this composite are removed.
        The module or any of its children (recursively) may still have other hooks attached.

        Parameters
        ----------
        module: :py:class:`torch.nn.Module`
            Hooks and canonizers will be applied to this module recursively according to ``module_map`` and
            ``canonizers``.
        '''
        self.remove()

        for canonizer in self.canonizers:
            self.handles += canonizer.apply(module)

        ctx = {}
        for name, child in module.named_modules():
            template = self.module_map(ctx, name, child)
            if template is not None:
                hook = template.copy()
                self.hook_refs.add(hook)
                self.handles.append(hook.register(child))

    def remove(self):
        '''Remove all handles for hooks and canonizers.
        Hooks will simply be removed from their corresponding Modules.
        Canonizers will revert the state of the modules they changed.
        '''
        self.handles.remove()
        self.hook_refs.clear()

    def context(self, module):
        '''Return a CompositeContext object with this instance and the supplied module.

        Parameters
        ----------
        module: :py:class:`torch.nn.Module`
            Module for which to register this composite in the context.

        Returns
        -------
        :py:class:`zennit.core.CompositeContext`
            A context object which registers the composite to ``module`` on entering, and removes it on exiting.
        '''
        return CompositeContext(module, self)

    @contextmanager
    def inactive(self):
        '''Context manager to temporarily deactivate the gradient modification. This can be used to compute the
        gradient of the modified gradient.
        '''
        try:
            for hook in self.hook_refs:
                hook.active = False
            yield self
        finally:
            for hook in self.hook_refs:
                hook.active = True

    @staticmethod
    def _empty_module_map(ctx, name, module):
        '''Empty module_map, does not assign any rules.'''
        return None