In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import torch
import torch.nn
import jax
import jax.numpy as jnp
from jax import lax
import flax
import flax.linen as fnn
import numpy as np
from typing import Callable, Optional, Tuple, Union, Sequence, Iterable

sys.path.append("../bert")
from transformers import BertConfig, BertTokenizer
import modeling_flax_bert as bert_layers
import bert_explainability_layers as ours

  from .autonotebook import tqdm as notebook_tqdm
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]


In [3]:
def safe_divide(a, b):
    den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
    den = den + den.eq(0).type(den.type()) * 1e-9
    return a / den * b.ne(0).type(b.type())


def forward_hook(self, input, output):
    if type(input[0]) in (list, tuple):
        self.X = []
        for i in input[0]:
            x = i.detach()
            x.requires_grad = True
            self.X.append(x)
    else:
        self.X = input[0].detach()
        self.X.requires_grad = True

    self.Y = output


class RelProp(torch.nn.Module):
    def __init__(self):
        super(RelProp, self).__init__()
        # if not self.training:
        self.register_forward_hook(forward_hook)

    def gradprop(self, Z, X, S):
        C = torch.autograd.grad(Z, X, S, retain_graph=True)
        return C

    def relprop(self, R, alpha):
        return R


class RelPropSimple(RelProp):
    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs

class MatMul(RelPropSimple):
    def forward(self, inputs):
        return torch.matmul(*inputs)
    
class Add(RelPropSimple):
    def forward(self, inputs):
        return torch.add(*inputs)

    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        a = self.X[0] * C[0]
        b = self.X[1] * C[1]

        a_sum = a.sum()
        b_sum = b.sum()

        a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
        b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()

        a = a * safe_divide(a_fact, a.sum())
        b = b * safe_divide(b_fact, b.sum())

        outputs = [a, b]
        return outputs
    
class IndexSelect(RelProp):
    def forward(self, inputs, dim, indices):
        self.__setattr__('dim', dim)
        self.__setattr__('indices', indices)

        return torch.index_select(inputs, dim, indices)

    def relprop(self, R, alpha):
        Z = self.forward(self.X, self.dim, self.indices)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs
    
class Clone(RelProp):
    def forward(self, input, num):
        self.__setattr__('num', num)
        outputs = []
        for _ in range(num):
            outputs.append(input)

        return outputs

    def relprop(self, R, alpha):
        Z = []
        for _ in range(self.num):
            Z.append(self.X)
        S = [safe_divide(r, z) for r, z in zip(R, Z)]
        C = self.gradprop(Z, self.X, S)[0]

        R = self.X * C

        return R

In [4]:
A = torch.ones((2,2))
B = torch.tensor([[-2., 1], [1, -2]])

mm = MatMul()
mm([A,B])

tensor([[-1., -1.],
        [-1., -1.]])

In [5]:
mm.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[ 2., -1.],
         [ 0.,  0.]], grad_fn=<MulBackward0>),
 tensor([[ 2.,  0.],
         [-1., -0.]], grad_fn=<MulBackward0>)]

In [6]:
add = Add()
add([A,B])

tensor([[-1.,  2.],
        [ 2., -1.]])

In [7]:
add.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[0.3333, -0.0000],
         [-0.0000, 0.0000]], grad_fn=<MulBackward0>),
 tensor([[0.6667, 0.0000],
         [0.0000, 0.0000]], grad_fn=<MulBackward0>)]

In [8]:
pool = IndexSelect()
pool(B, 1, torch.zeros(1, dtype=torch.int32))

tensor([[-2.],
        [ 1.]])

In [9]:
pool.relprop(torch.tensor([[1.],[0]]), alpha=1)

tensor([[1., 0.],
        [0., -0.]], grad_fn=<MulBackward0>)

In [10]:
clone = Clone()
clone(B, 2)

[tensor([[-2.,  1.],
         [ 1., -2.]]),
 tensor([[-2.,  1.],
         [ 1., -2.]])]

