# Modules needed to be verified


## Basic Building Blocks (bb)
all the components in the layers file

- FlaxBertForSequenceClassification
    - FlaxBertForSequenceClassificationModule
        - FlaxBertModule
            - FlaxBertEmbeddings
                - Add
                - LayerNorm
                - Dropout
            - FlaxBertEncoder
                - FlaxBertLayerCollection
                    - FlaxBertLayer
                        - FlaxBertAttention
                            - FlaxBertSelfAttention(bb only)
                            - ~~FlaxBertSelfOutput(bb only)~~ ?
                        - ~~FlaxBertIntermediate(bb only)~~
                        - FlaxBertOutput(bb only)
                    - FlaxBertCheckpointLayer(cond.)
            - FlaxBertPooler(bb only)
        - Dropout
        - Dense

## Basic Building Block verification.
- FlaxBertSelfAttention vs.     BertSelfAttention
- FlaxBertSelfOutput    vs.     BertSelfOutput
- FlaxBertIntermediate  vs.     BertIntermediate
- FlaxBertOutput        vs.     BertOutput
- FlaxBertPooler        vs.     BertPooler

In [1]:
import sys
import torch
import torch.nn as nn
import jax
import jax.numpy as jnp
import numpy as np
from clu import parameter_overview
import flax.linen as fnn

sys.path.insert(0, '../')

from bert_torch.BERT import *
from bert.modeling_flax_bert import *
import bert.modeling_flax_bert as layers
from transformers import BertConfig, BertTokenizer
import bert_torch.layers as tl
import bert.bert_explainability_layers as fl


2022-12-01 23:04:18.167464: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-01 23:04:18.167502: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
configuration = BertConfig()
print(configuration.add_cross_attention)

False


In [4]:
text_batch = ["Hello world!"]
fl_inputs = tokenizer(text_batch)
# print(fl_inputs.__class__)
# print(dict(inputs).__class__)

for k in fl_inputs:
    fl_inputs[k] = jnp.array(fl_inputs[k])
fl_inputs['position_ids'] = jnp.arange(fl_inputs['input_ids'].shape[0])

input_ids = jnp.array(fl_inputs["input_ids"])
token_ids = jnp.array(fl_inputs["token_type_ids"])
attention_mask = jnp.array(fl_inputs["attention_mask"])
position_ids = jnp.arange(input_ids.shape[0])

pt_inputs = tokenizer(text_batch, return_tensors='pt')
pt_inputs['position_ids'] = torch.tensor(np.arange(pt_inputs['input_ids'].shape[0]))
pt_input_ids = pt_inputs["input_ids"]
pt_token_ids = pt_inputs["token_type_ids"]
pt_attention_mask = pt_inputs["attention_mask"]
pt_position_ids = torch.tensor(np.arange(pt_input_ids.shape[0]))

print(fl_inputs, pt_inputs)

{'input_ids': DeviceArray([[ 101, 7592, 2088,  999,  102]], dtype=int32), 'token_type_ids': DeviceArray([[0, 0, 0, 0, 0]], dtype=int32), 'attention_mask': DeviceArray([[1, 1, 1, 1, 1]], dtype=int32), 'position_ids': DeviceArray([0], dtype=int32)} {'input_ids': tensor([[ 101, 7592, 2088,  999,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]]), 'position_ids': tensor([0])}


In [5]:
# pt_layer = BertEmbeddings(configuration).eval()
# fl_layer = FlaxBertEmbeddings(configuration)

In [119]:
def nested_set(dic, keys, value, create_missing=True):
    d = dic
    for key in keys[:-1]:
        if key in d:
            d = d[key]
        elif create_missing:
            d = d.setdefault(key, {})
        else:
            return dic
    if keys[-1] in d or create_missing:
        d[keys[-1]] = value
    return dic


