In [1]:
from transformers import BigBirdForPreTraining, FlaxBigBirdForPreTraining
import numpy as np
import jax
import torch
import jax.numpy as jnp

MODEL_ID = "google/bigbird-roberta-base"

def get_difference(torch_array, jx_array):
    torch_array = torch_array.detach().numpy()
    jx_array = np.array(jx_array)
    return np.max(torch_array - jx_array)

In [2]:
fx_model = FlaxBigBirdForPreTraining.from_pretrained(MODEL_ID, from_pt=True)
model = BigBirdForPreTraining.from_pretrained(MODEL_ID, attention_type="original_full")

Some weights of the model checkpoint at google/bigbird-roberta-base were not used when initializing FlaxBigBirdForPreTraining: {('cls', 'predictions', 'decoder', 'kernel'), ('bert', 'embeddings', 'position_ids'), ('cls', 'predictions', 'decoder', 'bias')}
- This IS expected if you are initializing FlaxBigBirdForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBigBirdForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
np.random.seed(0)
array = np.random.randint(1, 512, size=(2, 512))

jx_array = jnp.array(array)
torch_array = torch.tensor(array, dtype=torch.long)

jx_array, torch_array

(DeviceArray([[ 685,  560, 1654, ..., 1515, 1078,  817],
              [1119,   60,  337, ...,   51,  243,  871]], dtype=int32),
 tensor([[ 685,  560, 1654,  ..., 1515, 1078,  817],
         [1119,   60,  337,  ...,   51,  243,  871]]))

In [4]:
fx_out = fx_model(jx_array)[0]
fx_out

DeviceArray([[[ -8.147562 ,  -9.594704 ,  -7.671412 , ...,  -8.259321 ,
                -4.6958823,  -7.7564836],
              [ -7.5017667, -10.02726  ,  -9.095716 , ...,  -5.152514 ,
                -5.898193 ,  -7.626057 ],
              [ -8.320391 , -10.639635 ,  -9.125765 , ...,  -5.7325406,
                -7.0324764,  -7.56709  ],
              ...,
              [ -6.192581 , -10.898497 ,  -8.557592 , ...,  -7.6627326,
                -7.0683737,  -8.802221 ],
              [ -6.2537403,  -8.485796 ,  -7.414158 , ...,  -6.094995 ,
                -8.348713 ,  -6.63145  ],
              [ -7.934296 , -10.426262 ,  -9.944581 , ...,  -8.070741 ,
                -7.530947 ,  -8.704152 ]],

             [[ -8.230836 ,  -9.261623 ,  -7.3377166, ...,  -8.558403 ,
                -4.5617714,  -8.134953 ],
              [ -6.553763 ,  -8.414845 ,  -8.131858 , ...,  -2.8671875,
                -4.0934815,  -7.7758017],
              [ -7.222993 ,  -7.805408 ,  -8.3035555, ...,  -3.1251

In [5]:
with torch.no_grad():
    torch_out = model(torch_array)[0]
torch_out

tensor([[[ -8.1476,  -9.5947,  -7.6714,  ...,  -8.2593,  -4.6959,  -7.7565],
         [ -7.5018, -10.0273,  -9.0957,  ...,  -5.1525,  -5.8982,  -7.6261],
         [ -8.3204, -10.6396,  -9.1258,  ...,  -5.7325,  -7.0325,  -7.5671],
         ...,
         [ -6.1926, -10.8985,  -8.5576,  ...,  -7.6627,  -7.0684,  -8.8022],
         [ -6.2537,  -8.4858,  -7.4142,  ...,  -6.0950,  -8.3487,  -6.6314],
         [ -7.9343, -10.4263,  -9.9446,  ...,  -8.0707,  -7.5309,  -8.7042]],

        [[ -8.2308,  -9.2616,  -7.3377,  ...,  -8.5584,  -4.5618,  -8.1350],
         [ -6.5538,  -8.4148,  -8.1319,  ...,  -2.8672,  -4.0935,  -7.7758],
         [ -7.2230,  -7.8054,  -8.3035,  ...,  -3.1251,  -5.3699,  -6.2098],
         ...,
         [ -8.0287,  -9.7877,  -9.0515,  ...,  -4.4214,  -4.6920,  -8.1069],
         [ -8.0191,  -9.2181,  -8.8340,  ...,  -4.7247,  -4.6359,  -8.1067],
         [ -8.1492,  -9.0395,  -8.7315,  ...,  -4.9262,  -4.8057,  -8.1705]]])

In [6]:
print("difference in inputs:", get_difference(torch_array, jx_array))
print("difference in logits:", get_difference(torch_out, fx_out))

difference in inputs: 0
difference in logits: 5.8412552e-05
