In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import torch
import torch.nn
import jax
import jax.numpy as jnp
from jax import lax
import flax
import flax.linen as fnn
import numpy as np
from typing import Callable, Optional, Tuple, Union, Sequence, Iterable

sys.path.append("../bert")
from transformers import BertConfig, BertTokenizer
import modeling_flax_bert as bert_layers
import bert_explainability_layers as ours

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dir(ours)

['Add',
 'Callable',
 'Dense',
 'Dropout',
 'GeLU',
 'IndexSelect',
 'LayerNorm',
 'MatMul',
 'Optional',
 'ReLU',
 'RelProp',
 'RelPropSimple',
 'Softmax',
 'Tanh',
 'Tuple',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'einsum',
 'flax',
 'jax',
 'jnp',
 'lax',
 'nn',
 'np',
 'safe_divide']

In [4]:
dir(bert_layers)

['ACT2FN',
 'BERT_INPUTS_DOCSTRING',
 'BERT_START_DOCSTRING',
 'BertConfig',
 'Callable',
 'FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING',
 'FLAX_BERT_FOR_PRETRAINING_DOCSTRING',
 'FlaxBaseModelOutputWithPastAndCrossAttentions',
 'FlaxBaseModelOutputWithPooling',
 'FlaxBaseModelOutputWithPoolingAndCrossAttentions',
 'FlaxBertAttention',
 'FlaxBertEmbeddings',
 'FlaxBertEncoder',
 'FlaxBertForCausalLM',
 'FlaxBertForCausalLMModule',
 'FlaxBertForMaskedLM',
 'FlaxBertForMaskedLMModule',
 'FlaxBertForMultipleChoice',
 'FlaxBertForMultipleChoiceModule',
 'FlaxBertForNextSentencePrediction',
 'FlaxBertForNextSentencePredictionModule',
 'FlaxBertForPreTraining',
 'FlaxBertForPreTrainingModule',
 'FlaxBertForPreTrainingOutput',
 'FlaxBertForQuestionAnswering',
 'FlaxBertForQuestionAnsweringModule',
 'FlaxBertForSequenceClassification',
 'FlaxBertForSequenceClassificationModule',
 'FlaxBertForTokenClassification',
 'FlaxBertForTokenClassificationModule',
 'FlaxBertIntermediate',
 'FlaxBertLMPredicti

In [5]:
def safe_divide(a, b):
    den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
    den = den + den.eq(0).type(den.type()) * 1e-9
    return a / den * b.ne(0).type(b.type())


def forward_hook(self, input, output):
    if type(input[0]) in (list, tuple):
        self.X = []
        for i in input[0]:
            x = i.detach()
            x.requires_grad = True
            self.X.append(x)
    else:
        self.X = input[0].detach()
        self.X.requires_grad = True

    self.Y = output


class RelProp(torch.nn.Module):
    def __init__(self):
        super(RelProp, self).__init__()
        # if not self.training:
        self.register_forward_hook(forward_hook)

    def gradprop(self, Z, X, S):
        C = torch.autograd.grad(Z, X, S, retain_graph=True)
        return C

    def relprop(self, R, alpha):
        return R


class RelPropSimple(RelProp):
    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs

class MatMul(RelPropSimple):
    def forward(self, inputs):
        return torch.matmul(*inputs)
    
class Add(RelPropSimple):
    def forward(self, inputs):
        return torch.add(*inputs)

    def relprop(self, R, alpha):
        Z = self.forward(self.X)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        a = self.X[0] * C[0]
        b = self.X[1] * C[1]

        a_sum = a.sum()
        b_sum = b.sum()

        a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
        b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()

        a = a * safe_divide(a_fact, a.sum())
        b = b * safe_divide(b_fact, b.sum())

        outputs = [a, b]
        return outputs
    
class IndexSelect(RelProp):
    def forward(self, inputs, dim, indices):
        self.__setattr__('dim', dim)
        self.__setattr__('indices', indices)

        return torch.index_select(inputs, dim, indices)

    def relprop(self, R, alpha):
        Z = self.forward(self.X, self.dim, self.indices)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs
    
class Clone(RelProp):
    def forward(self, input, num):
        self.__setattr__('num', num)
        outputs = []
        for _ in range(num):
            outputs.append(input)

        return outputs

    def relprop(self, R, alpha):
        Z = []
        for _ in range(self.num):
            Z.append(self.X)
        S = [safe_divide(r, z) for r, z in zip(R, Z)]
        C = self.gradprop(Z, self.X, S)[0]

        R = self.X * C

        return R

In [6]:
A = torch.ones((2,2))
B = torch.tensor([[-2., 1], [1, -2]])

mm = MatMul()
mm([A,B])

tensor([[-1., -1.],
        [-1., -1.]])

In [7]:
mm.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[ 2., -1.],
         [ 0.,  0.]], grad_fn=<MulBackward0>),
 tensor([[ 2.,  0.],
         [-1., -0.]], grad_fn=<MulBackward0>)]

