In [4]:
import torch
import torch.nn
import jax
import jax.numpy as jnp
from jax import lax
import flax
import flax.linen as fnn
import numpy as np
from typing import Callable, Optional, Tuple, Union, Sequence, Iterable

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def safe_divide(a, b):
    den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
    den = den + den.eq(0).type(den.type()) * 1e-9
    return a / den * b.ne(0).type(b.type())


def forward_hook(self, input, output):
    if type(input[0]) in (list, tuple):
        self.X = []
        for i in input[0]:
            x = i.detach()
            x.requires_grad = True
            self.X.append(x)
    else:
        self.X = input[0].detach()
        self.X.requires_grad = True

    self.Y = output


class RelProp(torch.nn.Module):
    def __init__(self):
        super(RelProp, self).__init__()
        # if not self.training:
        self.register_forward_hook(forward_hook)

    def gradprop(self, Z, X, S):
        C = torch.autograd.grad(Z, X, S, retain_graph=True)
        return C

    def relprop(self, R, alpha):
        return R


class RelPropSimple(RelProp):
    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs

class MatMul(RelPropSimple):
    def forward(self, inputs):
        return torch.matmul(*inputs)
    
class Add(RelPropSimple):
    def forward(self, inputs):
        return torch.add(*inputs)

    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        a = self.X[0] * C[0]
        b = self.X[1] * C[1]

        a_sum = a.sum()
        b_sum = b.sum()

        a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
        b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()

        a = a * safe_divide(a_fact, a.sum())
        b = b * safe_divide(b_fact, b.sum())

        outputs = [a, b]
        return outputs
    
class IndexSelect(RelProp):
    def forward(self, inputs, dim, indices):
        self.__setattr__('dim', dim)
        self.__setattr__('indices', indices)

        return torch.index_select(inputs, dim, indices)

    def relprop(self, R, alpha):
        Z = self.forward(self.X, self.dim, self.indices)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs
    
class Clone(RelProp):
    def forward(self, input, num):
        self.__setattr__('num', num)
        outputs = []
        for _ in range(num):
            outputs.append(input)

        return outputs

    def relprop(self, R, alpha):
        Z = []
        for _ in range(self.num):
            Z.append(self.X)
        S = [safe_divide(r, z) for r, z in zip(R, Z)]
        C = self.gradprop(Z, self.X, S)[0]

        R = self.X * C

        return R

In [6]:
A = torch.ones((2,2))
B = torch.tensor([[-2., 1], [1, -2]])

mm = MatMul()
mm([A,B])

tensor([[-1., -1.],
        [-1., -1.]])

In [7]:
mm.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[ 2., -1.],
         [ 0.,  0.]], grad_fn=<MulBackward0>),
 tensor([[ 2.,  0.],
         [-1., -0.]], grad_fn=<MulBackward0>)]

In [8]:
add = Add()
add([A,B])

tensor([[-1.,  2.],
        [ 2., -1.]])

In [9]:
add.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[0.3333, -0.0000],
         [-0.0000, 0.0000]], grad_fn=<MulBackward0>),
 tensor([[0.6667, 0.0000],
         [0.0000, 0.0000]], grad_fn=<MulBackward0>)]

In [10]:
pool = IndexSelect()
pool(B, 1, torch.zeros(1, dtype=torch.int32))

tensor([[-2.],
        [ 1.]])

In [11]:
pool.relprop(torch.tensor([[1.],[0]]), alpha=1)

tensor([[1., 0.],
        [0., -0.]], grad_fn=<MulBackward0>)

In [12]:
clone = Clone()
clone(B, 2)

[tensor([[-2.,  1.],
         [ 1., -2.]]),
 tensor([[-2.,  1.],
         [ 1., -2.]])]

In [16]:
clone.relprop([torch.tensor([[.5,.5],[0,0]]), torch.tensor([[0,.25],[.25,.5]])], alpha=1)

tensor([[0.5000, 0.7500],
        [0.2500, 0.5000]], grad_fn=<MulBackward0>)

It looks like clone is necessary because of the way pyTorch tracks gradients. Working out the math, I think the relprop works out to a sum over all of the relevances (which intuitively also makes sense). For this reason, I didn't include "clone" in the Flax layers I implemented and whenever I saw a clone.relprop in the pyTorch version, I added the relevances in the Jax version

In [220]:
def jax_safe_divide(a, b):
    den = jnp.clip(b, a_min=1e-9) + jnp.clip(b, a_max=1e-9)
    den = den + (den == 0).astype(den.dtype) * 1e-9
    return a / den * (b != 0).astype(b.dtype)

class JaxRelProp(fnn.Module):
    def relprop(self, R):
        return R


class JaxRelPropSimple(JaxRelProp):
    def relprop(self, cam, *inputs, **kwargs):
        z, grad_func = jax.vjp(lambda *i : self.__call__(*i, **kwargs), *inputs)
        s = jax_safe_divide(cam, z)
        c = grad_func(s)

        if len(inputs) > 1:
            # We only ever take te relevence propogation with respect to the first two positional arguments
            outputs = []
            outputs.append(inputs[0] * c[0])
            outputs.append(inputs[1] * c[1])
        else:
            outputs = inputs[0] * (c[0])
        return outputs

class JaxMatMul(JaxRelPropSimple):
    def __call__(self, *inputs):
        return jnp.matmul(*inputs)
    