def pt2fl(pt: nn.Module, debug=False):
    vf = {}
    for k, v in pt.named_parameters():
        keys = k.split('.')
        if len(keys) >= 2:
            if debug:
                print(k, v.size())
            if keys[-2].endswith('embeddings'):
                keys[-1] = 'embedding'

            if keys[-1] == 'weight':
                if keys[-2] == 'LayerNorm':
                    keys[-1] = 'scale'
                else:
                    keys[-1] = 'kernel'

            if keys[-2] in ['dense', 'key', 'value', 'query'] and keys[-1] == 'kernel':
                nested_set(vf, keys, value=jnp.transpose(v.detach().numpy(), (1, 0)))
            else:
                nested_set(vf, keys, value=v.detach().numpy())
        else:
            if keys[-1] == 'weight':
                keys[-1] = 'kernel'
                nested_set(vf, keys, value=jnp.transpose(v.detach().numpy(), (1, 0)))
            else:
                nested_set(vf, keys, value=v.detach().numpy())
    params = {'params': vf}
    if debug:
        print('Num Prams: ', count_parameters(pt))
        print(parameter_overview.get_parameter_overview(params))
    return params


def verify_module(pt:nn.Module, fl:fnn.Module, pt_kwargs=None, fl_kwargs=None, debug=False):
    params = pt2fl(pt, debug=debug)

    pt_out = pt(**pt_kwargs)
    fl_out = fl.apply(params, **fl_kwargs)
    # print(pt_out)
    if isinstance(pt_out, tuple):
        pt_out = pt_out[0]
    if isinstance(fl_out, tuple):
        fl_out = fl_out[0]
    if isinstance(pt_out, BaseModelOutput):
        pt_out = pt_out[0]

    if isinstance(fl_out, FlaxBaseModelOutputWithPastAndCrossAttentions):
        fl_out = fl_out[0]
    if isinstance(pt_out, BaseModelOutputWithPooling):
        pt_out = pt_out[1]
    if isinstance(fl_out, FlaxBaseModelOutputWithPoolingAndCrossAttentions):
        fl_out = fl_out[1]
    print("Forward diff: ",np.abs(pt_out.detach().numpy() - fl_out).sum())

    # print(parameter_overview.get_parameter_overview(params))
    # print(pt, count_parameters(pt))

    cam = np.random.rand(*fl_out.shape)
    fl_cam = jnp.array(cam / cam.sum())
    pt_cam = torch.tensor(cam / cam.sum())

    fl_cams = fl.apply(params, fl_cam, **fl_kwargs, method=fl.relprop)
    kwargs = {'alpha': 1}
    pt_cams = pt.relprop(pt_cam, **kwargs)

    print('Flax relprop:', len(fl_cams))
    print('Pt relprop:', len(pt_cams))
    pt_sum = None
    fl_sum = None
    if isinstance(fl_cams, list) or isinstance(fl_cams, tuple):

        for i, c in enumerate(fl_cams):
            print("Relprop size:", pt_cams[i].size())
            print("Relprop", i, " diff:", np.abs(np.array(c) - pt_cams[i].detach().numpy()).sum())
            if pt_sum:
                pt_sum+=pt_cams[i].sum()
            else:
                pt_sum = pt_cams[i].sum()

            if fl_sum:
                fl_sum+=c.sum()
            else:
                fl_sum = c.sum()
    else:
        pt_sum = pt_cams.sum()
        fl_sum = fl_cams.sum()
        print("Relprop size:", pt_cams.size())
        print("Relprop diff:", np.abs(np.array(fl_cams) - pt_cams.detach().numpy()).sum())
    print("Cam sum:", pt_sum, fl_sum)
    return pt_out, fl_out

