# Linear Layer RelProp Demo

**Alok Kamatar and Mike Tynes**
**2022.11.22**

Here we show that we can reproduce the relprop on the linear layer from the reference paper. 

In [1]:
import torch 
from torch import nn

from modules import layers_ours as layers

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Their linear layer and relprop

We put a single instance of thier linear layer into a network to test agains 

In [2]:
class TestNet(nn.Module): 
    
    def __init__(self): 
        super().__init__()
        self.linear = layers.Linear(5, 5, bias=False)
    
    def forward(self, x):
        out = self.linear(x)
        return out
    
    def relprop(self, y): 
        kwargs = {'alpha': 1} # this is how they do it in the relprop repo
        return self.linear.relprop(y, **kwargs)

In [3]:
X_t = torch.rand((1,5))

In [4]:
X_t

tensor([[0.2221, 0.9772, 0.2961, 0.9482, 0.9726]])

In [5]:
test_net = TestNet()

In [6]:
out_t = test_net.forward(X_t)
out_t

tensor([[-0.6279, -0.0045,  0.1656,  0.7667, -0.3432]], grad_fn=<MmBackward0>)

In [7]:
relprop_t = test_net.relprop(out_t)
relprop_t

tensor([[-0.0014, -0.4511,  0.0207,  0.1390,  0.2494]], grad_fn=<SubBackward0>)

## Jax implementation of linear layer relprop

In [8]:
import jax.numpy as jnp
import jax.nn as jnn
import jax

We test the linear layer and relprop against the same input

In [9]:
X_j = jnp.array(X_t)
X_j



DeviceArray([[0.2221241 , 0.9772073 , 0.29611993, 0.94821244, 0.97255856]],            dtype=float32)

Grab the params from the torch net

In [10]:
params = jnp.array(test_net.linear.weight.detach().numpy())
params

DeviceArray([[-0.44147232,  0.05627925, -0.06419598, -0.35039422,
              -0.24016897],
             [ 0.40366843, -0.0708417 , -0.3224802 , -0.12766537,
               0.19697148],
             [-0.19308433, -0.34504202, -0.02525529,  0.39824575,
               0.1804341 ],
             [-0.18666616,  0.19073626,  0.4374971 ,  0.29672757,
               0.2167627 ],
             [-0.16843243, -0.12480278,  0.05739747,  0.04225772,
              -0.24766016]], dtype=float32)

In [11]:
def j_linear(A, x): 
    """like torch.F.linear but in jax"""
    return x @ A.T

In [12]:
def forward(params, x): 
    out = j_linear(params, x)
    return out

In [13]:
backward = jax.jacfwd(forward, 1)

In [14]:
backward(jnp.asarray([[2., 0.], [0., 2.]]), jnp.asarray([1.,1.]))

DeviceArray([[2., 0.],
             [0., 2.]], dtype=float32)

In [15]:
params.shape

(5, 5)

In [16]:
out_j = forward(params, X_j)
out_j

DeviceArray([[-0.62790143, -0.00454296,  0.16555943,  0.7666526 ,
              -0.34316927]], dtype=float32)

In [17]:
out_t

tensor([[-0.6279, -0.0045,  0.1656,  0.7667, -0.3432]], grad_fn=<MmBackward0>)

In [18]:
backward(params, X_j.squeeze())

DeviceArray([[-0.44147232,  0.05627925, -0.06419598, -0.35039422,
              -0.24016897],
             [ 0.40366843, -0.0708417 , -0.3224802 , -0.12766537,
               0.19697148],
             [-0.19308433, -0.34504202, -0.02525529,  0.39824575,
               0.1804341 ],
             [-0.18666616,  0.19073626,  0.4374971 ,  0.29672757,
               0.2167627 ],
             [-0.16843243, -0.12480278,  0.05739747,  0.04225772,
              -0.24766016]], dtype=float32)

In [19]:
import numpy as np

Is the forward pass OK? 

In [20]:
np.allclose(np.array(out_j), out_t.detach().numpy())

True

cool, now for the hard part. 

In [21]:
from jax import jit

In [22]:
@jit
def safe_divide(a, b):
    den = jnp.clip(b, a_min=1e-9) + jnp.clip(b, a_max=1e-9)
    den = den + (den == 0).astype(den.dtype) * 1e-9
    return a / den * (b != 0).astype(b.dtype)

In [23]:
safe_divide(params, X_j)

DeviceArray([[-1.987503  ,  0.05759193, -0.21679048, -0.36953133,
              -0.24694552],
             [ 1.8173103 , -0.07249404, -1.089019  , -0.13463794,
               0.20252916],
             [-0.8692633 , -0.3530899 , -0.08528736,  0.41999632,
               0.18552516],
             [-0.8403688 ,  0.19518507,  1.4774321 ,  0.31293362,
               0.22287883],
             [-0.75828075, -0.12771372,  0.19383185,  0.04456566,
              -0.25464806]], dtype=float32)

In [24]:
def linear_relprop(R, x, params, alpha=1): 
    beta = alpha - 1
    pw = jnp.clip(params, a_min=0)
    nw = jnp.clip(params, a_max=0)
    px = jnp.clip(x,      a_min=0)
    nx = jnp.clip(x,      a_max=0)
    
    def f(w1, w2, x1, x2): 
        z1 = j_linear(w1, x1) 
        z2 = j_linear(w2, x2)
        s1 = safe_divide(R, z1+z2) # why are there two of these
        s2 = safe_divide(R, z1+z2) # for gradient reasons? 
        _, vjp = jax.vjp(lambda x : j_linear(w1, x), x1)
        c1 = x1 * vjp(s1)[0]
        c2 = x2 * vjp(s2)[0]
        return c1 + c2
    
    activator_relevances = f(pw, nw, px, nx)
    inhibitor_relevances = f(nw, pw, px, nx)
    R = alpha * activator_relevances - beta * inhibitor_relevances
    return R

In [25]:
relevance_j = linear_relprop(out_j, X_j, params)
relevance_j

DeviceArray([[-0.00144843, -0.4510759 ,  0.02069469,  0.13899776,
               0.24943018]], dtype=float32)

In [26]:
relevance_t = test_net.relprop(out_t)
relevance_t

tensor([[-0.0014, -0.4511,  0.0207,  0.1390,  0.2494]], grad_fn=<SubBackward0>)

In [27]:
np.allclose(np.asarray(relevance_j), 
            relevance_t.detach().numpy())

True

Bullseye