In [1]:
import torch 
from torch import nn

from modules import layers_ours as layers

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Their linear layer and relprop

We put a single instance of thier linear layer into a network to test agains 

In [2]:
class TestNet(nn.Module): 
    
    def __init__(self): 
        super().__init__()
        self.linear = layers.Linear(5, 5, bias=False)
    
    def forward(self, x):
        out = self.linear(x)
        return out
    
    def relprop(self, y): 
        kwargs = {'alpha': 1} # this is how they do it in the relprop repo
        return self.linear.relprop(y, **kwargs)

In [3]:
X_t = torch.rand((1,5))

In [4]:
X_t

tensor([[0.6789, 0.6927, 0.1304, 0.1509, 0.1421]])

In [5]:
test_net = TestNet()

In [6]:
out_t = test_net.forward(X_t)
out_t

tensor([[-0.3315, -0.2612,  0.0834, -0.1305, -0.3100]], grad_fn=<MmBackward0>)

In [7]:
relprop_t = test_net.relprop(out_t)
relprop_t

tensor([[0.0101, 0.0580, 0.2691, 0.0455, 0.0475]], grad_fn=<MmBackward0>)
tensor([[-32.7458,  -4.5056,   0.3097,  -2.8677,  -6.5285]],
       grad_fn=<MulBackward0>)
tensor([[0.6789, 0.6927, 0.1304, 0.1509, 0.1421]], requires_grad=True)
RETURNING JUST C


tensor([[ 0.0000,  0.0378, -4.5270,  0.0000, -2.7136]])

## Jax implementation of linear layer relprop

In [8]:
import jax.numpy as jnp
import jax.nn as jnn
import jax

We test the linear layer and relprop against the same input

In [9]:
X_j = jnp.array(X_t)
X_j



DeviceArray([[0.67891246, 0.69272023, 0.13040155, 0.15086126, 0.14212435]],            dtype=float32)

Grab the params from the torch net

In [10]:
params = jnp.array(test_net.linear.weight.detach().numpy())
params

DeviceArray([[-0.07250727, -0.4038238 ,  0.05131372, -0.08358306,
               0.0241376 ],
             [-0.33375564, -0.04740531,  0.44456163, -0.16668098,
              -0.24337842],
             [-0.20267841,  0.32940978, -0.05643311, -0.27050206,
               0.28800365],
             [-0.13518845,  0.0223879 ,  0.2300261 , -0.15898027,
              -0.4238099 ],
             [-0.07810157, -0.35893023,  0.02818899, -0.37012824,
               0.30825394]], dtype=float32)

In [11]:
def j_linear(A, x): 
    """like torch.F.linear but in jax"""
    return x @ A.T

In [12]:
def forward(params, x): 
    out = j_linear(params, x)
    return out

In [13]:
backward = jax.jacfwd(forward, 1)

In [14]:
backward(jnp.asarray([[2., 0.], [0., 2.]]), jnp.asarray([1.,1.]))

DeviceArray([[2., 0.],
             [0., 2.]], dtype=float32)

In [15]:
params.shape

(5, 5)

In [16]:
out_j = forward(params, X_j)
out_j

DeviceArray([[-0.33145052, -0.26119366,  0.08335301, -0.13049448,
              -0.3100141 ]], dtype=float32)

In [17]:
out_t

tensor([[-0.3315, -0.2612,  0.0834, -0.1305, -0.3100]], grad_fn=<MmBackward0>)

In [18]:
backward(params, X_j.squeeze())

DeviceArray([[-0.07250727, -0.4038238 ,  0.05131372, -0.08358306,
               0.0241376 ],
             [-0.33375564, -0.04740531,  0.44456163, -0.16668098,
              -0.24337842],
             [-0.20267841,  0.32940978, -0.05643311, -0.27050206,
               0.28800365],
             [-0.13518845,  0.0223879 ,  0.2300261 , -0.15898027,
              -0.4238099 ],
             [-0.07810157, -0.35893023,  0.02818899, -0.37012824,
               0.30825394]], dtype=float32)

In [19]:
import numpy as np

Is the forward pass OK? 

In [20]:
np.allclose(np.array(out_j), out_t.detach().numpy())

True

cool, now for the hard part. 

In [21]:
from jax import jit