class JaxAdd(JaxRelPropSimple):
    def __call__(self, *inputs):
        return jnp.add(*inputs)
    
    def relprop(self, cam, *inputs, **kwargs):
        z, grad_func = jax.vjp(lambda *i : self.__call__(*i), *inputs)
        s = jax_safe_divide(cam, z)
        c = grad_func(s)

        a = inputs[0] * c[0]
        b = inputs[1] * c[1]

        a_sum = a.sum()
        b_sum = b.sum()

        a_fact = jax_safe_divide(jnp.abs(a_sum), jnp.abs(a_sum) + jnp.abs(b_sum)) * cam.sum()
        b_fact = jax_safe_divide(jnp.abs(b_sum), jnp.abs(a_sum) + jnp.abs(b_sum)) * cam.sum()

        a = a * jax_safe_divide(a_fact, a.sum())
        b = b * jax_safe_divide(b_fact, b.sum())

        outputs = [a, b]

        return outputs

class JaxIndexSelect(JaxRelProp):
    def __call__(self, inputs, dim, indices):
        return jnp.take(inputs, indices, axis=dim)

    def relprop(self, cam, *inputs):
        (inputs, dim, indices) = inputs
        z, grad_func = jax.vjp(lambda i : self.__call__(i, dim, indices), inputs)
        s = jax_safe_divide(cam, z)
        c = grad_func(s)
        outputs = inputs * (c[0])
        return outputs

In [221]:
A = jnp.ones((2,2))
B = jnp.array([[-2., 1], [1, -2]])

jmm = JaxMatMul()
jmm(A,B)

DeviceArray([[-1., -1.],
             [-1., -1.]], dtype=float32)

In [222]:
jmm.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 2., -1.],
              [ 0.,  0.]], dtype=float32),
 DeviceArray([[ 2.,  0.],
              [-1., -0.]], dtype=float32)]

In [223]:
j_add = JaxAdd()
j_add(A,B)

DeviceArray([[-1.,  2.],
             [ 2., -1.]], dtype=float32)

In [224]:
j_add.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 0.33333334, -0.        ],
              [-0.        ,  0.        ]], dtype=float32),
 DeviceArray([[0.6666667, 0.       ],
              [0.       , 0.       ]], dtype=float32)]

In [225]:
jax_pool = JaxIndexSelect()
jax_pool(B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[-2.],
             [ 1.]], dtype=float32)

In [226]:
jax_pool.relprop(jnp.array([[1.],[0]]), B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[ 1.,  0.],
             [ 0., -0.]], dtype=float32)

In [245]:
class JaxDense(fnn.Dense, JaxRelProp):
    def relprop(self, R, *inputs, alpha=1):
        inputs = inputs[0]
        beta = alpha - 1
        j, k = x.shape[-1], self.features
        w = self.variables["params"]["kernel"]
        pw = jnp.clip(w, a_min=0)
        nw = jnp.clip(w, a_max=0)
        px = jnp.clip(x, a_min=0)
        nx = jnp.clip(x, a_max=0)
        
        def __f(R, w1, w2, x1, x2):
            z1, vjp_x1 = jax.vjp(lambda x: jnp.dot(x,w1), x1)
            z2, vjp_x2 = jax.vjp(lambda x: jnp.dot(x,w2), x2)
            s1 = jax_safe_divide(R, z1 + z2)
            s2 = jax_safe_divide(R, z1 + z2)

            c1 = x1 * vjp_x1(s1)[0]
            c2 = x2 * vjp_x2(s2)[0]

            return c1 + c2

        activator_relevances = __f(R, pw, nw, px, nx)
        inhibitor_relevances = __f(R, nw, pw, px, nx)
        R = alpha * activator_relevances - beta * inhibitor_relevances

        return R
    

In [243]:
jax_dense = JaxDense(2)
x = jnp.ones((1,20))
variables = jax_dense.init(jax.random.PRNGKey(0), x)
model = jax_dense.bind(variables)
model(x)
variables
print(model.variables["params"])

FrozenDict({
    kernel: DeviceArray([[ 1.53342038e-01,  3.29581231e-01],
                 [ 4.70121145e-01,  1.81581691e-01],
                 [ 1.80420190e-01,  3.35969597e-01],
                 [-1.62113383e-01, -2.02395543e-01],
                 [-1.18919984e-01, -6.62900805e-02],
                 [-4.10243064e-01,  7.33509436e-02],
                 [ 4.96547073e-02, -2.71931272e-02],
                 [-2.19073966e-01,  4.65010494e-01],
                 [ 1.54693127e-02,  1.10250756e-01],
                 [-3.89764100e-01,  8.84109288e-02],
                 [-3.79199862e-01, -2.35682800e-02],
                 [ 3.24670374e-01,  2.04515398e-01],
                 [ 1.44210760e-04,  1.65998340e-01],
                 [-3.36234242e-01, -3.94779295e-01],
                 [-1.40488684e-01, -5.03639765e-02],
                 [-2.11226270e-02,  2.08313853e-01],
                 [ 9.56470072e-02,  2.05471292e-01],
                 [-9.69679356e-02, -3.74182016e-01],
                 [-3.6075

In [244]:
model.relprop(jnp.array([[1.,0]]), x)

DeviceArray([[1.1891876e-01, 3.6458510e-01, 1.3991822e-01, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 3.8507875e-02, 0.0000000e+00,
              1.1996655e-02, 0.0000000e+00, 0.0000000e+00, 2.5178614e-01,
              1.1183733e-04, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              7.4175507e-02, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],            dtype=float32)

In [158]:
print(model.relprop)

<bound method JaxRelProp.relprop of JaxDense()>