def verify_module_args(pt:nn.Module, fl:fnn.Module, pt_kwargs=None, fl_kwargs=None, debug=False):
    params = pt2fl(pt, debug=debug)

    pt_out = pt(*pt_kwargs)
    fl_out = fl.apply(params, *fl_kwargs)
    # print(pt_out)
    if isinstance(pt_out, tuple):
        pt_out = pt_out[0]
    if isinstance(fl_out, tuple):
        fl_out = fl_out[0]
    if isinstance(pt_out, BaseModelOutput):
        pt_out = pt_out[0]

    if isinstance(fl_out, FlaxBaseModelOutputWithPastAndCrossAttentions):
        fl_out = fl_out[0]
    if isinstance(pt_out, BaseModelOutputWithPooling):
        pt_out = pt_out[1]
    if isinstance(fl_out, FlaxBaseModelOutputWithPoolingAndCrossAttentions):
        fl_out = fl_out[1]
    print("Forward diff: ",np.abs(pt_out.detach().numpy() - fl_out).sum())
    if debug:
        # print(fl_out)
        # print(pt_out.detach().numpy())
        print(pt_out.detach().numpy() - fl_out)
    # print(parameter_overview.get_parameter_overview(params))
    # print(pt, count_parameters(pt))

    cam = np.random.rand(*fl_out.shape)
    fl_cam = jnp.array(cam / cam.sum())
    pt_cam = torch.tensor(cam / cam.sum())

    fl_cams = fl.apply(params, fl_cam, *fl_kwargs, method=fl.relprop)
    kwargs = {'alpha': 1}
    pt_cams = pt.relprop(pt_cam, **kwargs)

    print('Flax relprop:', len(fl_cams))
    print('Pt relprop:', len(pt_cams))

    if isinstance(fl_cams, list) or isinstance(fl_cams, tuple):
        for i, c in enumerate(fl_cams):
            print("Relprop size:", pt_cams[i].size())
            print("Relprop", i, " diff:", np.abs(np.array(c) - pt_cams[i].detach().numpy()).sum())
    else:
        print("Relprop size:", pt_cams.size(), ". Cam sum:", fl_cams.sum(), pt_cams.sum() )
        print("Relprop diff:", np.abs(np.array(fl_cams) - pt_cams.detach().numpy()).sum())
    return pt_out, fl_out

In [123]:
"""Verified"""
pt_m = BertEmbeddings(configuration).eval()
fl_m = FlaxBertEmbeddings(configuration)
fl_in = fl_inputs.copy()
pt_in = pt_inputs.copy()
print(fl_in, pt_in)
# fl_in.pop('attention_mask', None)
pt_in.pop('attention_mask', None)
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in, debug=True)

print(pt_out.shape)

{'input_ids': DeviceArray([[ 101, 7592, 2088,  999,  102]], dtype=int32), 'token_type_ids': DeviceArray([[0, 0, 0, 0, 0]], dtype=int32), 'attention_mask': DeviceArray([[1, 1, 1, 1, 1]], dtype=int32), 'position_ids': DeviceArray([0], dtype=int32)} {'input_ids': tensor([[ 101, 7592, 2088,  999,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]]), 'position_ids': tensor([0])}
word_embeddings.weight torch.Size([30522, 768])
position_embeddings.weight torch.Size([512, 768])
token_type_embeddings.weight torch.Size([2, 768])
LayerNorm.weight torch.Size([768])
LayerNorm.bias torch.Size([768])
Num Prams:  23837184
+----------------------------------------+--------------+------------+-----------+------+
| Name                                   | Shape        | Size       | Mean      | Std  |
+----------------------------------------+--------------+------------+-----------+------+
| params/LayerNorm/bias                  | (768,)       | 768        | 

In [124]:
all_head_size = configuration.num_attention_heads * int(configuration.hidden_size / configuration.num_attention_heads)
q = tl.Linear(configuration.hidden_size, all_head_size).eval()
k = fl.Dense(configuration.hidden_size, all_head_size)
inputs = np.random.rand(*fl_out.shape)
fl_in = jnp.array(inputs)
pt_in =  torch.Tensor(inputs)
lpt_out, lfl_out = verify_module_args(q, k, pt_in, fl_in, debug=True)
print(lpt_out.dtype)
print(lfl_out.dtype)

Num Prams:  590592
+---------------+------------+---------+-----------+--------+
| Name          | Shape      | Size    | Mean      | Std    |
+---------------+------------+---------+-----------+--------+
| params/bias   | (768,)     | 768     | -0.000103 | 0.0201 |
| params/kernel | (768, 768) | 589,824 | 1.83e-05  | 0.0208 |
+---------------+------------+---------+-----------+--------+
Total: 590,592
Forward diff:  0.29175353
[[ 5.63561916e-05 -1.45345926e-04 -6.81877136e-05 ... -5.32269478e-05
  -1.88350677e-05  1.03414059e-05]
 [ 2.51531601e-05  4.08291817e-06 -6.37769699e-05 ... -9.78410244e-05
   1.17883086e-04 -4.16040421e-05]
 [ 3.07857990e-05  3.15755606e-05 -6.27040863e-05 ... -1.45435333e-05
  -3.48836184e-05 -7.14063644e-05]
 [ 8.72313976e-05  6.85453415e-07 -3.86536121e-05 ... -9.53450799e-05
   8.67843628e-05 -5.93662262e-05]
 [ 1.25767663e-04  1.02996826e-04 -7.15255737e-07 ...  9.35196877e-05
  -2.11298466e-05 -1.33275986e-04]]
