In [1]:
from transformers import BigBirdForPreTraining, FlaxBigBirdForPreTraining, FlaxBigBirdForQuestionAnswering, BigBirdForQuestionAnswering
import numpy as np
import jax
import torch
import jax.numpy as jnp

MODEL_ID = "google/bigbird-base-trivia-itc"

def get_difference(torch_array, jx_array):
    torch_array = torch_array.detach().numpy()
    jx_array = np.array(jx_array)
    return np.max(torch_array - jx_array)

In [2]:
fx_model = FlaxBigBirdForQuestionAnswering.from_pretrained(MODEL_ID, attention_type="block_sparse", from_pt=True, hidden_act="gelu_new")
model = BigBirdForQuestionAnswering.from_pretrained(MODEL_ID, attention_type="block_sparse", hidden_act="gelu_new")

hidden_states [[[ 1.790036    1.2783235   0.8206438  ...  0.05368678  0.58942956
   -2.2866602 ]
  [ 2.1765451   0.8971078   1.0146816  ...  0.7316095   0.38268682
   -1.7438451 ]
  [ 1.9386533   0.8805409   0.7943621  ... -0.00664092  1.0473495
   -1.7510742 ]
  ...
  [ 1.288487    1.0392525   0.8836512  ...  0.7352993   1.2431761
   -1.4505898 ]
  [ 0.9936909   1.4865384   0.66211474 ...  0.5886477   0.47717008
   -2.124638  ]
  [ 1.0428754   1.4412599   0.9249919  ...  0.23503263  0.64811
   -1.8501263 ]]]
logits [[[ 0.37246317 -0.06472513]
  [ 0.3184571  -0.82387245]
  [-0.10829401 -1.0957636 ]
  ...
  [ 0.7186227  -0.2963915 ]
  [ 0.23671004  0.05929509]
  [-0.08464466 -0.35181844]]]
Some weights of the model checkpoint at google/bigbird-base-trivia-itc were not used when initializing FlaxBigBirdForQuestionAnswering: {('bert', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxBigBirdForQuestionAnswering from the checkpoint of a model trained on another 

In [3]:
np.random.seed(0)
array = np.random.randint(1, 1024, size=(1, 768))

jx_array = jnp.array(array)
torch_array = torch.tensor(array, dtype=torch.long)

jx_array.shape, torch_array.shape

((1, 768), torch.Size([1, 768]))

In [4]:
fx_out = fx_model(jx_array)
# fx_out

hidden_states [[[-0.6101385   0.3201888  -0.49640045 ...  0.40424678 -0.5002909
   -0.08858022]
  [-0.4044776  -0.01814827 -0.29021034 ...  0.27513313 -0.3184294
    0.08828235]
  [-0.60638195  0.09818833 -0.37506604 ... -0.07753819 -0.24246731
   -0.03164982]
  ...
  [-0.35966563  0.09503029 -0.26183507 ...  0.27247858 -0.09599186
    0.02604759]
  [-0.4004357   0.19957614 -0.26888436 ...  0.24486455 -0.11021097
    0.05327293]
  [-0.38653636  0.20197348 -0.21651582 ...  0.45053104 -0.0606011
   -0.01957168]]]
logits [[[-14.178629 -13.06072 ]
  [ -9.486593 -13.158725]
  [-11.060806 -11.636639]
  ...
  [-12.697244 -13.91427 ]
  [-13.595952 -13.285418]
  [-13.735154 -13.877777]]]


In [5]:
with torch.no_grad():
    torch_out = model(torch_array)
# torch_out

In [6]:
print("difference in inputs:", get_difference(torch_array, jx_array))

print("difference in start logits", get_difference(torch_out["start_logits"], fx_out["start_logits"]))

print("difference in end logits", get_difference(torch_out["end_logits"], fx_out["end_logits"]))

# print("difference in logits:", get_difference(torch_out["prediction_logits"], fx_out["prediction_logits"]))

# print("difference in seq-logits:", get_difference(torch_out["seq_relationship_logits"], fx_out["seq_relationship_logits"]))

difference in inputs: 0
difference in start logits 1.335144e-05
difference in end logits 1.04904175e-05


In [7]:
fx_out["start_logits"].shape, fx_out["end_logits"].shape

((1, 768), (1, 768))