In [1]:
from transformers import BigBirdForPreTraining, FlaxBigBirdForPreTraining
import numpy as np
import jax
import torch
import jax.numpy as jnp

MODEL_ID = "google/bigbird-roberta-base"

def get_difference(torch_array, jx_array):
    torch_array = torch_array.detach().numpy()
    jx_array = np.array(jx_array)
    return np.max(torch_array - jx_array)

In [2]:
fx_model = FlaxBigBirdForPreTraining.from_pretrained(MODEL_ID, attention_type="block_sparse", from_pt=True, hidden_act="gelu_new")
model = BigBirdForPreTraining.from_pretrained(MODEL_ID, attention_type="block_sparse", hidden_act="gelu_new")

Some weights of the model checkpoint at google/bigbird-roberta-base were not used when initializing FlaxBigBirdForPreTraining: {('cls', 'predictions', 'decoder', 'bias'), ('bert', 'embeddings', 'position_ids'), ('cls', 'predictions', 'decoder', 'kernel')}
- This IS expected if you are initializing FlaxBigBirdForPreTraining from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBigBirdForPreTraining from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
np.random.seed(0)
array = np.random.randint(1, 1024, size=(1, 1024))

jx_array = jnp.array(array)
torch_array = torch.tensor(array, dtype=torch.long)

jx_array, torch_array

(DeviceArray([[685, 560, 630, ...,  51, 243, 871]], dtype=int32),
 tensor([[685, 560, 630,  ...,  51, 243, 871]]))

In [4]:
fx_out = fx_model(jx_array)[0]
fx_out

DeviceArray([[[ -7.758166 ,  -9.172134 ,  -7.3885   , ...,  -8.788565 ,
                -4.6005898,  -7.278014 ],
              [ -7.915648 ,  -9.131889 ,  -8.026362 , ...,  -7.385435 ,
                -4.9368143,  -7.2617083],
              [ -5.399763 ,  -7.598421 ,  -7.061778 , ...,  -7.9555287,
                -5.8667946,  -4.1060424],
              ...,
              [ -6.743006 ,  -9.424944 ,  -8.606188 , ...,  -5.065666 ,
                -4.483654 ,  -7.036038 ],
              [ -7.3438916,  -9.684366 ,  -8.672882 , ...,  -5.657706 ,
                -3.5089154,  -7.0832176],
              [ -8.258332 , -10.002783 ,  -8.916176 , ...,  -5.5625167,
                -4.4211135,  -7.6033983]]], dtype=float32)

In [5]:
with torch.no_grad():
    torch_out = model(torch_array)[0]
torch_out

tensor([[[ -7.7582,  -9.1721,  -7.3885,  ...,  -8.7886,  -4.6006,  -7.2780],
         [ -7.9156,  -9.1319,  -8.0264,  ...,  -7.3854,  -4.9368,  -7.2617],
         [ -5.3998,  -7.5984,  -7.0618,  ...,  -7.9555,  -5.8668,  -4.1060],
         ...,
         [ -6.7430,  -9.4249,  -8.6062,  ...,  -5.0657,  -4.4837,  -7.0360],
         [ -7.3439,  -9.6844,  -8.6729,  ...,  -5.6577,  -3.5089,  -7.0832],
         [ -8.2583, -10.0028,  -8.9162,  ...,  -5.5625,  -4.4211,  -7.6034]]])

In [6]:
print("difference in inputs:", get_difference(torch_array, jx_array))
print("difference in logits:", get_difference(torch_out, fx_out))

difference in inputs: 0
difference in logits: 9.727478e-05
