# Linear Layer RelProp Demo

**Alok Kamatar and Mike Tynes**
**2022.11.22**

Here we show that we can reproduce the relprop on the linear layer from the reference paper. 

In [1]:
import torch 
from torch import nn

from modules import layers_ours as layers

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Their linear layer and relprop

We put a single instance of thier linear layer into a network to test agains 

In [2]:
class TestNet(nn.Module): 
    
    def __init__(self): 
        super().__init__()
        self.linear = layers.Linear(5, 5, bias=False)
    
    def forward(self, x):
        out = self.linear(x)
        return out
    
    def relprop(self, y): 
        kwargs = {'alpha': 1} # this is how they do it in the relprop repo
        return self.linear.relprop(y, **kwargs)

In [3]:
X_t = torch.rand((1,5))

In [4]:
X_t

tensor([[0.4541, 0.6316, 0.5566, 0.0331, 0.1095]])

In [5]:
test_net = TestNet()

In [6]:
out_t = test_net.forward(X_t)
out_t

tensor([[-1.5071e-01,  2.3412e-02,  3.9160e-04,  4.0253e-01, -2.0664e-02]],
       grad_fn=<MmBackward0>)

In [7]:
relprop_t = test_net.relprop(out_t)
relprop_t

tensor([[ 0.0832,  0.1991,  0.0089,  0.0008, -0.0371]], grad_fn=<SubBackward0>)

## Jax implementation of linear layer relprop

In [8]:
import jax.numpy as jnp
import jax.nn as jnn
import jax

We test the linear layer and relprop against the same input

In [9]:
X_j = jnp.array(X_t)
X_j



DeviceArray([[0.45414388, 0.63164246, 0.5565549 , 0.03309029, 0.10949355]],            dtype=float32)

Grab the params from the torch net

In [10]:
params = jnp.array(test_net.linear.weight.detach().numpy())
params

DeviceArray([[-0.30776274, -0.25277573,  0.20577456, -0.30647692,
               0.4049387 ],
             [ 0.20985027,  0.02401242, -0.21625122,  0.08439169,
               0.27860352],
             [ 0.13254546,  0.02433965, -0.11566351,  0.2926152 ,
              -0.18710423],
             [ 0.14981171,  0.32025668,  0.24071638,  0.05160305,
              -0.03176383],
             [-0.4107939 ,  0.06375179,  0.2790354 ,  0.43949497,
              -0.40380582]], dtype=float32)

In [11]:
def j_linear(A, x): 
    """like torch.F.linear but in jax"""
    return x @ A.T

In [12]:
def forward(params, x): 
    out = j_linear(params, x)
    return out

In [13]:
backward = jax.jacfwd(forward, 1)

In [14]:
backward(jnp.asarray([[2., 0.], [0., 2.]]), jnp.asarray([1.,1.]))

DeviceArray([[2., 0.],
             [0., 2.]], dtype=float32)

In [15]:
params.shape

(5, 5)

In [16]:
out_j = forward(params, X_j)
out_j

DeviceArray([[-1.5071085e-01,  2.3411637e-02,  3.9159169e-04,
               4.0252531e-01, -2.0663811e-02]], dtype=float32)

In [17]:
out_t

tensor([[-1.5071e-01,  2.3412e-02,  3.9160e-04,  4.0253e-01, -2.0664e-02]],
       grad_fn=<MmBackward0>)

In [18]:
backward(params, X_j.squeeze())

DeviceArray([[-0.30776274, -0.25277573,  0.20577456, -0.30647692,
               0.4049387 ],
             [ 0.20985027,  0.02401242, -0.21625122,  0.08439169,
               0.27860352],
             [ 0.13254546,  0.02433965, -0.11566351,  0.2926152 ,
              -0.18710423],
             [ 0.14981171,  0.32025668,  0.24071638,  0.05160305,
              -0.03176383],
             [-0.4107939 ,  0.06375179,  0.2790354 ,  0.43949497,
              -0.40380582]], dtype=float32)

In [19]:
import numpy as np

Is the forward pass OK? 

In [20]:
np.allclose(np.array(out_j), out_t.detach().numpy())

True

cool, now for the hard part. 

In [21]:
from jax import jit

In [22]:
@jit
def safe_divide(a, b):
    den = jnp.clip(b, a_min=1e-9) + jnp.clip(b, a_max=1e-9)
    den = den + (den == 0).astype(den.dtype) * 1e-9
    return a / den * (b != 0).astype(b.dtype)

In [23]:
safe_divide(params, X_j)

DeviceArray([[-0.67767674, -0.400188  ,  0.36972913, -9.261838  ,
               3.698288  ],
             [ 0.4620788 ,  0.03801585, -0.38855326,  2.550346  ,
               2.5444741 ],
             [ 0.29185784,  0.03853391, -0.20782048,  8.842932  ,
              -1.708815  ],
             [ 0.3298772 ,  0.5070221 ,  0.43251148,  1.5594617 ,
              -0.29009774],
             [-0.9045457 ,  0.10093018,  0.50136185, 13.281689  ,
              -3.6879416 ]], dtype=float32)

In [24]:
def linear_relprop(R, x, params, alpha=1): 
    beta = alpha - 1
    pw = jnp.clip(params, a_min=0)
    nw = jnp.clip(params, a_max=0)
    px = jnp.clip(x,      a_min=0)
    nx = jnp.clip(x,      a_max=0)
    
    def f(w1, w2, x1, x2): 
        z1 = j_linear(w1, x1) 
        z2 = j_linear(w2, x2)
        s1 = safe_divide(R, z1+z2) # why are there two of these
        s2 = safe_divide(R, z1+z2) # for gradient reasons? 
        # there has to be a cleaner way, right?
        _, vjp = jax.vjp(lambda x : j_linear(w1, x), x1)
        c1 = x1 * vjp(s1)[0]
        _, vjp = jax.vjp(lambda x : j_linear(w2, x), x2)
        c2 = x2 * vjp(s2)[0]
        return c1 + c2
    
    activator_relevances = f(pw, nw, px, nx)
    inhibitor_relevances = f(nw, pw, px, nx)
    R = alpha * activator_relevances - beta * inhibitor_relevances
    return R

In [25]:
relevance_j = linear_relprop(out_j, X_j, params)
relevance_j

DeviceArray([[ 0.08324914,  0.19913508,  0.00890309,  0.00076189,
              -0.03709533]], dtype=float32)

In [26]:
relevance_t = test_net.relprop(out_t)
relevance_t

tensor([[ 0.0832,  0.1991,  0.0089,  0.0008, -0.0371]], grad_fn=<SubBackward0>)

In [27]:
np.allclose(np.asarray(relevance_j), 
            relevance_t.detach().numpy())

True

Bullseye