In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import torch
import torch.nn
import jax
import jax.numpy as jnp
from jax import lax
import flax
import flax.linen as fnn
import numpy as np
from typing import Callable, Optional, Tuple, Union, Sequence, Iterable

sys.path.append("../bert")
from transformers import BertConfig, BertTokenizer
import modeling_flax_bert as bert_layers
import bert_explainability_layers as ours

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def safe_divide(a, b):
    den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
    den = den + den.eq(0).type(den.type()) * 1e-9
    return a / den * b.ne(0).type(b.type())


def forward_hook(self, input, output):
    if type(input[0]) in (list, tuple):
        self.X = []
        for i in input[0]:
            x = i.detach()
            x.requires_grad = True
            self.X.append(x)
    else:
        self.X = input[0].detach()
        self.X.requires_grad = True

    self.Y = output


class RelProp(torch.nn.Module):
    def __init__(self):
        super(RelProp, self).__init__()
        # if not self.training:
        self.register_forward_hook(forward_hook)

    def gradprop(self, Z, X, S):
        C = torch.autograd.grad(Z, X, S, retain_graph=True)
        return C

    def relprop(self, R, alpha):
        return R


class RelPropSimple(RelProp):
    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs

class MatMul(RelPropSimple):
    def forward(self, inputs):
        return torch.matmul(*inputs)
    
class Add(RelPropSimple):
    def forward(self, inputs):
        return torch.add(*inputs)

    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        a = self.X[0] * C[0]
        b = self.X[1] * C[1]

        a_sum = a.sum()
        b_sum = b.sum()

        a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
        b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()

        a = a * safe_divide(a_fact, a.sum())
        b = b * safe_divide(b_fact, b.sum())

        outputs = [a, b]
        return outputs
    
class IndexSelect(RelProp):
    def forward(self, inputs, dim, indices):
        self.__setattr__('dim', dim)
        self.__setattr__('indices', indices)

        return torch.index_select(inputs, dim, indices)

    def relprop(self, R, alpha):
        Z = self.forward(self.X, self.dim, self.indices)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs
    
class Clone(RelProp):
    def forward(self, input, num):
        self.__setattr__('num', num)
        outputs = []
        for _ in range(num):
            outputs.append(input)

        return outputs

    def relprop(self, R, alpha):
        Z = []
        for _ in range(self.num):
            Z.append(self.X)
        S = [safe_divide(r, z) for r, z in zip(R, Z)]
        C = self.gradprop(Z, self.X, S)[0]

        R = self.X * C

        return R

In [4]:
A = torch.ones((2,2))
B = torch.tensor([[-2., 1], [1, -2]])

mm = MatMul()
mm([A,B])

tensor([[-1., -1.],
        [-1., -1.]])

In [5]:
mm.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[ 2., -1.],
         [ 0.,  0.]], grad_fn=<MulBackward0>),
 tensor([[ 2.,  0.],
         [-1., -0.]], grad_fn=<MulBackward0>)]

In [6]:
add = Add()
add([A,B])

tensor([[-1.,  2.],
        [ 2., -1.]])

In [7]:
add.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[0.3333, -0.0000],
         [-0.0000, 0.0000]], grad_fn=<MulBackward0>),
 tensor([[0.6667, 0.0000],
         [0.0000, 0.0000]], grad_fn=<MulBackward0>)]

In [8]:
pool = IndexSelect()
pool(B, 1, torch.zeros(1, dtype=torch.int32))

tensor([[-2.],
        [ 1.]])

In [9]:
pool.relprop(torch.tensor([[1.],[0]]), alpha=1)

tensor([[1., 0.],
        [0., -0.]], grad_fn=<MulBackward0>)

In [10]:
clone = Clone()
clone(B, 2)

[tensor([[-2.,  1.],
         [ 1., -2.]]),
 tensor([[-2.,  1.],
         [ 1., -2.]])]