In [22]:
@jit
def safe_divide(a, b):
    den = jnp.clip(b, a_min=1e-9) + jnp.clip(b, a_max=1e-9)
    den = den + (den == 0).astype(den.dtype) * 1e-9
    return a / den * (b != 0).astype(b.dtype)

In [23]:
safe_divide(params, X_j)

DeviceArray([[-0.10679914, -0.58295363,  0.39350545, -0.55403924,
               0.16983438],
             [-0.49160334, -0.06843356,  3.4091744 , -1.1048627 ,
              -1.7124329 ],
             [-0.29853395,  0.47553074, -0.4327641 , -1.7930518 ,
               2.02642   ],
             [-0.199125  ,  0.03231882,  1.7639828 , -1.0538176 ,
              -2.9819653 ],
             [-0.11503923, -0.51814604,  0.2161707 , -2.4534347 ,
               2.168903  ]], dtype=float32)

In [30]:
?jax.jvp

[0;31mSignature:[0m [0mjax[0m[0;34m.[0m[0mjvp[0m[0;34m([0m[0mfun[0m[0;34m:[0m [0;34m'Callable'[0m[0;34m,[0m [0mprimals[0m[0;34m,[0m [0mtangents[0m[0;34m,[0m [0mhas_aux[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m)[0m [0;34m->[0m [0;34m'Tuple[Any, ...]'[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Computes a (forward-mode) Jacobian-vector product of ``fun``.

Args:
  fun: Function to be differentiated. Its arguments should be arrays, scalars,
    or standard Python containers of arrays or scalars. It should return an
    array, scalar, or standard Python container of arrays or scalars.
  primals: The primal values at which the Jacobian of ``fun`` should be
    evaluated. Should be either a tuple or a list of arguments,
    and its length should be equal to the number of positional parameters of
    ``fun``.
  tangents: The tangent vector for which the Jacobian-vector product should be
    evaluated. Should be either a tuple

In [62]:
def linear_relprop(R, x, params, alpha=1): 
    beta = alpha - 1
    pw = jnp.clip(params, a_min=0)
    nw = jnp.clip(params, a_max=0)
    px = jnp.clip(x,      a_min=0)
    nx = jnp.clip(x,      a_max=0)
    
    def f(w1, w2, x1, x2): 
        z1 = j_linear(w1, x1) 
        z2 = j_linear(w2, x2)
        s1 = safe_divide(R, z1+z2) # why are there two of these
        s2 = safe_divide(R, z1+z2) # for gradient reasons? 
        
        print(z1)
        print(s1)
        print(x1)
        
        print(f'R: {R.shape}, s:{s1.shape}, z: {z1.shape}, x: {x1.shape}')
        f = lambda x : j_linear(w1, x)
        _, b_func = jax.vjp(f, x1)
        c1 = b_func(s1)
        #c1 = jax.jvp(f, (x1,), (s1,))
        #b = backward(params, x1.squeeze())
        #print(b)
        #print(params.shape, b.shape)
        #c1 = (s1 @ b.T)
        #print(f'R: {R.shape}, s:{s1.shape}, z: {z1.shape}, x: {x1.shape}, c:{c1.shape}')
        return c1
    
    return f(pw, nw, px, nx)

In [63]:
linear_relprop(out_j, X_j, params)

[[0.01012193 0.05797153 0.26912114 0.04550431 0.04748628]]
[[-32.74578    -4.505551    0.309723   -2.8677387  -6.5284977]]
[[0.67891246 0.69272023 0.13040155 0.15086126 0.14212435]]
R: (1, 5), s:(1, 5), z: (1, 5), x: (1, 5)


(DeviceArray([[ 0.        ,  0.03782313, -4.5269895 ,  0.        ,
               -2.7136385 ]], dtype=float32),)

In [64]:
test_net.relprop(out_t)

tensor([[0.0101, 0.0580, 0.2691, 0.0455, 0.0475]], grad_fn=<MmBackward0>)
tensor([[-32.7458,  -4.5056,   0.3097,  -2.8677,  -6.5285]],
       grad_fn=<MulBackward0>)
tensor([[0.6789, 0.6927, 0.1304, 0.1509, 0.1421]], requires_grad=True)
RETURNING JUST C


tensor([[ 0.0000,  0.0378, -4.5270,  0.0000, -2.7136]])