In [1]:
from transformers import BigBirdForPreTraining, FlaxBigBirdForPreTraining
import numpy as np
import jax
import torch
import jax.numpy as jnp

MODEL_ID = "google/bigbird-roberta-base"

def get_difference(torch_array, jx_array):
    torch_array = torch_array.detach().numpy()
    jx_array = np.array(jx_array)
    return np.max(torch_array - jx_array)

In [2]:
fx_model = FlaxBigBirdForPreTraining.from_pretrained(MODEL_ID, from_pt=True)
model = BigBirdForPreTraining.from_pretrained(MODEL_ID, attention_type="original_full")

Some weights of the model checkpoint at google/bigbird-roberta-base were not used when initializing FlaxBigBirdForPreTraining: {('cls', 'predictions', 'decoder', 'kernel'), ('cls', 'predictions', 'decoder', 'bias'), ('bert', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxBigBirdForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBigBirdForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
np.random.seed(0)
array = np.random.randint(1, 512, size=(2, 512))

jx_array = jnp.array(array)
torch_array = torch.tensor(array, dtype=torch.long)

jx_array, torch_array

(DeviceArray([[173,  48, 118, ..., 491,  54, 305],
              [ 95,  60, 337, ...,  51, 243, 359]], dtype=int32),
 tensor([[173,  48, 118,  ..., 491,  54, 305],
         [ 95,  60, 337,  ...,  51, 243, 359]]))

In [4]:
fx_out = fx_model(jx_array)[0]
fx_out

DeviceArray([[[-4.9603143 , -6.466676  , -4.6835575 , ..., -4.8491845 ,
               -3.3872652 , -2.416497  ],
              [-4.8720617 , -6.6423264 , -5.374215  , ..., -6.417832  ,
               -2.9115303 , -4.7255917 ],
              [ 0.7174561 , -0.50911397,  3.9928522 , ..., -7.978772  ,
               -7.620496  ,  4.976294  ],
              ...,
              [-5.566427  , -7.214348  , -4.3618603 , ..., -5.987189  ,
               -4.2361917 , -3.9470434 ],
              [-4.689574  , -7.061965  , -4.666668  , ..., -4.087977  ,
               -2.3270102 , -4.136928  ],
              [-5.369287  , -6.8011727 , -4.333805  , ..., -4.2402    ,
               -2.8998585 , -3.902777  ]],

             [[-4.069118  , -6.0870595 , -4.890304  , ..., -6.692041  ,
               -3.4898877 , -1.5123243 ],
              [-3.883913  , -6.4441943 , -4.5847635 , ..., -4.6294613 ,
               -2.660959  , -2.9854906 ],
              [-4.6196632 , -6.3770995 , -3.9329724 , ..., -4.71687

In [5]:
with torch.no_grad():
    torch_out = model(torch_array)[0]
torch_out

tensor([[[-4.9603, -6.4667, -4.6836,  ..., -4.8492, -3.3873, -2.4165],
         [-4.8721, -6.6423, -5.3742,  ..., -6.4178, -2.9115, -4.7256],
         [ 0.7175, -0.5091,  3.9928,  ..., -7.9788, -7.6205,  4.9763],
         ...,
         [-5.5664, -7.2143, -4.3619,  ..., -5.9872, -4.2362, -3.9470],
         [-4.6896, -7.0620, -4.6667,  ..., -4.0880, -2.3270, -4.1369],
         [-5.3693, -6.8012, -4.3338,  ..., -4.2402, -2.8999, -3.9028]],

        [[-4.0691, -6.0871, -4.8903,  ..., -6.6920, -3.4899, -1.5123],
         [-3.8839, -6.4442, -4.5848,  ..., -4.6295, -2.6610, -2.9855],
         [-4.6197, -6.3771, -3.9330,  ..., -4.7169, -2.1266, -2.8034],
         ...,
         [-3.6946, -5.7839, -4.1050,  ..., -3.6339, -3.0817, -2.7614],
         [-3.5686, -5.4104, -3.5697,  ..., -4.8652, -2.6640, -2.2420],
         [-2.2415, -4.9702, -3.0008,  ..., -5.3960, -2.2235, -1.1518]]])

In [6]:
print("difference in inputs:", get_difference(torch_array, jx_array))
print("difference in logits:", get_difference(torch_out, fx_out))

difference in inputs: 0
difference in logits: 6.771088e-05