Flax relprop: 5
Pt relprop: 5
Relprop size

In [125]:
# <- BertEmbeddings
pt_m = BertAttention(configuration).eval()
fl_m = FlaxBertAttention(configuration)
inputs = np.random.rand(*fl_out.shape)
fl_in = {'hidden_states': jnp.array(inputs)}

pt_in = {'hidden_states': torch.Tensor(inputs)}

# print(np.sum(np.array(fl_in['hidden_states']) - pt_in['hidden_states'].numpy()))
fl_in['layer_head_mask'] = None
pt_in['head_mask'] = None

fl_in['attention_mask'] = fl_inputs['attention_mask']
pt_in['attention_mask'] = torch.Tensor(pt_inputs['attention_mask'].numpy())
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in, debug=True)
att_in = fl_in

self.query.weight torch.Size([768, 768])
self.query.bias torch.Size([768])
self.key.weight torch.Size([768, 768])
self.key.bias torch.Size([768])
self.value.weight torch.Size([768, 768])
self.value.bias torch.Size([768])
output.dense.weight torch.Size([768, 768])
output.dense.bias torch.Size([768])
output.LayerNorm.weight torch.Size([768])
output.LayerNorm.bias torch.Size([768])
Num Prams:  2363904
+-------------------------------+------------+---------+-----------+--------+
| Name                          | Shape      | Size    | Mean      | Std    |
+-------------------------------+------------+---------+-----------+--------+
| params/output/LayerNorm/bias  | (768,)     | 768     | 0.0       | 0.0    |
| params/output/LayerNorm/scale | (768,)     | 768     | 1.0       | 0.0    |
| params/output/dense/bias      | (768,)     | 768     | 0.000191  | 0.0208 |
| params/output/dense/kernel    | (768, 768) | 589,824 | 4.15e-06  | 0.0208 |
| params/self/key/bias          | (768,)     | 768  

In [126]:
"""Verified"""
# <- BertAttention
pt_m = BertIntermediate(configuration).eval()
fl_m = FlaxBertIntermediate(configuration)
# inputs = np.random.rand(*fl_out.shape)
att_out = np.random.rand(*fl_out.shape)

fl_in = {'hidden_states': jnp.array(att_out)}
pt_in = {'hidden_states': torch.Tensor(att_out)}
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in, debug=True)
print(pt_out.size())

dense.weight torch.Size([3072, 768])
dense.bias torch.Size([3072])
Num Prams:  2362368
+---------------------+-------------+-----------+----------+--------+
| Name                | Shape       | Size      | Mean     | Std    |
+---------------------+-------------+-----------+----------+--------+
| params/dense/bias   | (3072,)     | 3,072     | -6.8e-05 | 0.021  |
| params/dense/kernel | (768, 3072) | 2,359,296 | 1e-05    | 0.0208 |
+---------------------+-------------+-----------+----------+--------+
Total: 2,362,368
Forward diff:  0.64422125
Flax relprop: 1
Pt relprop: 1
Relprop size: torch.Size([1, 5, 768])
Relprop diff: 9.828516e-06
Cam sum: tensor(1., grad_fn=<SumBackward0>) 0.99999726
torch.Size([1, 5, 3072])


In [127]:
"""Verified??"""
#TODO: make sure it's verified
pt_m = BertOutput(configuration).eval()
fl_m = FlaxBertOutput(configuration)
# inputs = np.random.rand(*fl_out.shape)
# att_out = np.random.rand(*att_out.shape)
inputs = np.random.rand(1,5,3072)
att_out = np.random.rand(1,5,768)
print(inputs.shape)
print(att_out.shape)

fl_in = {'hidden_states': jnp.array(inputs)}
pt_in = {'hidden_states': torch.Tensor(inputs)}