In [8]:
add = Add()
add([A,B])

tensor([[-1.,  2.],
        [ 2., -1.]])

In [9]:
add.relprop(torch.tensor([[1.,0],[0,0]]), alpha=1)

[tensor([[0.3333, -0.0000],
         [-0.0000, 0.0000]], grad_fn=<MulBackward0>),
 tensor([[0.6667, 0.0000],
         [0.0000, 0.0000]], grad_fn=<MulBackward0>)]

In [10]:
pool = IndexSelect()
pool(B, 1, torch.zeros(1, dtype=torch.int32))

tensor([[-2.],
        [ 1.]])

In [11]:
pool.relprop(torch.tensor([[1.],[0]]), alpha=1)

tensor([[1., 0.],
        [0., -0.]], grad_fn=<MulBackward0>)

In [12]:
clone = Clone()
clone(B, 2)

[tensor([[-2.,  1.],
         [ 1., -2.]]),
 tensor([[-2.,  1.],
         [ 1., -2.]])]

In [13]:
clone.relprop([torch.tensor([[.5,.5],[0,0]]), torch.tensor([[0,.25],[.25,.5]])], alpha=1)

tensor([[0.5000, 0.7500],
        [0.2500, 0.5000]], grad_fn=<MulBackward0>)

It looks like clone is necessary because of the way pyTorch tracks gradients. Working out the math, I think the relprop works out to a sum over all of the relevances (which intuitively also makes sense). For this reason, I didn't include "clone" in the Flax layers I implemented and whenever I saw a clone.relprop in the pyTorch version, I added the relevances in the Jax version

In [14]:
A = jnp.ones((2,2))
B = jnp.array([[-2., 1], [1, -2]])

jmm = ours.MatMul()
jmm(A,B)



DeviceArray([[-1., -1.],
             [-1., -1.]], dtype=float32)

In [15]:
jmm.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 2., -1.],
              [ 0.,  0.]], dtype=float32),
 DeviceArray([[ 2.,  0.],
              [-1., -0.]], dtype=float32)]

In [16]:
j_add = ours.Add()
j_add(A,B)

DeviceArray([[-1.,  2.],
             [ 2., -1.]], dtype=float32)

In [17]:
j_add.relprop(jnp.array([[1.,0],[0,0]]), A, B)

[DeviceArray([[ 0.33333334, -0.        ],
              [-0.        ,  0.        ]], dtype=float32),
 DeviceArray([[0.6666667, 0.       ],
              [0.       , 0.       ]], dtype=float32)]

In [18]:
jax_pool = ours.IndexSelect()
jax_pool(B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[-2.],
             [ 1.]], dtype=float32)

In [19]:
jax_pool.relprop(jnp.array([[1.],[0]]), B, 1, jnp.zeros(1, dtype=jnp.int32))

DeviceArray([[ 1.,  0.],
             [ 0., -0.]], dtype=float32)

In [20]:
jax_dense = ours.Dense(2)
x = jnp.ones((1,20))
variables = jax_dense.init(jax.random.PRNGKey(0), x)
model = jax_dense.bind(variables)
model(x)
variables
print(model.variables["params"])

