In [1]:
import torch 
from torch import nn

from modules import layers_ours as layers

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Their linear layer and relprop

We put a single instance of thier linear layer into a network to test agains 

In [8]:
class TestNet(nn.Module): 
    
    def __init__(self): 
        super().__init__()
        self.linear = layers.Linear(5, 5, bias=False)
    
    def forward(self, x):
        out = self.linear(x)
        return out
    
    def relprop(self, y): 
        kwargs = {'alpha': 1} # this is how they do it in the relprop repo
        return self.linear.relprop(y, **kwargs)

In [9]:
X_t = torch.rand((5,5))

In [10]:
X_t

tensor([[0.0396, 0.2021, 0.0094, 0.2014, 0.2615],
        [0.3696, 0.1693, 0.2372, 0.2914, 0.5716],
        [0.6876, 0.7824, 0.4204, 0.9479, 0.5541],
        [0.3594, 0.6124, 0.4170, 0.4724, 0.7459],
        [0.2507, 0.7391, 0.8448, 0.3919, 0.1333]])

In [11]:
test_net = TestNet()

In [12]:
out_t = test_net.forward(X_t)
out_t

tensor([[ 0.0271,  0.2247,  0.0595,  0.0737,  0.0288],
        [ 0.1560,  0.3779, -0.0645,  0.2782,  0.0513],
        [ 0.0212,  0.7607,  0.2313,  0.3238,  0.1256],
        [ 0.0778,  0.6604,  0.0195,  0.2062,  0.0615],
        [-0.1338,  0.4483,  0.0294, -0.1131, -0.1263]], grad_fn=<MmBackward0>)

In [13]:
relprop_t = test_net.relprop(out_t)
relprop_t

tensor([[ 1.5359e-02,  1.2895e-01,  1.8755e-04,  1.0302e-01,  1.6633e-01],
        [ 1.5826e-01,  6.1305e-02,  9.8979e-03,  9.0426e-02,  4.7904e-01],
        [ 2.3083e-01,  4.6282e-01,  4.1858e-03,  4.4671e-01,  3.1806e-01],
        [ 1.0647e-01,  2.9200e-01,  8.0763e-03,  1.6335e-01,  4.5545e-01],
        [-6.5436e-02,  2.3639e-01, -3.7730e-02, -2.0724e-02, -8.0368e-03]],
       grad_fn=<SubBackward0>)

## Jax implementation of linear layer relprop

In [22]:
import jax.numpy as jnp
import jax.nn as jnn
import jax

We test the linear layer and relprop against the same input

In [15]:
X_j = jnp.array(X_t)
X_j



DeviceArray([[0.03955817, 0.20212638, 0.00942403, 0.2013644 , 0.2615192 ],
             [0.3695619 , 0.16926914, 0.23717767, 0.2913971 , 0.5716445 ],
             [0.6875626 , 0.7824402 , 0.42037886, 0.947896  , 0.55408573],
             [0.3594305 , 0.61237097, 0.4169714 , 0.47242063, 0.7459355 ],
             [0.250704  , 0.73905694, 0.84475124, 0.3918997 , 0.13328278]],            dtype=float32)

Grab the params from the torch net

In [16]:
params = jnp.array(test_net.linear.weight.detach().numpy())
params

DeviceArray([[ 0.02091356, -0.3537856 ,  0.04878209,  0.10913646,
               0.28824186],
             [ 0.07911818,  0.40988153,  0.00648237,  0.17079464,
               0.39888352],
             [-0.06482629,  0.24479395, -0.28112087,  0.3266349 ,
              -0.19329758],
             [ 0.3046329 , -0.34327713, -0.09880875,  0.27790838,
               0.29062688],
             [ 0.44012907,  0.30358416, -0.42777327, -0.26360548,
               0.0272531 ]], dtype=float32)

In [17]:
def j_linear(a, b): 
    """like torch.F.linear but in jax"""
    return jnp.dot(b, a.T)

In [18]:
def forward(params, x): 
    out = j_linear(params, x)
    return out

In [23]:
backward = jax.grad(forward)

In [19]:
out_j = forward(params, X_j)
out_j

DeviceArray([[ 0.0271346 ,  0.22474639,  0.05948723,  0.07369954,
               0.02878817],
             [ 0.15598781,  0.37794545, -0.06451388,  0.27815622,
               0.05134932],
             [ 0.02123099,  0.7607427 ,  0.23129973,  0.3237837 ,
               0.12555586],
             [ 0.07777783,  0.66036826,  0.01950636,  0.20615955,
               0.06152911],
             [-0.13382761,  0.4483357 ,  0.02943215, -0.11314971,
              -0.12632845]], dtype=float32)

In [20]:
import numpy as np

Is the forward pass OK? 

In [21]:
np.allclose(np.array(out_j), out_t.detach().numpy())

True

cool, now for the hard part. 

In [27]:
from jax import jit

In [28]:
@jit
def safe_divide(a, b):
    den = jnp.clip(b, a_min=1e-9) + jnp.clip(b, a_max=1e-9)
    den = den + (den == 0).astype(den.dtype) * 1e-9
    return a / den * (b != 0).astype(b.dtype)

In [29]:
safe_divide(params, X_j)

DeviceArray([[ 0.52867854, -1.7503188 ,  5.1763506 ,  0.5419849 ,
               1.1021824 ],
             [ 0.2140864 ,  2.421478  ,  0.02733127,  0.58612335,
               0.6977825 ],
             [-0.0942842 ,  0.31285962, -0.66873217,  0.3445894 ,
              -0.34885862],
             [ 0.8475433 , -0.56057054, -0.2369677 ,  0.5882647 ,
               0.38961396],
             [ 1.7555727 ,  0.41077235, -0.5063896 , -0.672635  ,
               0.20447579]], dtype=float32)

In [None]:
def linear_relprop(R, x params, alpha=1): 
    beta = alpha - 1
    pw = jax.clip(params, min=0)
    nw = jax.clip(params, max=0)
    px = jax.clip(x,      min=0)
    nx = jax.clip(x,      max=0)
    
    def f(w1, w2, x1, x2): 
        z1 = j_linear(x1, w1) 
        z2 = j_linear(x2, w2)
        s1 = safe_divide(R, z1+z2) # why are there two of these
        s2 = safe_divide(R, z1+z2) # for gradient reasons? 
        
        c1 = x1 * (backward(z1, x1) * s1)
        print(c1)
        