In [11]:
clone.relprop([torch.tensor([[.5,.5],[0,0]]), torch.tensor([[0,.25],[.25,.5]])], alpha=1)

tensor([[0.5000, 0.7500],
        [0.2500, 0.5000]], grad_fn=<MulBackward0>)

It looks like clone is necessary because of the way pyTorch tracks gradients. Working out the math, I think the relprop works out to a sum over all of the relevances (which intuitively also makes sense). For this reason, I didn't include "clone" in the Flax layers I implemented and whenever I saw a clone.relprop in the pyTorch version, I added the relevances in the Jax version

In [12]:
A = jnp.ones((2,2))
B = jnp.array([[-2., 1], [1, -2]])

jmm = ours.MatMul()
jmm(A,B)



DeviceArray([[-1., -1.],
             [-1., -1.]], dtype=float32)

In [13]:
jmm.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 2., -1.],
              [ 0.,  0.]], dtype=float32),
 DeviceArray([[ 2.,  0.],
              [-1., -0.]], dtype=float32)]

In [14]:
j_add = ours.Add()
j_add(A,B)

DeviceArray([[-1.,  2.],
             [ 2., -1.]], dtype=float32)

In [15]:
j_add.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 0.33333334, -0.        ],
              [-0.        ,  0.        ]], dtype=float32),
 DeviceArray([[0.6666667, 0.       ],
              [0.       , 0.       ]], dtype=float32)]

In [16]:
jax_pool = ours.IndexSelect()
jax_pool(B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[-2.],
             [ 1.]], dtype=float32)

In [17]:
jax_pool.relprop(jnp.array([[1.],[0]]), B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[ 1.,  0.],
             [ 0., -0.]], dtype=float32)

In [18]:
jax_dense = ours.Dense(2)
x = jnp.ones((1,20))
variables = jax_dense.init(jax.random.PRNGKey(0), x)
model = jax_dense.bind(variables)
model(x)
variables
print(model.variables["params"])

