In [1]:
%load_ext autoreload
%autoreload 2

In [31]:
import sys
import torch
import torch.nn
import jax
import jax.numpy as jnp
from jax import lax
import flax
import flax.linen as fnn
import numpy as np
from typing import Callable, Optional, Tuple, Union, Sequence, Iterable

sys.path.append("../bert")
from transformers import BertConfig, BertTokenizer
import modeling_flax_bert as bert_layers
import bert_explainability_layers as ours

In [32]:
def safe_divide(a, b):
    den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
    den = den + den.eq(0).type(den.type()) * 1e-9
    return a / den * b.ne(0).type(b.type())


def forward_hook(self, input, output):
    if type(input[0]) in (list, tuple):
        self.X = []
        for i in input[0]:
            x = i.detach()
            x.requires_grad = True
            self.X.append(x)
    else:
        self.X = input[0].detach()
        self.X.requires_grad = True

    self.Y = output


class RelProp(torch.nn.Module):
    def __init__(self):
        super(RelProp, self).__init__()
        # if not self.training:
        self.register_forward_hook(forward_hook)

    def gradprop(self, Z, X, S):
        C = torch.autograd.grad(Z, X, S, retain_graph=True)
        return C

    def relprop(self, R, alpha):
        return R


class RelPropSimple(RelProp):
    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs

class MatMul(RelPropSimple):
    def forward(self, inputs):
        return torch.matmul(*inputs)
    
class Add(RelPropSimple):
    def forward(self, inputs):
        return torch.add(*inputs)

    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        a = self.X[0] * C[0]
        b = self.X[1] * C[1]

        a_sum = a.sum()
        b_sum = b.sum()

        a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
        b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()

        a = a * safe_divide(a_fact, a.sum())
        b = b * safe_divide(b_fact, b.sum())

        outputs = [a, b]
        return outputs
    
class IndexSelect(RelProp):
    def forward(self, inputs, dim, indices):
        self.__setattr__('dim', dim)
        self.__setattr__('indices', indices)

        return torch.index_select(inputs, dim, indices)

    def relprop(self, R, alpha):
        Z = self.forward(self.X, self.dim, self.indices)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs
    
class Clone(RelProp):
    def forward(self, input, num):
        self.__setattr__('num', num)
        outputs = []
        for _ in range(num):
            outputs.append(input)

        return outputs

    def relprop(self, R, alpha):
        Z = []
        for _ in range(self.num):
            Z.append(self.X)
        S = [safe_divide(r, z) for r, z in zip(R, Z)]
        C = self.gradprop(Z, self.X, S)[0]

        R = self.X * C

        return R

In [33]:
A = torch.ones((2,2))
B = torch.tensor([[-2., 1], [1, -2]])

mm = MatMul()
mm([A,B])

tensor([[-1., -1.],
        [-1., -1.]])

In [34]:
mm.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[ 2., -1.],
         [ 0.,  0.]], grad_fn=<MulBackward0>),
 tensor([[ 2.,  0.],
         [-1., -0.]], grad_fn=<MulBackward0>)]

In [35]:
add = Add()
add([A,B])

tensor([[-1.,  2.],
        [ 2., -1.]])

In [36]:
add.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[0.3333, -0.0000],
         [-0.0000, 0.0000]], grad_fn=<MulBackward0>),
 tensor([[0.6667, 0.0000],
         [0.0000, 0.0000]], grad_fn=<MulBackward0>)]

In [37]:
pool = IndexSelect()
pool(B, 1, torch.zeros(1, dtype=torch.int32))

tensor([[-2.],
        [ 1.]])

In [38]:
pool.relprop(torch.tensor([[1.],[0]]), alpha=1)

tensor([[1., 0.],
        [0., -0.]], grad_fn=<MulBackward0>)

In [39]:
clone = Clone()
clone(B, 2)

[tensor([[-2.,  1.],
         [ 1., -2.]]),
 tensor([[-2.,  1.],
         [ 1., -2.]])]

In [40]:
clone.relprop([torch.tensor([[.5,.5],[0,0]]), torch.tensor([[0,.25],[.25,.5]])], alpha=1)