In [11]:
clone.relprop([torch.tensor([[.5,.5],[0,0]]), torch.tensor([[0,.25],[.25,.5]])], alpha=1)

tensor([[0.5000, 0.7500],
        [0.2500, 0.5000]], grad_fn=<MulBackward0>)

It looks like clone is necessary because of the way pyTorch tracks gradients. Working out the math, I think the relprop works out to a sum over all of the relevances (which intuitively also makes sense). For this reason, I didn't include "clone" in the Flax layers I implemented and whenever I saw a clone.relprop in the pyTorch version, I added the relevances in the Jax version

In [12]:
A = jnp.ones((2,2))
B = jnp.array([[-2., 1], [1, -2]])

jmm = ours.MatMul()
jmm(A,B)



DeviceArray([[-1., -1.],
             [-1., -1.]], dtype=float32)

In [13]:
jmm.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 2., -1.],
              [ 0.,  0.]], dtype=float32),
 DeviceArray([[ 2.,  0.],
              [-1., -0.]], dtype=float32)]

In [14]:
j_add = ours.Add()
j_add(A,B)

DeviceArray([[-1.,  2.],
             [ 2., -1.]], dtype=float32)

In [15]:
j_add.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 0.33333334, -0.        ],
              [-0.        ,  0.        ]], dtype=float32),
 DeviceArray([[0.6666667, 0.       ],
              [0.       , 0.       ]], dtype=float32)]

In [16]:
jax_pool = ours.IndexSelect()
jax_pool(B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[-2.],
             [ 1.]], dtype=float32)

In [17]:
jax_pool.relprop(jnp.array([[1.],[0]]), B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[ 1.,  0.],
             [ 0., -0.]], dtype=float32)

In [18]:
jax_dense = ours.Dense(2)
x = jnp.ones((1,20))
variables = jax_dense.init(jax.random.PRNGKey(0), x)
model = jax_dense.bind(variables)
model(x)
variables
print(model.variables["params"])