FrozenDict({
    kernel: DeviceArray([[ 1.53342038e-01,  3.29581231e-01],
                 [ 4.70121145e-01,  1.81581691e-01],
                 [ 1.80420190e-01,  3.35969597e-01],
                 [-1.62113383e-01, -2.02395543e-01],
                 [-1.18919984e-01, -6.62900805e-02],
                 [-4.10243064e-01,  7.33509436e-02],
                 [ 4.96547073e-02, -2.71931272e-02],
                 [-2.19073966e-01,  4.65010494e-01],
                 [ 1.54693127e-02,  1.10250756e-01],
                 [-3.89764100e-01,  8.84109288e-02],
                 [-3.79199862e-01, -2.35682800e-02],
                 [ 3.24670374e-01,  2.04515398e-01],
                 [ 1.44210760e-04,  1.65998340e-01],
                 [-3.36234242e-01, -3.94779295e-01],
                 [-1.40488684e-01, -5.03639765e-02],
                 [-2.11226270e-02,  2.08313853e-01],
                 [ 9.56470072e-02,  2.05471292e-01],
                 [-9.69679356e-02, -3.74182016e-01],
                 [-3.6075

In [19]:
model.relprop(jnp.array([[1.,0]]), x)

DeviceArray([[1.1891876e-01, 3.6458510e-01, 1.3991822e-01, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 3.8507875e-02, 0.0000000e+00,
              1.1996655e-02, 0.0000000e+00, 0.0000000e+00, 2.5178614e-01,
              1.1183733e-04, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              7.4175507e-02, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],            dtype=float32)

In [20]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
configuration = BertConfig()

In [21]:
inputs = tokenizer("Hello world!")
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

In [22]:
embedding_layer = bert_layers.FlaxBertEmbeddings(configuration)
variables = embedding_layer.init(jax.random.PRNGKey(0), input_ids, token_ids,  position_ids, attention_mask)
output = embedding_layer.apply(variables, input_ids, token_ids, position_ids, attention_mask)
output.shape

(5, 768)

In [23]:
cam = jnp.array(np.random.rand(*output.shape))
cam = cam / cam.sum()
cam = embedding_layer.apply(variables, output, input_ids, token_ids, position_ids, attention_mask, method=embedding_layer.relprop)
cam

[DeviceArray([[ 3.1433138e-09, -4.8694666e-09, -4.3969450e-09, ...,
                2.4628964e-12,  6.5122796e-09, -7.5907156e-09],
              [ 5.2464384e-09,  4.1711883e-09,  3.5693684e-09, ...,
               -1.1491520e-08, -7.9963085e-09,  2.2653799e-09],
              [-7.7218343e-11,  4.1342711e-09, -7.0949807e-10, ...,
               -8.1966034e-09, -1.1211086e-08, -3.1586398e-13],
              [-1.8080911e-09, -5.1735189e-09, -3.4792007e-09, ...,
                1.5778143e-09, -1.3092965e-08, -6.1562648e-09],
              [-1.3209663e-09, -4.1299006e-10,  8.9867500e-09, ...,
                3.6140119e-10, -7.1364568e-09, -7.4246453e-09]],            dtype=float32),
 DeviceArray([[-1.7690714e-09,  1.9062234e-09,  4.7452042e-10, ...,
               -1.8748968e-10, -3.4874317e-09,  3.2353644e-09],
              [-3.4204888e-09, -4.0048584e-10,  1.0476618e-10, ...,
                9.5655728e-10,  1.0658153e-09, -1.9571462e-09],
              [ 5.5020560e-10,  1.5006854e-09, -

In [24]:
output = output[jnp.newaxis]
output.shape

(1, 5, 768)

In [25]:
attention = bert_layers.FlaxBertAttention(configuration)
variables = attention.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = attention.apply(variables, output, attention_mask, None)

In [26]:
print(hidden_states)

(DeviceArray([[[-0.56294453,  1.0679047 ,  1.1388768 , ...,  0.3926163 ,
               -1.4752828 ,  1.9199177 ],
              [-1.1342765 , -0.7030669 , -0.18691309, ...,  2.3455682 ,
                1.5618546 , -0.48230222],
              [ 0.3563769 , -0.39173922,  0.12976554, ...,  1.4595736 ,
                0.5980807 ,  0.19320044],
              [ 0.6631111 ,  1.5822634 ,  1.1960372 , ..., -0.13444656,
                1.0535947 ,  1.4545076 ],
              [ 0.54959863, -0.10840819, -0.42316663, ...,  0.59822196,
                1.625095  ,  1.389647  ]]], dtype=float32),)


In [27]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam

DeviceArray([[[3.47962632e-04, 3.03811335e-04, 3.28713504e-04, ...,
               8.99931110e-05, 3.77409306e-04, 1.58041759e-04],
              [4.52530629e-04, 4.54374327e-04, 8.31495490e-05, ...,
               3.55336437e-04, 3.31967371e-04, 4.95868990e-05],
              [4.04895400e-04, 3.29074363e-04, 4.75295499e-04, ...,
               1.99961578e-04, 4.47486265e-04, 1.11088084e-04],
              [9.64094797e-05, 3.58464313e-04, 2.52198399e-04, ...,
               3.76594340e-04, 4.42620403e-05, 4.32472967e-04],
              [1.06285097e-05, 2.65545590e-04, 3.56618839e-04, ...,
               2.84285634e-04, 4.23955134e-06, 4.51202039e-04]]],            dtype=float32)

In [28]:
cam = attention.apply(variables, cam, output, None, None, method=attention.relprop)
cam.sum()

DeviceArray(1.0000001, dtype=float32)

In [29]:
layer = bert_layers.FlaxBertLayer(configuration)
variables = layer.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = layer.apply(variables, output, attention_mask, None)

In [30]:
hidden_states[0].shape

(1, 5, 768)

In [31]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam = layer.apply(variables, cam, output, attention_mask, None, method=layer.relprop)

In [32]:
cam.shape
cam.sum()

DeviceArray(0.9999999, dtype=float32)

In [33]:
layer_collection = bert_layers.FlaxBertLayerCollection(configuration)
variables = layer_collection.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = layer_collection.apply(variables, output, attention_mask, None)

2022-11-30 11:28:25.208107: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-11-30 11:28:25.228489: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-30 11:28:25.771252: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-11-30 11:28:25.771369: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [34]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam = layer_collection.apply(variables, cam, output, attention_mask, None, method=layer_collection.relprop)

In [35]:
cam.sum()

DeviceArray(0.9999996, dtype=float32)

In [36]:
encoder = bert_layers.FlaxBertEncoder(configuration)
variables = encoder.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = encoder.apply(variables, output, attention_mask, None)

In [37]:
hidden_states

FlaxBaseModelOutputWithPastAndCrossAttentions(last_hidden_state=DeviceArray([[[-0.73932266, -0.10496283, -0.23650123, ..., -0.29518175,
               -0.07944169,  2.5543838 ],
              [-1.0293188 , -1.4049089 ,  0.43076634, ...,  0.04329822,
                1.65496   ,  0.30194774],
              [ 0.4192415 , -1.8640374 ,  0.49091607, ...,  0.8654702 ,
               -0.00657526,  0.5579648 ],
              [ 0.07758794, -0.8224394 , -0.11835903, ..., -0.11618427,
                1.0518749 ,  1.1674784 ],
              [ 0.09991326, -0.80100965, -1.4718195 , ...,  0.5109208 ,
                0.6762812 ,  1.2827846 ]]], dtype=float32), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)

In [38]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam = encoder.apply(variables, cam, output, attention_mask, None, method=encoder.relprop)

In [39]:
cam.sum()

DeviceArray(1., dtype=float32)

In [40]:
inputs = tokenizer(["Hello world!",])
input_ids = jnp.array(inputs["input_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
print(input_ids.shape)

bert_module = bert_layers.FlaxBertModule(configuration)
variables = bert_module.init(jax.random.PRNGKey(0), input_ids, attention_mask, None)
hidden_states = bert_module.apply(variables, input_ids, attention_mask, None)

(1, 5)


In [41]:
hidden_states

FlaxBaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=DeviceArray([[[-1.803636  , -0.90945095, -0.4882279 , ..., -1.7539307 ,
                0.26016185, -0.39490893],
              [-1.0041101 ,  0.06071824, -0.61647964, ..., -0.9118402 ,
               -1.5660151 , -1.7031233 ],
              [-0.70081997, -0.7901022 , -1.5094408 , ..., -0.9647325 ,
               -0.89686865,  0.2988747 ],
              [-0.44771618, -0.58907217, -2.2381587 , ...,  0.40958568,
               -0.7222764 , -1.6665676 ],
              [-1.0788714 , -1.6146749 , -1.6605315 , ..., -1.2121923 ,
               -0.27918416, -0.9997014 ]]], dtype=float32), pooler_output=DeviceArray([[-3.71137857e-01,  6.60066009e-01, -9.70576286e-01,
               7.94233978e-02, -6.85315728e-01, -2.48473790e-02,
              -3.00466567e-01,  6.13525093e-01, -3.73042107e-01,
               3.94238204e-01,  8.54212463e-01,  1.11271188e-01,
              -4.34547335e-01,  3.64358783e-01, -4.61808622e-01,
      

In [42]:
cam = jnp.array(np.random.rand(*hidden_states[1].shape))
cam = cam / cam.sum()
cam = bert_module.apply(variables, cam, input_ids, attention_mask, None, method=bert_module.relprop)

In [43]:
cam

[DeviceArray([[[ 3.2232713e-03,  1.8840600e-04,  2.4332281e-04, ...,
                 1.4012388e-03,  1.9539725e-03,  2.0035658e-02],
               [ 5.3010933e-04, -9.6512129e-05,  8.0519101e-05, ...,
                -2.8975915e-06,  2.4951586e-07,  2.2727574e-05],
               [ 2.8188466e-04,  1.6813600e-04,  2.7707119e-06, ...,
                 8.7276339e-06, -2.5991199e-05,  1.7369130e-04],
               [ 5.7742797e-04, -4.3127539e-08,  1.2448423e-07, ...,
                 4.5172137e-04,  1.6437768e-04, -1.2961204e-06],
               [ 2.0184058e-04,  4.6511533e-04,  1.3332981e-04, ...,
                -5.1738010e-05,  2.7215676e-04, -8.2387458e-05]]],            dtype=float32),
 DeviceArray([[[ 9.4673916e-04,  1.8274126e-04,  1.3683217e-04, ...,
                 1.4695181e-03,  5.6467979e-04, -1.3139717e-02],
               [-2.5422843e-05,  1.4336115e-04,  5.7703070e-05, ...,
                 4.9630376e-07, -1.3612440e-06,  5.6747649e-05],
               [-1.3085324e-04,  

In [44]:
print(len(cam))

2


In [45]:
#cam[0] corresponds to position/token-type embeddings and cam[1] corresponds to token embeddings
cam[0].sum() + cam[1].sum()

DeviceArray(1.0000002, dtype=float32)

In [71]:
inputs = tokenizer(["Hello world!",])
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

seq_class_module = bert_layers.FlaxBertForSequenceClassificationModule(configuration)
variables = seq_class_module.init(jax.random.PRNGKey(0), input_ids, attention_mask, token_ids, position_ids, None)
hidden_states = seq_class_module.apply(variables, input_ids, attention_mask, token_ids, position_ids, None)

In [73]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
print(cam, input_ids.dtype, attention_mask.dtype)
cam = seq_class_module.apply(variables, cam, input_ids, attention_mask, token_ids, position_ids, None, method=seq_class_module.relprop)

[[0.82040316 0.17959687]] int32 int32


In [62]:
model = bert_layers.FlaxBertForSequenceClassification(configuration)

In [64]:
output = model(input_ids)
output

FlaxSequenceClassifierOutput(logits=DeviceArray([[ 1.0231309 , -0.18921821]], dtype=float32), hidden_states=None, attentions=None)

In [78]:
cam = jnp.array(np.random.rand(*output[0].shape))
model.relprop(cam, input_ids)

[DeviceArray([[[-9.8497968e-04,  2.5075275e-07,  3.2105399e-04, ...,
                -6.8772458e-03, -2.3458910e-03, -9.2084100e-03],
               [ 1.1722860e-02, -6.9105008e-05,  1.0484050e-02, ...,
                -9.4426163e-03, -4.2159725e-03,  5.6059042e-04],
               [-9.0922178e-05, -3.5654928e-08,  7.7627064e-06, ...,
                -2.0889578e-04, -3.1042389e-05, -2.5963900e-05],
               [ 4.9927854e-03,  1.3115408e-04, -2.3248696e-04, ...,
                -1.6548664e-03,  5.0835941e-05, -2.5872296e-05],
               [-9.9246390e-05, -4.6292516e-06, -2.3977894e-04, ...,
                -6.1016774e-05, -2.4277764e-05,  1.8905023e-07]]],            dtype=float32),
 DeviceArray([[[-5.6007382e-04,  1.6211176e-06, -9.3178962e-05, ...,
                 3.3156369e-03,  7.2201586e-04, -3.9590094e-03],
               [-1.8556701e-03,  3.9240844e-05, -7.3660696e-03, ...,
                -7.0078223e-04,  5.8974233e-04,  1.0548327e-03],
               [-1.7764745e-05,  