# Linear Layer RelProp Demo

**Alok Kamatar and Mike Tynes**
**2022.11.22**

Here we show that we can reproduce the relprop on the linear layer from the reference paper. 

In [1]:
import torch 
from torch import nn

from modules import layers_ours as layers

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Their linear layer and relprop

We put a single instance of thier linear layer into a network to test agains 

In [2]:
class TestNet(nn.Module): 
    
    def __init__(self): 
        super().__init__()
        self.linear = layers.Linear(5, 5, bias=False)
    
    def forward(self, x):
        out = self.linear(x)
        return out
    
    def relprop(self, y): 
        kwargs = {'alpha': 1} # this is how they do it in the relprop repo
        return self.linear.relprop(y, **kwargs)

In [3]:
X_t = torch.rand((1,5))

In [4]:
X_t

tensor([[0.9575, 0.0148, 0.8885, 0.0716, 0.8014]])

In [5]:
test_net = TestNet()

In [6]:
out_t = test_net.forward(X_t)
out_t

tensor([[-0.5447,  1.0457,  0.0854,  0.5045,  0.0797]], grad_fn=<MmBackward0>)

In [7]:
relprop_t = test_net.relprop(out_t)
relprop_t

tensor([[ 0.1852, -0.0043,  0.4465,  0.0291,  0.5139]], grad_fn=<SubBackward0>)

## Jax implementation of linear layer relprop

In [8]:
import jax.numpy as jnp
import jax.nn as jnn
import jax

We test the linear layer and relprop against the same input

In [9]:
X_j = jnp.array(X_t)
X_j



DeviceArray([[0.95748496, 0.01477087, 0.8885277 , 0.07158202, 0.80135876]],            dtype=float32)

Grab the params from the torch net

In [10]:
params = jnp.array(test_net.linear.weight.detach().numpy())
params

DeviceArray([[ 0.05143719,  0.05079835, -0.38719702, -0.10237346,
              -0.30360562],
             [ 0.4078076 ,  0.14633663,  0.4181239 , -0.08558807,
               0.35894793],
             [ 0.38435724,  0.27652642,  0.02783004,  0.13175456,
              -0.40045315],
             [ 0.3494682 ,  0.06004746, -0.17389387,  0.35976076,
               0.371603  ],
             [-0.22262457,  0.06671791,  0.33893257,  0.4350806 ,
              -0.05049613]], dtype=float32)

In [11]:
def j_linear(A, x): 
    """like torch.F.linear but in jax"""
    return x @ A.T

In [12]:
def forward(params, x): 
    out = j_linear(params, x)
    return out

In [13]:
backward = jax.jacfwd(forward, 1)

In [14]:
backward(jnp.asarray([[2., 0.], [0., 2.]]), jnp.asarray([1.,1.]))

DeviceArray([[2., 0.],
             [0., 2.]], dtype=float32)

In [15]:
params.shape

(5, 5)

In [16]:
out_j = forward(params, X_j)
out_j

DeviceArray([[-0.54465973,  1.0456654 ,  0.08535317,  0.5045277 ,
               0.07965522]], dtype=float32)

In [17]:
out_t

tensor([[-0.5447,  1.0457,  0.0854,  0.5045,  0.0797]], grad_fn=<MmBackward0>)

In [18]:
backward(params, X_j.squeeze())

DeviceArray([[ 0.05143719,  0.05079835, -0.38719702, -0.10237346,
              -0.30360562],
             [ 0.4078076 ,  0.14633663,  0.4181239 , -0.08558807,
               0.35894793],
             [ 0.38435724,  0.27652642,  0.02783004,  0.13175456,
              -0.40045315],
             [ 0.3494682 ,  0.06004746, -0.17389387,  0.35976076,
               0.371603  ],
             [-0.22262457,  0.06671791,  0.33893257,  0.4350806 ,
              -0.05049613]], dtype=float32)

In [19]:
import numpy as np

Is the forward pass OK? 

In [20]:
np.allclose(np.array(out_j), out_t.detach().numpy())

True

cool, now for the hard part. 

In [21]:
from jax import jit

In [22]:
@jit
def safe_divide(a, b):
    den = jnp.clip(b, a_min=1e-9) + jnp.clip(b, a_max=1e-9)
    den = den + (den == 0).astype(den.dtype) * 1e-9
    return a / den * (b != 0).astype(b.dtype)

In [23]:
safe_divide(params, X_j)

DeviceArray([[ 0.05372115,  3.439091  , -0.43577373, -1.4301561 ,
              -0.37886354],
             [ 0.4259154 ,  9.907112  ,  0.4705806 , -1.1956643 ,
               0.44792414],
             [ 0.40142378, 18.72107   ,  0.03132153,  1.8406098 ,
              -0.49971768],
             [ 0.3649856 ,  4.065263  , -0.19571012,  5.0258536 ,
               0.46371618],
             [-0.23250973,  4.516858  ,  0.38145414,  6.078071  ,
              -0.06301314]], dtype=float32)

In [24]:
def linear_relprop(R, x, params, alpha=1): 
    beta = alpha - 1
    pw = jnp.clip(params, a_min=0)
    nw = jnp.clip(params, a_max=0)
    px = jnp.clip(x,      a_min=0)
    nx = jnp.clip(x,      a_max=0)
    
    def f(w1, w2, x1, x2): 
        z1 = j_linear(w1, x1) 
        z2 = j_linear(w2, x2)
        s1 = safe_divide(R, z1+z2) # why are there two of these
        s2 = safe_divide(R, z1+z2) # for gradient reasons? 
        # there has to be a cleaner way, right?
        _, vjp = jax.vjp(lambda x : j_linear(w1, x), x1)
        c1 = x1 * vjp(s1)[0]
        _, vjp = jax.vjp(lambda x : j_linear(w1, x), x2)
        c2 = x2 * vjp(s2)[0]
        return c1 + c2
    
    activator_relevances = f(pw, nw, px, nx)
    inhibitor_relevances = f(nw, pw, px, nx)
    R = alpha * activator_relevances - beta * inhibitor_relevances
    return R

In [25]:
relevance_j = linear_relprop(out_j, X_j, params)
relevance_j

DeviceArray([[ 0.18518934, -0.00425183,  0.44652203,  0.0291398 ,
               0.5139424 ]], dtype=float32)

In [26]:
relevance_t = test_net.relprop(out_t)
relevance_t

tensor([[ 0.1852, -0.0043,  0.4465,  0.0291,  0.5139]], grad_fn=<SubBackward0>)

In [27]:
np.allclose(np.asarray(relevance_j), 
            relevance_t.detach().numpy())

True

Bullseye