FrozenDict({
    kernel: DeviceArray([[ 1.53342038e-01,  3.29581231e-01],
                 [ 4.70121145e-01,  1.81581691e-01],
                 [ 1.80420190e-01,  3.35969597e-01],
                 [-1.62113383e-01, -2.02395543e-01],
                 [-1.18919984e-01, -6.62900805e-02],
                 [-4.10243064e-01,  7.33509436e-02],
                 [ 4.96547073e-02, -2.71931272e-02],
                 [-2.19073966e-01,  4.65010494e-01],
                 [ 1.54693127e-02,  1.10250756e-01],
                 [-3.89764100e-01,  8.84109288e-02],
                 [-3.79199862e-01, -2.35682800e-02],
                 [ 3.24670374e-01,  2.04515398e-01],
                 [ 1.44210760e-04,  1.65998340e-01],
                 [-3.36234242e-01, -3.94779295e-01],
                 [-1.40488684e-01, -5.03639765e-02],
                 [-2.11226270e-02,  2.08313853e-01],
                 [ 9.56470072e-02,  2.05471292e-01],
                 [-9.69679356e-02, -3.74182016e-01],
                 [-3.6075

In [19]:
model.relprop(jnp.array([[1.,0]]), x)

DeviceArray([[1.1891876e-01, 3.6458510e-01, 1.3991822e-01, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 3.8507875e-02, 0.0000000e+00,
              1.1996655e-02, 0.0000000e+00, 0.0000000e+00, 2.5178614e-01,
              1.1183733e-04, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              7.4175507e-02, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],            dtype=float32)

In [20]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
configuration = BertConfig()

Downloading: 100%|███████████████████████████████████████████████████| 232k/232k [00:00<00:00, 2.81MB/s]
Downloading: 100%|███████████████████████████████████████████████████| 28.0/28.0 [00:00<00:00, 21.6kB/s]
Downloading: 100%|██████████████████████████████████████████████████████| 570/570 [00:00<00:00, 475kB/s]


In [46]:
inputs = tokenizer("Hello world!")
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

In [47]:
embedding_layer = bert_layers.FlaxBertEmbeddings(configuration)
variables = embedding_layer.init(jax.random.PRNGKey(0), input_ids, token_ids,  position_ids, attention_mask)
output = embedding_layer.apply(variables, input_ids, token_ids, position_ids, attention_mask)
output.shape

(5, 768)

In [48]:
cam = jnp.array(np.random.rand(*output.shape))
cam = cam / cam.sum()
cam = embedding_layer.apply(variables, output, input_ids, token_ids, position_ids, attention_mask, method=embedding_layer.relprop)
cam

[DeviceArray([[ 1.47529153e-08, -2.83236457e-08, -2.98519609e-08, ...,
               -9.80268879e-08,  3.19056568e-08, -4.26698819e-08],
              [ 2.07204494e-08,  2.83888308e-08,  2.44977816e-08, ...,
               -7.83912455e-08, -5.39528457e-08,  3.94658128e-09],
              [ 2.63997428e-08,  2.46574903e-08,  3.82257674e-08, ...,
               -5.37965015e-08, -3.58449519e-08,  1.20567904e-07],
              [-1.14016965e-08, -6.57746302e-09, -1.69149015e-08, ...,
               -5.53824497e-09, -6.18229024e-08, -3.85396532e-08],
              [-7.03338898e-09,  2.34075426e-08,  4.81068483e-08, ...,
               -3.22811111e-08, -4.31609166e-08, -5.09730675e-08]],            dtype=float32),
 DeviceArray([[-1.8991523e-08,  1.8220303e-08,  3.6113799e-09, ...,
               -9.9331722e-08, -3.6784922e-08,  3.1697240e-08],
              [-3.8814896e-08, -3.0151741e-09,  6.9854250e-10, ...,
                7.1177952e-09,  8.2972127e-09, -2.5059039e-08],
              [ 3.

In [49]:
output = output[jnp.newaxis]
output.shape

(1, 5, 768)

In [50]:
self_attention = bert_layers.FlaxBertSelfAttention(configuration)
variables = self_attention.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = self_attention.apply(variables, output, attention_mask, None)

In [51]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam, attn_cams = self_attention.apply(variables, cam, output, None, None, method=self_attention.relprop)
attn_cams.shape

(1, 12, 5, 5)

In [52]:
attention = bert_layers.FlaxBertAttention(configuration)
variables = attention.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = attention.apply(variables, output, attention_mask, None)

In [53]:
print(hidden_states)

(DeviceArray([[[-0.56294453,  1.0679047 ,  1.1388768 , ...,  0.3926163 ,
               -1.4752828 ,  1.9199177 ],
              [-1.1342765 , -0.7030669 , -0.18691309, ...,  2.3455682 ,
                1.5618546 , -0.48230222],
              [ 0.3563769 , -0.39173922,  0.12976554, ...,  1.4595736 ,
                0.5980807 ,  0.19320044],
              [ 0.6631111 ,  1.5822634 ,  1.1960372 , ..., -0.13444656,
                1.0535947 ,  1.4545076 ],
              [ 0.54959863, -0.10840819, -0.42316663, ...,  0.59822196,
                1.625095  ,  1.389647  ]]], dtype=float32),)


In [54]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam

DeviceArray([[[2.0544577e-04, 1.8731359e-04, 1.5842455e-04, ...,
               3.2136851e-04, 3.2195373e-04, 3.1940680e-04],
              [3.7368256e-05, 4.3913317e-04, 4.5112115e-06, ...,
               1.6339459e-04, 1.6674273e-04, 2.8629354e-04],
              [1.9990503e-04, 4.0936793e-05, 4.6675006e-04, ...,
               2.0885885e-04, 2.7191587e-04, 8.4521111e-05],
              [3.6794922e-04, 2.4513615e-04, 4.9480487e-04, ...,
               2.4759353e-04, 5.0792940e-05, 1.2597123e-04],
              [1.9601891e-04, 2.9551189e-05, 4.4787728e-04, ...,
               2.1334381e-04, 5.0185749e-04, 1.7077646e-06]]],            dtype=float32)

In [55]:
cam, attn_cams = attention.apply(variables, cam, output, None, None, method=attention.relprop)
cam.sum()

DeviceArray(0.9999999, dtype=float32)

In [56]:
layer = bert_layers.FlaxBertLayer(configuration)
variables = layer.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = layer.apply(variables, output, attention_mask, None)

In [57]:
hidden_states[0].shape

(1, 5, 768)

In [58]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam, attn_cam = layer.apply(variables, cam, output, attention_mask, None, method=layer.relprop)

In [59]:
cam.shape
cam.sum()

DeviceArray(0.99999964, dtype=float32)

In [60]:
layer_collection = bert_layers.FlaxBertLayerCollection(configuration)
variables = layer_collection.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = layer_collection.apply(variables, output, attention_mask, None)

2022-12-04 10:28:47.009087: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-12-04 10:28:47.030524: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-12-04 10:28:47.673756: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-04 10:28:47.673920: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [61]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam, attn_cams = layer_collection.apply(variables, cam, output, attention_mask, None, method=layer_collection.relprop)

In [62]:
cam.sum()

DeviceArray(0.9999993, dtype=float32)

In [64]:
len(attn_cams)

12

In [65]:
encoder = bert_layers.FlaxBertEncoder(configuration)
variables = encoder.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = encoder.apply(variables, output, attention_mask, None)

In [66]:
hidden_states

FlaxBaseModelOutputWithPastAndCrossAttentions(last_hidden_state=DeviceArray([[[-0.73932266, -0.10496283, -0.23650123, ..., -0.29518175,
               -0.07944169,  2.5543838 ],
              [-1.0293188 , -1.4049089 ,  0.43076634, ...,  0.04329822,
                1.65496   ,  0.30194774],
              [ 0.4192415 , -1.8640374 ,  0.49091607, ...,  0.8654702 ,
               -0.00657526,  0.5579648 ],
              [ 0.07758794, -0.8224394 , -0.11835903, ..., -0.11618427,
                1.0518749 ,  1.1674784 ],
              [ 0.09991326, -0.80100965, -1.4718195 , ...,  0.5109208 ,
                0.6762812 ,  1.2827846 ]]], dtype=float32), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)

In [67]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam, attn_cam = encoder.apply(variables, cam, output, attention_mask, None, method=encoder.relprop)

In [68]:
cam.sum()

DeviceArray(0.9999999, dtype=float32)

In [75]:
inputs = tokenizer(["Hello world!",])
input_ids = jnp.array(inputs["input_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
print(input_ids.shape)

bert_module = bert_layers.FlaxBertModule(configuration)
variables = bert_module.init(jax.random.PRNGKey(0), input_ids, attention_mask, None)
hidden_states = bert_module.apply(variables, input_ids, attention_mask, None)

(1, 5)


In [76]:
hidden_states

FlaxBaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=DeviceArray([[[-1.803636  , -0.90945095, -0.4882279 , ..., -1.7539307 ,
                0.26016185, -0.39490893],
              [-1.0041101 ,  0.06071824, -0.61647964, ..., -0.9118402 ,
               -1.5660151 , -1.7031233 ],
              [-0.70081997, -0.7901022 , -1.5094408 , ..., -0.9647325 ,
               -0.89686865,  0.2988747 ],
              [-0.44771618, -0.58907217, -2.2381587 , ...,  0.40958568,
               -0.7222764 , -1.6665676 ],
              [-1.0788714 , -1.6146749 , -1.6605315 , ..., -1.2121923 ,
               -0.27918416, -0.9997014 ]]], dtype=float32), pooler_output=DeviceArray([[-3.71137857e-01,  6.60066009e-01, -9.70576286e-01,
               7.94233978e-02, -6.85315728e-01, -2.48473790e-02,
              -3.00466567e-01,  6.13525093e-01, -3.73042107e-01,
               3.94238204e-01,  8.54212463e-01,  1.11271188e-01,
              -4.34547335e-01,  3.64358783e-01, -4.61808622e-01,
      

In [78]:
cam = jnp.array(np.random.rand(*hidden_states[1].shape))
cam = cam / cam.sum()
cam, attn_cams = bert_module.apply(variables, cam, input_ids, attention_mask, None, method=bert_module.relprop)

In [79]:
cam

DeviceArray([[[ 4.2693843e-03,  3.7984230e-04,  5.1922008e-04, ...,
                2.8064649e-03,  2.3867206e-03,  6.1647287e-03],
              [ 4.2418513e-04,  4.4299239e-05,  1.0434465e-04, ...,
               -5.9854333e-06,  2.1828346e-06,  6.4564476e-05],
              [ 5.0416053e-04,  2.4041923e-04,  1.1700042e-05, ...,
                2.6237562e-05,  1.6011950e-06,  2.3359265e-04],
              [ 2.9140397e-04, -8.1375892e-06, -1.6246898e-06, ...,
                5.1468442e-04, -2.6290909e-06,  2.5299564e-06],
              [-1.2793008e-04,  4.2963159e-04, -2.5572461e-05, ...,
               -4.1006551e-05,  2.3577675e-04,  1.3572631e-05]]],            dtype=float32)

In [80]:
#cam[0] corresponds to position/token-type embeddings and cam[1] corresponds to token embeddings
cam.sum()

DeviceArray(0.99999994, dtype=float32)

In [81]:
inputs = tokenizer(["Hello world!",])
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

seq_class_module = bert_layers.FlaxBertForSequenceClassificationModule(configuration)
variables = seq_class_module.init(jax.random.PRNGKey(0), input_ids, attention_mask, token_ids, position_ids, None)
hidden_states = seq_class_module.apply(variables, input_ids, attention_mask, token_ids, position_ids, None)

In [82]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
print(cam, input_ids.dtype, attention_mask.dtype)
cam = seq_class_module.apply(variables, cam, input_ids, attention_mask, token_ids, position_ids, None, method=seq_class_module.relprop)

[[0.02014412 0.9798559 ]] int32 int32


In [90]:
output = model(input_ids)
output

FlaxSequenceClassifierOutput(logits=DeviceArray([[ 1.0231309 , -0.18921821]], dtype=float32), hidden_states=None, attentions=None)

In [93]:
cam = jnp.array(np.random.rand(*output[0].shape))
cam, attn_weights = model.relprop(cam, input_ids)

In [94]:
print(len(attn_weights))

12


In [95]:
attn_weights[0].shape

(1, 12, 5, 5)

In [96]:
inputs = tokenizer(["Hello world!",])
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

In [110]:
model = bert_layers.FlaxBertForSequenceClassification(configuration)

In [111]:
def loss_fn(input_ids, attention_mask, perturbs, index=None):
    output = model(input_ids, attention_mask, perturbs=perturbs)[0]
    if index == None:
        index = jnp.argmax(output, axis=-1)

    one_hot = np.zeros((1, output.shape[-1]), dtype=jnp.float32)
    one_hot[0, index] = 1
    one_hot_vector = jnp.array(one_hot)
    loss = jnp.sum(one_hot * output)
    return loss

loss, intermediate_grads = jax.value_and_grad(loss_fn, argnums=2)(input_ids, attention_mask, model.perturbs)

In [112]:
intermediate_grads

FrozenDict({
    bert: {
        encoder: {
            layer: {
                0: {
                    attention: {
                        self: {
                            attn_weights: DeviceArray([[[[ 0.10701816]],
                            
                                          [[ 0.06321273]],
                            
                                          [[-0.15270925]],
                            
                                          [[-0.02372664]],
                            
                                          [[ 0.02161717]],
                            
                                          [[-0.1558971 ]],
                            
                                          [[-0.35827017]],
                            
                                          [[ 0.09502272]],
                            
                                          [[-0.18520816]],
                            
                                          [

In [122]:
output = model(input_ids, attention_mask)[0]
index = jnp.argmax(output, axis=-1)

one_hot = np.zeros((1, output.shape[-1]), dtype=jnp.float32)
one_hot[0, index] = 1
one_hot_vector = jnp.array(one_hot)
input_cam, attn_cams = model.relprop(one_hot_vector, input_ids, attention_mask)

In [123]:
attn_cams[0].shape

(1, 12, 5, 5)

In [126]:
cams = []
for blk, cam in zip(intermediate_grads["bert"]["encoder"]["layer"].values(), reversed(attn_cams)):
    grad = blk["attention"]["self"]["attn_weights"]
    cam = jnp.reshape(cam[0], (-1, cam.shape[-1], cam.shape[-1]))
    grad = jnp.reshape(grad[0], (-1, grad.shape[-1], grad.shape[-1]))
    cam = grad * cam
    cam = jnp.clip(cam, a_min=0).mean(axis=0)
    cams.append(cam[jnp.newaxis])
cams

[DeviceArray([[[7.9139070e-05, 3.8550003e-05, 6.1506988e-05,
                6.5829525e-05, 4.7100897e-05],
               [5.3313970e-06, 4.5402312e-06, 4.8563907e-06,
                4.3781465e-06, 3.8504477e-06],
               [3.6195697e-06, 2.6392549e-06, 3.8894268e-06,
                2.8058691e-06, 3.7412917e-06],
               [1.3279441e-06, 1.3799876e-06, 1.4229416e-06,
                1.7548898e-06, 1.7892987e-06],
               [2.2859501e-06, 2.5369898e-06, 2.9938983e-06,
                1.9571537e-06, 2.0532789e-06]]], dtype=float32),
 DeviceArray([[[1.7492193e-05, 3.0794854e-05, 4.2264081e-05,
                2.4913190e-05, 1.8518527e-05],
               [4.8676487e-05, 3.9152168e-05, 6.7535599e-05,
                3.6008161e-05, 3.5029603e-05],
               [5.7590296e-06, 5.0679819e-06, 6.8918357e-06,
                4.5817642e-06, 4.3014079e-06],
               [6.6412304e-06, 7.0618457e-06, 6.6894422e-06,
                5.6739368e-06, 8.2463848e-06],
          

In [136]:
def compute_rollout_attention(all_layer_matrices, start_layer=0):
    # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
    num_tokens = all_layer_matrices[0].shape[1]
    batch_size = all_layer_matrices[0].shape[0]
    eye = jnp.eye(num_tokens)[jnp.newaxis]
    all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
    matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(axis=-1, keepdims=True)
                          for i in range(len(all_layer_matrices))]
    
    joint_attention = matrices_aug[start_layer]
    for i in range(start_layer+1, len(matrices_aug)):
        print(matrices_aug[i].shape, joint_attention.shape)
        joint_attention = np.einsum('...ij,...jk', matrices_aug[i], joint_attention)
    return joint_attention

In [137]:
rollout = compute_rollout_attention(cams, start_layer=0)
rollout[:, 0, 0] = rollout[:, 0].min()
rollout[:,0]

(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)
(1, 5, 5) (1, 5, 5)


array([[0.00049795, 0.00059671, 0.00053786, 0.00049795, 0.00050486]],
      dtype=float32)

In [138]:
from explanation_generator import *

In [139]:
inputs = tokenizer(["Hello world!",])
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

model = bert_layers.FlaxBertForSequenceClassification(configuration)

In [145]:
gen = Generator(model)
gen.generate_LRP(input_ids, attention_mask, start_layer=0)

DeviceArray([[0.00049795, 0.00059671, 0.00053786, 0.00049795, 0.00050486]],            dtype=float32)