FrozenDict({
    kernel: DeviceArray([[ 1.53342038e-01,  3.29581231e-01],
                 [ 4.70121145e-01,  1.81581691e-01],
                 [ 1.80420190e-01,  3.35969597e-01],
                 [-1.62113383e-01, -2.02395543e-01],
                 [-1.18919984e-01, -6.62900805e-02],
                 [-4.10243064e-01,  7.33509436e-02],
                 [ 4.96547073e-02, -2.71931272e-02],
                 [-2.19073966e-01,  4.65010494e-01],
                 [ 1.54693127e-02,  1.10250756e-01],
                 [-3.89764100e-01,  8.84109288e-02],
                 [-3.79199862e-01, -2.35682800e-02],
                 [ 3.24670374e-01,  2.04515398e-01],
                 [ 1.44210760e-04,  1.65998340e-01],
                 [-3.36234242e-01, -3.94779295e-01],
                 [-1.40488684e-01, -5.03639765e-02],
                 [-2.11226270e-02,  2.08313853e-01],
                 [ 9.56470072e-02,  2.05471292e-01],
                 [-9.69679356e-02, -3.74182016e-01],
                 [-3.6075

In [21]:
model.relprop(jnp.array([[1.,0]]), x)

DeviceArray([[1.1891876e-01, 3.6458510e-01, 1.3991822e-01, 0.0000000e+00,
              0.0000000e+00, 0.0000000e+00, 3.8507875e-02, 0.0000000e+00,
              1.1996655e-02, 0.0000000e+00, 0.0000000e+00, 2.5178614e-01,
              1.1183733e-04, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
              7.4175507e-02, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],            dtype=float32)

In [22]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
configuration = BertConfig()

In [23]:
inputs = tokenizer("Hello world!")
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

In [24]:
embedding_layer = bert_layers.FlaxBertEmbeddings(configuration)
variables = embedding_layer.init(jax.random.PRNGKey(0), input_ids, token_ids,  position_ids, attention_mask)
output = embedding_layer.apply(variables, input_ids, token_ids, position_ids, attention_mask)
output.shape

(5, 768)

In [25]:
cam = jnp.array(np.random.rand(*output.shape))
cam = cam / cam.sum()
cam = embedding_layer.apply(variables, output, input_ids, token_ids, position_ids, attention_mask, method=embedding_layer.relprop)
cam

[DeviceArray([[ 1.47529153e-08, -2.83236457e-08, -2.98519609e-08, ...,
               -9.80268879e-08,  3.19056568e-08, -4.26698819e-08],
              [ 2.07204494e-08,  2.83888308e-08,  2.44977816e-08, ...,
               -7.83912455e-08, -5.39528457e-08,  3.94658128e-09],
              [ 2.63997428e-08,  2.46574903e-08,  3.82257674e-08, ...,
               -5.37965015e-08, -3.58449519e-08,  1.20567904e-07],
              [-1.14016965e-08, -6.57746302e-09, -1.69149015e-08, ...,
               -5.53824497e-09, -6.18229024e-08, -3.85396532e-08],
              [-7.03338898e-09,  2.34075426e-08,  4.81068483e-08, ...,
               -3.22811111e-08, -4.31609166e-08, -5.09730675e-08]],            dtype=float32),
 DeviceArray([[-1.8991523e-08,  1.8220303e-08,  3.6113799e-09, ...,
               -9.9331722e-08, -3.6784922e-08,  3.1697240e-08],
              [-3.8814896e-08, -3.0151741e-09,  6.9854250e-10, ...,
                7.1177952e-09,  8.2972127e-09, -2.5059039e-08],
              [ 3.

In [26]:
output = output[jnp.newaxis]
output.shape

(1, 5, 768)

In [27]:
attention = bert_layers.FlaxBertAttention(configuration)
variables = attention.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = attention.apply(variables, output, attention_mask, None)

In [28]:
print(hidden_states)

(DeviceArray([[[-0.5629446 ,  1.0679047 ,  1.1388767 , ...,  0.3926163 ,
               -1.4752828 ,  1.9199177 ],
              [-1.1342765 , -0.7030668 , -0.18691306, ...,  2.3455682 ,
                1.5618546 , -0.4823021 ],
              [ 0.35637692, -0.3917392 ,  0.12976551, ...,  1.4595735 ,
                0.59808064,  0.1932004 ],
              [ 0.6631112 ,  1.5822634 ,  1.1960372 , ..., -0.13444664,
                1.0535946 ,  1.4545076 ],
              [ 0.5495985 , -0.10840818, -0.42316666, ...,  0.59822196,
                1.625095  ,  1.389647  ]]], dtype=float32),)


In [29]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam

DeviceArray([[[1.8484204e-04, 2.5492642e-04, 7.8359575e-05, ...,
               9.8414523e-05, 5.9304934e-05, 3.1589242e-04],
              [2.4367780e-04, 4.8970850e-04, 4.8835133e-04, ...,
               1.4320457e-04, 2.2933554e-04, 3.5219468e-04],
              [3.0232969e-04, 1.6612618e-04, 2.3883308e-04, ...,
               1.3362999e-04, 1.8775632e-04, 1.7151897e-04],
              [2.6469497e-04, 4.4329997e-04, 4.1071716e-04, ...,
               4.3890130e-04, 4.1915450e-04, 4.6043098e-04],
              [1.0859840e-04, 1.6374406e-04, 2.7995577e-04, ...,
               2.0164947e-04, 4.1660364e-04, 3.2140972e-04]]],            dtype=float32)

In [32]:
%pdb

Automatic pdb calling has been turned ON


In [None]:
cam = attention.apply(variables, cam, output, None, None, method=attention.relprop)
cam.sum()

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (5, 768) and (5, 768, 12, 64)

> [0;32m/home/mike/miniconda3/envs/TransformerExplainability/lib/python3.10/site-packages/jax/core.py[0m(1510)[0;36mdivide_shape_sizes[0;34m()[0m
[0;32m   1508 [0;31m      [0;32mreturn[0m [0;36m1[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1509 [0;31m    [0;32mif[0m [0msz1[0m [0;34m%[0m [0msz2[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1510 [0;31m      [0;32mraise[0m [0mInconclusiveDimensionOperation[0m[0;34m([0m[0;34mf"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1511 [0;31m    [0;32mreturn[0m [0msz1[0m [0;34m//[0m [0msz2[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1512 [0;31m[0;34m[0m[0m
[0m


In [None]:
layer = bert_layers.FlaxBertLayer(configuration)
variables = layer.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = layer.apply(variables, output, attention_mask, None)

In [None]:
hidden_states[0].shape

In [None]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam = layer.apply(variables, cam, output, attention_mask, None, method=layer.relprop)

In [None]:
cam.shape
cam.sum()

In [None]:
layer_collection = bert_layers.FlaxBertLayerCollection(configuration)
variables = layer_collection.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = layer_collection.apply(variables, output, attention_mask, None)

In [None]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam = layer_collection.apply(variables, cam, output, attention_mask, None, method=layer_collection.relprop)

In [None]:
cam.sum()

In [None]:
encoder = bert_layers.FlaxBertEncoder(configuration)
variables = encoder.init(jax.random.PRNGKey(0), output, attention_mask, None)
hidden_states = encoder.apply(variables, output, attention_mask, None)

In [None]:
hidden_states

In [None]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
cam = encoder.apply(variables, cam, output, attention_mask, None, method=encoder.relprop)

In [None]:
cam.sum()

In [None]:
inputs = tokenizer(["Hello world!",])
input_ids = jnp.array(inputs["input_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
print(input_ids.shape)

bert_module = bert_layers.FlaxBertModule(configuration)
variables = bert_module.init(jax.random.PRNGKey(0), input_ids, attention_mask, None)
hidden_states = bert_module.apply(variables, input_ids, attention_mask, None)

In [None]:
hidden_states

In [None]:
cam = jnp.array(np.random.rand(*hidden_states[1].shape))
cam = cam / cam.sum()
cam = bert_module.apply(variables, cam, input_ids, attention_mask, None, method=bert_module.relprop)

In [None]:
cam

In [None]:
print(len(cam))

In [None]:
#cam[0] corresponds to position/token-type embeddings and cam[1] corresponds to token embeddings
cam[0].sum() + cam[1].sum()

In [None]:
inputs = tokenizer(["Hello world!",])
input_ids = jnp.array(inputs["input_ids"])
token_ids = jnp.array(inputs["token_type_ids"])
attention_mask = jnp.array(inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

seq_class_module = bert_layers.FlaxBertForSequenceClassificationModule(configuration)
variables = seq_class_module.init(jax.random.PRNGKey(0), input_ids, attention_mask, token_ids, position_ids, None)
hidden_states = seq_class_module.apply(variables, input_ids, attention_mask, token_ids, position_ids, None)

In [None]:
cam = jnp.array(np.random.rand(*hidden_states[0].shape))
cam = cam / cam.sum()
print(cam, input_ids.dtype, attention_mask.dtype)
cam = seq_class_module.apply(variables, cam, input_ids, attention_mask, token_ids, position_ids, None, method=seq_class_module.relprop)

In [None]:
model = bert_layers.FlaxBertForSequenceClassification(configuration)

In [None]:
output = model(input_ids)
output

In [None]:
cam = jnp.array(np.random.rand(*output[0].shape))
model.relprop(cam, input_ids)

In [None]:
model.apply?