In [1]:
import torch 
from torch import nn

from modules import layers_ours as layers

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Their linear layer and relprop

We put a single instance of thier linear layer into a network to test agains 

In [2]:
class TestNet(nn.Module): 
    
    def __init__(self): 
        super().__init__()
        self.linear = layers.Linear(5, 5, bias=False)
    
    def forward(self, x):
        out = self.linear(x)
        return out
    
    def relprop(self, y): 
        kwargs = {'alpha': 1} # this is how they do it in the relprop repo
        return self.linear.relprop(y, **kwargs)

In [3]:
X_t = torch.rand((1,5))

In [4]:
X_t

tensor([[0.0144, 0.0796, 0.9815, 0.9279, 0.1152]])

In [5]:
test_net = TestNet()

In [6]:
out_t = test_net.forward(X_t)
out_t

tensor([[-0.1706,  0.4224, -0.1271, -0.1420,  0.0921]], grad_fn=<MmBackward0>)

In [7]:
relprop_t = test_net.relprop(out_t)
relprop_t

RETURNING JUST C


tensor([[ 0.0000, -0.0290,  0.1348, -0.0006, -0.0304]], grad_fn=<MulBackward0>)

## Jax implementation of linear layer relprop

In [8]:
import jax.numpy as jnp
import jax.nn as jnn
import jax

We test the linear layer and relprop against the same input

In [9]:
X_j = jnp.array(X_t)
X_j



DeviceArray([[0.0144251 , 0.07961643, 0.9814939 , 0.9279042 , 0.11520195]],            dtype=float32)

Grab the params from the torch net

In [10]:
params = jnp.array(test_net.linear.weight.detach().numpy())
params

DeviceArray([[-1.5172173e-01, -3.8182023e-01, -1.3576640e-01,
               1.5150848e-02, -1.6323280e-01],
             [-2.1489854e-01,  4.4933017e-02,  1.5109259e-01,
               3.2077172e-01, -2.0871787e-01],
             [-1.9395326e-01, -3.7563645e-04,  1.2869158e-01,
              -3.1283948e-01,  3.4492922e-01],
             [-1.0875816e-01,  4.2586431e-01, -2.4989186e-01,
               1.2373009e-01, -3.8127521e-01],
             [-2.0082198e-01, -3.6653295e-01,  3.7867019e-01,
              -2.2911526e-01, -3.0295074e-01]], dtype=float32)

In [11]:
def j_linear(A, x): 
    """like torch.F.linear but in jax"""
    return x @ A.T

In [12]:
def forward(params, x): 
    out = j_linear(params, x)
    return out

In [27]:
backward = jax.jacfwd(forward, 1)

In [37]:
backward(jnp.asarray([[2., 0.], [0., 2.]]), jnp.asarray([1.,1.]))

DeviceArray([[2., 0.],
             [0., 2.]], dtype=float32)

In [38]:
params.shape

(5, 5)

In [39]:
out_j = forward(params, X_j)
out_j

DeviceArray([[-0.17058787,  0.42237467, -0.12706624, -0.14204438,
               0.09208602]], dtype=float32)

In [40]:
out_t

tensor([[-0.1706,  0.4224, -0.1271, -0.1420,  0.0921]], grad_fn=<MmBackward0>)

In [41]:
backward(params, X_j);

In [42]:
import numpy as np

Is the forward pass OK? 

In [43]:
np.allclose(np.array(out_j), out_t.detach().numpy())

True

cool, now for the hard part. 

In [44]:
from jax import jit

In [45]:
@jit
def safe_divide(a, b):
    den = jnp.clip(b, a_min=1e-9) + jnp.clip(b, a_max=1e-9)
    den = den + (den == 0).astype(den.dtype) * 1e-9
    return a / den * (b != 0).astype(b.dtype)

In [46]:
safe_divide(params, X_j)

DeviceArray([[-1.0517899e+01, -4.7957468e+00, -1.3832629e-01,
               1.6328031e-02, -1.4169275e+00],
             [-1.4897543e+01,  5.6436867e-01,  1.5394145e-01,
               3.4569487e-01, -1.8117564e+00],
             [-1.3445540e+01, -4.7180769e-03,  1.3111807e-01,
              -3.3714631e-01,  2.9941266e+00],
             [-7.5395083e+00,  5.3489504e+00, -2.5460359e-01,
               1.3334361e-01, -3.3096247e+00],
             [-1.3921704e+01, -4.6037354e+00,  3.8581002e-01,
              -2.4691693e-01, -2.6297362e+00]], dtype=float32)

In [85]:
def linear_relprop(R, x, params, alpha=1): 
    beta = alpha - 1
    pw = jnp.clip(params, a_min=0)
    nw = jnp.clip(params, a_max=0)
    px = jnp.clip(x,      a_min=0)
    nx = jnp.clip(x,      a_max=0)
    
    def f(w1, w2, x1, x2): 
        z1 = j_linear(w1, x1) 
        z2 = j_linear(w2, x2)
        s1 = safe_divide(R, z1+z2) # why are there two of these
        s2 = safe_divide(R, z1+z2) # for gradient reasons? 
        
        print(f'R: {R.shape}, s:{s1.shape}, z: {z1.shape}, x: {x1.shape}')
        b = backward(params, x1.squeeze())
        print(params.shape, b.shape)
        c1 = x1 * (s1 @ b.T)
        print(f'R: {R.shape}, s:{s1.shape}, z: {z1.shape}, x: {x1.shape}, b:{b.shape}, c:{c1.shape}')
        return c1
    
    return f(pw, nw, px, nx)

In [86]:
relprop_t

tensor([[ 0.0000, -0.0290,  0.1348, -0.0006, -0.0304]], grad_fn=<MulBackward0>)

In [87]:
linear_relprop(out_j, X_j, params)

R: (1, 5), s:(1, 5), z: (1, 5), x: (1, 5)
(5, 5) (5, 5)
R: (1, 5), s:(1, 5), z: (1, 5), x: (1, 5), b:(5, 5), c:(1, 5)


DeviceArray([[0.02208806, 0.17325366, 2.5900505 , 1.5759643 , 0.22422883]],            dtype=float32)

Right shapes, but not the right numbers. I dont think that our backward matches the torch autograd grad