tensor([[0.5000, 0.7500],
        [0.2500, 0.5000]], grad_fn=<MulBackward0>)

It looks like clone is necessary because of the way pyTorch tracks gradients. Working out the math, I think the relprop works out to a sum over all of the relevances (which intuitively also makes sense). For this reason, I didn't include "clone" in the Flax layers I implemented and whenever I saw a clone.relprop in the pyTorch version, I added the relevances in the Jax version

In [41]:
A = jnp.ones((2,2))
B = jnp.array([[-2., 1], [1, -2]])

jmm = ours.MatMul()
jmm(A,B)

DeviceArray([[-1., -1.],
             [-1., -1.]], dtype=float32)

In [42]:
jmm.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 2., -1.],
              [ 0.,  0.]], dtype=float32),
 DeviceArray([[ 2.,  0.],
              [-1., -0.]], dtype=float32)]

In [43]:
j_add = ours.Add()
j_add(A,B)

DeviceArray([[-1.,  2.],
             [ 2., -1.]], dtype=float32)

In [44]:
j_add.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 0.33333334, -0.        ],
              [-0.        ,  0.        ]], dtype=float32),
 DeviceArray([[0.6666667, 0.       ],
              [0.       , 0.       ]], dtype=float32)]

In [45]:
jax_pool = ours.IndexSelect()
jax_pool(B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[-2.],
             [ 1.]], dtype=float32)

In [46]:
jax_pool.relprop(jnp.array([[1.],[0]]), B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[ 1.,  0.],
             [ 0., -0.]], dtype=float32)

In [19]:
jax_dense = ours.Dense(2)
x = jnp.ones((1,20))
variables = jax_dense.init(jax.random.PRNGKey(0), x)
model = jax_dense.bind(variables)
model(x)
variables
print(model.variables["params"])