# print(np.sum(np.array(fl_in['hidden_states']) - pt_in['hidden_states'].numpy()))
# fl_in['layer_head_mask'] = None
# pt_in['head_mask'] = None

fl_in['attention_output'] = jnp.array(att_out)
pt_in['input_tensor'] = torch.Tensor(att_out)
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in, debug=True)
print(pt_out.size())

(1, 5, 3072)
(1, 5, 768)
dense.weight torch.Size([768, 3072])
dense.bias torch.Size([768])
LayerNorm.weight torch.Size([768])
LayerNorm.bias torch.Size([768])
Num Prams:  2361600
+------------------------+-------------+-----------+-----------+--------+
| Name                   | Shape       | Size      | Mean      | Std    |
+------------------------+-------------+-----------+-----------+--------+
| params/LayerNorm/bias  | (768,)      | 768       | 0.0       | 0.0    |
| params/LayerNorm/scale | (768,)      | 768       | 1.0       | 0.0    |
| params/dense/bias      | (768,)      | 768       | 0.000673  | 0.0106 |
| params/dense/kernel    | (3072, 768) | 2,359,296 | -2.99e-06 | 0.0104 |
+------------------------+-------------+-----------+-----------+--------+
Total: 2,361,600
Forward diff:  0.64069426
Flax relprop: 2
Pt relprop: 2
Relprop size: torch.Size([1, 5, 3072])
Relprop 0  diff: 0.20483156
Relprop size: torch.Size([1, 5, 768])
Relprop 1  diff: 0.21466121
Cam sum: tensor(1., gra

In [51]:
#TODO: not verified
pt_m = BertSelfAttention(configuration).eval()
fl_m = FlaxBertSelfAttention(configuration)
inputs = np.random.rand(*fl_out.shape)
fl_in = {'hidden_states': jnp.array(inputs)}
pt_in = {'hidden_states': torch.Tensor(inputs)}

# print(np.sum(np.array(fl_in['hidden_states']) - pt_in['hidden_states'].numpy()))
fl_in['layer_head_mask'] = None
pt_in['head_mask'] = None

fl_in['attention_mask'] = fl_inputs['attention_mask']
pt_in['attention_mask'] = torch.Tensor(pt_inputs['attention_mask'].numpy())
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in, debug=True)
print(pt_out.size())

query.weight torch.Size([768, 768])
query.bias torch.Size([768])
key.weight torch.Size([768, 768])
key.bias torch.Size([768])
value.weight torch.Size([768, 768])
value.bias torch.Size([768])
Num Prams:  1771776
+---------------------+------------+---------+-----------+--------+
| Name                | Shape      | Size    | Mean      | Std    |
+---------------------+------------+---------+-----------+--------+
| params/key/bias     | (768,)     | 768     | -0.00107  | 0.0211 |
| params/key/kernel   | (768, 768) | 589,824 | -4.47e-05 | 0.0208 |
| params/query/bias   | (768,)     | 768     | -0.000719 | 0.0209 |
| params/query/kernel | (768, 768) | 589,824 | -2.14e-05 | 0.0208 |
| params/value/bias   | (768,)     | 768     | -0.000857 | 0.0208 |
| params/value/kernel | (768, 768) | 589,824 | 3.79e-05  | 0.0208 |
+---------------------+------------+---------+-----------+--------+
Total: 1,771,776
Forward diff:  0.24423632
Flax relprop: 1
Pt relprop: 1
Relprop diff: 2.471684
torch.Size([1

In [52]:
"""Verified?"""
pt_m = BertSelfOutput(configuration).eval()
fl_m = FlaxBertSelfOutput(configuration)
inputs = np.random.rand(*fl_in['hidden_states'].shape)
att_out = np.random.rand(*fl_out.shape)

fl_in = {'hidden_states': jnp.array(inputs)}
pt_in = {'hidden_states': torch.Tensor(inputs)}

# print(np.sum(np.array(fl_in['hidden_states']) - pt_in['hidden_states'].numpy()))
# fl_in['layer_head_mask'] = None
# pt_in['head_mask'] = None

fl_in['input_tensor'] = jnp.array(att_out)
pt_in['input_tensor'] = torch.Tensor(att_out)
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in, debug=True)
print(pt_out.size())

dense.weight torch.Size([768, 768])
dense.bias torch.Size([768])
LayerNorm.weight torch.Size([768])
LayerNorm.bias torch.Size([768])
Num Prams:  592128
+------------------------+------------+---------+-----------+--------+
| Name                   | Shape      | Size    | Mean      | Std    |
+------------------------+------------+---------+-----------+--------+
| params/LayerNorm/bias  | (768,)     | 768     | 0.0       | 0.0    |
| params/LayerNorm/scale | (768,)     | 768     | 1.0       | 0.0    |
| params/dense/bias      | (768,)     | 768     | 0.000911  | 0.0207 |
| params/dense/kernel    | (768, 768) | 589,824 | -4.27e-06 | 0.0208 |
+------------------------+------------+---------+-----------+--------+
Total: 592,128
Forward diff:  0.6735077
Flax relprop: 2
Pt relprop: 2
Relprop 0  diff: 0.07470034
Relprop 1  diff: 0.11423703
torch.Size([1, 5, 768])


In [10]:
pt_m = BertLayer(configuration).eval()
fl_m = FlaxBertLayer(configuration)
inputs = np.random.rand(*fl_out.shape)
fl_in = {'hidden_states': jnp.array(inputs)}
pt_in = {'hidden_states': torch.Tensor(inputs)}

fl_in['layer_head_mask'] = None
pt_in['head_mask'] = None

fl_in['attention_mask'] = fl_inputs['attention_mask']
pt_in['attention_mask'] = torch.Tensor(pt_inputs['attention_mask'].numpy())
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in)

