In [1]:
from transformers import BigBirdForPreTraining, FlaxBigBirdForPreTraining, FlaxBigBirdForQuestionAnswering, BigBirdForQuestionAnswering
import numpy as np
import jax
import torch
import jax.numpy as jnp

MODEL_ID = "google/bigbird-base-trivia-itc"

def get_difference(torch_array, jx_array):
    torch_array = torch_array.detach().numpy()
    jx_array = np.array(jx_array)
    return np.max(torch_array - jx_array)

In [2]:
fx_model = FlaxBigBirdForQuestionAnswering.from_pretrained(MODEL_ID, attention_type="block_sparse", from_pt=True, hidden_act="gelu_new")
model = BigBirdForQuestionAnswering.from_pretrained(MODEL_ID, attention_type="block_sparse", hidden_act="gelu_new")

Some weights of the model checkpoint at google/bigbird-base-trivia-itc were not used when initializing FlaxBigBirdForQuestionAnswering: {('bert', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxBigBirdForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBigBirdForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
np.random.seed(0)
array = np.random.randint(1, 1024, size=(1, 768))

jx_array = jnp.array(array)
torch_array = torch.tensor(array, dtype=torch.long)

jx_array.shape, torch_array.shape

((1, 768), torch.Size([1, 768]))

In [4]:
fx_out = fx_model(jx_array)
# fx_out

In [5]:
with torch.no_grad():
    torch_out = model(torch_array)
# torch_out

In [6]:
print("difference in inputs:", get_difference(torch_array, jx_array))

print("difference in start logits", get_difference(torch_out["start_logits"], fx_out["start_logits"]))

print("difference in end logits", get_difference(torch_out["end_logits"], fx_out["end_logits"]))

# print("difference in logits:", get_difference(torch_out["prediction_logits"], fx_out["prediction_logits"]))

# print("difference in seq-logits:", get_difference(torch_out["seq_relationship_logits"], fx_out["seq_relationship_logits"]))

difference in inputs: 0
difference in start logits 1.335144e-05
difference in end logits 1.04904175e-05


In [7]:
fx_out["start_logits"].shape, fx_out["end_logits"].shape

((1, 768), (1, 768))