FrozenDict({
    kernel: DeviceArray([[ 1.53342038e-01,  3.29581231e-01],
                 [ 4.70121145e-01,  1.81581691e-01],
                 [ 1.80420190e-01,  3.35969597e-01],
                 [-1.62113383e-01, -2.02395543e-01],
                 [-1.18919984e-01, -6.62900805e-02],
                 [-4.10243064e-01,  7.33509436e-02],
                 [ 4.96547073e-02, -2.71931272e-02],
                 [-2.19073966e-01,  4.65010494e-01],
                 [ 1.54693127e-02,  1.10250756e-01],
                 [-3.89764100e-01,  8.84109288e-02],
                 [-3.79199862e-01, -2.35682800e-02],
                 [ 3.24670374e-01,  2.04515398e-01],
                 [ 1.44210760e-04,  1.65998340e-01],
                 [-3.36234242e-01, -3.94779295e-01],
                 [-1.40488684e-01, -5.03639765e-02],
                 [-2.11226270e-02,  2.08313853e-01],
                 [ 9.56470072e-02,  2.05471292e-01],
                 [-9.69679356e-02, -3.74182016e-01],
                 [-3.6075

In [20]:
model.relprop(jnp.array([[1.,0]]), x)

DeviceArray([[1.1891876e-01, 3.6458510e-01, 1.3991822e-01, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 3.8507875e-02, 0.0000000e+00,
              1.1996655e-02, 0.0000000e+00, 0.0000000e+00, 2.5178614e-01,
              1.1183733e-04, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              7.4175507e-02, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],            dtype=float32)

In [47]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
configuration = BertConfig()

In [48]:
inputs = tokenizer("Hello world!")
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

In [49]:
embedding_layer = bert_layers.FlaxBertEmbeddings(configuration)
variables = embedding_layer.init(jax.random.PRNGKey(0), input_ids, token_ids,  position_ids, attention_mask)

In [50]:
output = embedding_layer.apply(variables, input_ids, token_ids, position_ids, attention_mask)
embedding_layer.apply(variables, output, input_ids, token_ids, position_ids, attention_mask, method=embedding_layer.relprop)
output = output[jnp.newaxis]
output.shape

(1, 5, 768)

In [51]:
attention = bert_layers.FlaxBertAttention(configuration)
variables = attention.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = attention.apply(variables, output, attention_mask, None)

In [52]:
print(hidden_states)

(DeviceArray([[[-0.56294453,  1.0679047 ,  1.1388769 , ...,  0.39261633,
               -1.4752828 ,  1.9199177 ],
              [-1.1342766 , -0.7030669 , -0.18691307, ...,  2.3455684 ,
                1.5618547 , -0.48230216],
              [ 0.35637683, -0.39173928,  0.1297656 , ...,  1.4595735 ,
                0.5980807 ,  0.19320047],
              [ 0.66311115,  1.5822634 ,  1.1960373 , ..., -0.13444665,
                1.0535946 ,  1.4545076 ],
              [ 0.5495986 , -0.10840828, -0.42316678, ...,  0.598222  ,
                1.625095  ,  1.389647  ]]], dtype=float32),)


In [64]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam

DeviceArray([[[1.5862346e-04, 3.4134588e-04, 4.3500646e-04, ...,
               2.4813239e-04, 2.8730359e-05, 2.3901957e-04],
              [6.4698499e-05, 1.3880285e-04, 8.1970022e-05, ...,
               1.4927814e-04, 6.8905219e-06, 3.4883863e-04],
              [2.5487348e-04, 4.2201468e-04, 4.9790664e-04, ...,
               6.4399610e-05, 3.4115190e-04, 1.1278792e-04],
              [3.4541208e-05, 4.9971341e-04, 8.6375505e-05, ...,
               1.8742368e-04, 1.1919927e-04, 2.5305376e-04],
              [4.9182947e-04, 3.0515232e-04, 2.1039491e-04, ...,
               4.5212384e-04, 4.7184789e-04, 1.6803121e-04]]],            dtype=float32)

In [66]:
attention.apply(variables, cam, output, None, None, method=attention.relprop)

CAM: [[[1.5862346e-04 3.4134588e-04 4.3500646e-04 ... 2.4813239e-04
   2.8730359e-05 2.3901957e-04]
  [6.4698499e-05 1.3880285e-04 8.1970022e-05 ... 1.4927814e-04
   6.8905219e-06 3.4883863e-04]
  [2.5487348e-04 4.2201468e-04 4.9790664e-04 ... 6.4399610e-05
   3.4115190e-04 1.1278792e-04]
  [3.4541208e-05 4.9971341e-04 8.6375505e-05 ... 1.8742368e-04
   1.1919927e-04 2.5305376e-04]
  [4.9182947e-04 3.0515232e-04 2.1039491e-04 ... 4.5212384e-04
   4.7184789e-04 1.6803121e-04]]]
Dropout: [[[ 0.22248834 -0.01172732  0.37502575 ...  0.43499708  0.11476218
    0.20412494]
  [ 0.24330458  0.01846182  0.37363368 ...  0.39203686  0.13324705
    0.19083147]
  [ 0.26380512  0.02816542  0.36352172 ...  0.44176722  0.12059104
    0.1863118 ]
  [ 0.29731175  0.00489232  0.34569857 ...  0.43914324  0.13150395
    0.17871302]
  [ 0.24196245  0.02608497  0.34836513 ...  0.44917923  0.08554332
    0.17359   ]]]
Input Tensor: [[[-0.8081451   1.1146805   0.80141324 ... -0.03124941 -1.64507
    1.7810174 

DeviceArray([[[ 1.98149457e-04,  3.42299114e-04,  2.94019294e-04, ...,
               -1.79213130e-05,  8.56812949e-06,  2.19850524e-04],
              [-2.83737597e-03, -8.74250603e-04, -6.51159207e-05, ...,
                8.60234490e-04,  1.82910822e-03,  9.70344350e-04],
              [ 1.34632573e-04,  1.56228809e-04, -1.17633003e-03, ...,
                1.49306247e-03,  6.89799432e-04,  3.04724472e-06],
              [ 6.37599296e-05,  1.55628280e-04,  9.63787606e-05, ...,
                6.21967425e-04, -8.81010637e-05,  1.30146625e-04],
              [ 2.55513762e-04,  3.39580321e-04,  4.00387042e-04, ...,
                1.08731794e-04,  3.51434166e-04,  1.12680173e-04]]],            dtype=float32)