Forward diff:  0.7278961
Flax relprop: 1
Pt relprop: 1
Relprop diff: 2.263708


In [11]:
pt_m = BertEncoder(configuration).eval()
fl_m = FlaxBertEncoder(configuration)

# requires: hidden_states, attention_mask, head_mask (default None)
inputs = np.random.rand(*fl_out.shape)
fl_in = {'hidden_states': jnp.array(inputs)}
pt_in = {'hidden_states': torch.Tensor(inputs)}

fl_in['head_mask'] = None
pt_in['head_mask'] = None

fl_in['attention_mask'] = fl_inputs['attention_mask']
pt_in['attention_mask'] = torch.Tensor(pt_inputs['attention_mask'].numpy())

# print()


# print(fl_in, pt_in)
# fl_in.pop('attention_mask', None)
# pt_in.pop('attention_mask', None)
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in)

Forward diff:  1.5823481
Flax relprop: 1
Pt relprop: 1
Relprop diff: 2.1209507


In [10]:
# fl_in = (input_ids, token_ids, position_ids, attention_mask)
# pt_in = (pt_input_ids, pt_token_ids, pt_position_ids, None)
# fl_in = fl_inputs
# pt_in = pt_inputs
# pt_out, fl_out = verify_module(pt_layer, fl_layer, pt_in, fl_in)

In [89]:
pt_m = BertModel(configuration).eval()
fl_m = FlaxBertModule(configuration)
fl_in = fl_inputs
pt_in = pt_inputs
pt_out, fl_out = verify_module(pt_m, fl_m, pt_in, fl_in)



Forward diff:  0.19136032
Flax relprop: 1
Pt relprop: 1
Relprop diff: 2.0647373


In [None]:
print(parameter_overview.get_parameter_overview())
print(pt_layer, count_parameters(pt_layer))

In [None]:
for k,v in pt_layer.named_parameters():
    print(k, v.shape)

vf = {'params':
          {'word_embeddings':
               {'embedding': None},
           'position_embeddings':
               {'embedding': None},
           'token_type_embeddings':
               {'embedding': None},
           'LayerNorm':
               {'bias': None,
                'scale': None},
        }
      }
for k,v in pt_layer.named_parameters():
    if k == 'word_embeddings.weight':
        vf['params']['word_embeddings']['embedding'] = v.detach().numpy()
    if k == 'position_embeddings.weight':
        vf['params']['position_embeddings']['embedding'] = v.detach().numpy()
    if k == 'token_type_embeddings.weight':
        vf['params']['token_type_embeddings']['embedding'] = v.detach().numpy()
    if k == 'LayerNorm.weight':
        vf['params']['LayerNorm']['scale'] = v.detach().numpy()
    if k == 'LayerNorm.bias':
        vf['params']['LayerNorm']['bias'] = v.detach().numpy()

print(vf)


In [8]:
# vf = f_layer.init(jax.random.PRNGKey(0), input_ids, token_ids,  position_ids, attention_mask)
f_out = f_layer.apply(vf, input_ids, token_ids, position_ids, attention_mask)
pt_out = pt_layer(input_ids=pt_input_ids, token_type_ids=pt_token_ids, position_ids=pt_position_ids)

In [9]:
print(vf['params']['word_embeddings']['embedding'].shape)
print(f_out.shape)
print(pt_out.shape)
print(np.abs(pt_out.detach().numpy() - f_out).sum())

(30522, 768)
(1, 5, 768)
torch.Size([1, 5, 768])
3.368198e-05


In [10]:
cam = np.random.rand(*f_out.shape)
f_cam = jnp.array( cam / cam.sum())
pt_cam = torch.tensor(cam / cam.sum())

f_cam1, f_cam2 = f_layer.apply(vf, f_cam, input_ids, token_ids, position_ids, attention_mask, method=f_layer.relprop)
kwargs = {'alpha': 1}
pt_cam1, pt_cam2 = pt_layer.relprop(pt_cam, **kwargs)


print(np.abs(np.array(f_cam1) - pt_cam1.detach().numpy()).sum())
print(np.abs(np.array(f_cam2) - pt_cam2.detach().numpy()).sum())
print(np.abs(np.array(f_cam1) - pt_cam2.detach().numpy()).sum())
print(np.sum(f_cam1) + np.sum(f_cam2))
print(np.sum(pt_cam1.detach().numpy()) + np.sum(pt_cam2.detach().numpy()))
print(f_cam2.shape)

relprop 2
2.7899752
2.787298
2.7194073
1.0000002
1.0
(1, 5, 768)


In [11]:
print(parameter_overview.get_parameter_overview(vf))
print(pt_layer, count_parameters(pt_layer))

+----------------------------------------+--------------+------------+-----------+-------+
| Name                                   | Shape        | Size       | Mean      | Std   |
+----------------------------------------+--------------+------------+-----------+-------+
| params/LayerNorm/bias                  | (768,)       | 768        | 0.0       | 0.0   |
| params/LayerNorm/scale                 | (768,)       | 768        | 1.0       | 0.0   |
| params/position_embeddings/embedding   | (512, 768)   | 393,216    | -0.00286  | 0.999 |
| params/token_type_embeddings/embedding | (2, 768)     | 1,536      | -0.00341  | 1.01  |
| params/word_embeddings/embedding       | (30522, 768) | 23,440,896 | -7.33e-05 | 1.0   |
+----------------------------------------+--------------+------------+-----------+-------+
Total: 23,837,184
BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)

In [12]:
pt_layer = BertEncoder(configuration).eval()
f_layer = FlaxBertEncoder(configuration)

In [13]:
for k,v in pt_layer.named_parameters():
    print(k, v.shape)

layer.0.attention.self.query.weight torch.Size([768, 768])
layer.0.attention.self.query.bias torch.Size([768])
layer.0.attention.self.key.weight torch.Size([768, 768])
layer.0.attention.self.key.bias torch.Size([768])
layer.0.attention.self.value.weight torch.Size([768, 768])
layer.0.attention.self.value.bias torch.Size([768])
layer.0.attention.output.dense.weight torch.Size([768, 768])
layer.0.attention.output.dense.bias torch.Size([768])
layer.0.attention.output.LayerNorm.weight torch.Size([768])
layer.0.attention.output.LayerNorm.bias torch.Size([768])
layer.0.intermediate.dense.weight torch.Size([3072, 768])
layer.0.intermediate.dense.bias torch.Size([3072])
layer.0.output.dense.weight torch.Size([768, 3072])
layer.0.output.dense.bias torch.Size([768])
layer.0.output.LayerNorm.weight torch.Size([768])
layer.0.output.LayerNorm.bias torch.Size([768])
layer.1.attention.self.query.weight torch.Size([768, 768])
layer.1.attention.self.query.bias torch.Size([768])
layer.1.attention.self.k