In [2]:
from transformers import BigBirdForPreTraining, FlaxBigBirdForPreTraining, FlaxBigBirdForQuestionAnswering, BigBirdForQuestionAnswering
import numpy as np
import jax
import torch
import jax.numpy as jnp

MODEL_ID = "google/bigbird-base-trivia-itc"

def get_difference(torch_array, jx_array):
    torch_array = torch_array.detach().numpy()
    jx_array = np.array(jx_array)
    return np.max(torch_array - jx_array)

In [3]:
fx_model = FlaxBigBirdForQuestionAnswering.from_pretrained(MODEL_ID, attention_type="block_sparse", from_pt=True, hidden_act="gelu_new")
# model = BigBirdForQuestionAnswering.from_pretrained(MODEL_ID, attention_type="block_sparse", hidden_act="gelu_new")

Some weights of the model checkpoint at google/bigbird-base-trivia-itc were not used when initializing FlaxBigBirdForQuestionAnswering: {('bert', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxBigBirdForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBigBirdForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
np.random.seed(0)
array = np.random.randint(1, 1024, size=(1, 768))

jx_array = jnp.array(array)
torch_array = torch.tensor(array, dtype=torch.long)

jx_array.shape, torch_array.shape

((1, 768), torch.Size([1, 768]))

In [17]:
@jax.jit
def forward(jx_array):
    return fx_model(jx_array)

fx_out = forward(jx_array)

In [16]:
fx_out = fx_model(jx_array)

In [11]:
fx_out

FlaxQuestionAnsweringModelOutput(start_logits=DeviceArray([[-1.00001419e+06, -9.48659801e+00, -1.10608101e+01,
              -1.45308685e+01, -1.08723917e+01, -1.05584536e+01,
              -1.03852825e+01, -1.27956619e+01,  1.04537177e+00,
              -1.01275349e+01, -1.15132895e+01, -1.03173103e+01,
              -1.25697193e+01, -1.11597233e+01,  1.14302731e+00,
              -9.31546021e+00, -1.15054226e+01, -1.05190983e+01,
              -1.28111706e+01, -1.27645092e+01, -9.04795647e+00,
              -1.11875362e+01,  1.66731906e+00, -8.75372124e+00,
              -1.29193478e+01, -1.15552654e+01, -1.18228273e+01,
              -1.34917250e+01, -5.79090261e+00, -1.63757706e+00,
              -1.17210541e+01, -8.59663773e+00, -1.43713589e+01,
              -1.35778408e+01, -1.34649134e+01, -1.17007799e+01,
              -1.08012562e+01, -1.23075209e+01, -1.17841101e+01,
              -1.15530958e+01, -1.17487135e+01, -1.60052261e+01,
              -1.33426704e+01, -1.41677761e+

In [5]:
with torch.no_grad():
    torch_out = model(torch_array)
# torch_out

In [6]:
print("difference in inputs:", get_difference(torch_array, jx_array))

print("difference in start logits", get_difference(torch_out["start_logits"], fx_out["start_logits"]))

print("difference in end logits", get_difference(torch_out["end_logits"], fx_out["end_logits"]))

# print("difference in logits:", get_difference(torch_out["prediction_logits"], fx_out["prediction_logits"]))

# print("difference in seq-logits:", get_difference(torch_out["seq_relationship_logits"], fx_out["seq_relationship_logits"]))

difference in inputs: 0
difference in start logits 1.335144e-05
difference in end logits 1.04904175e-05


In [7]:
fx_out["start_logits"].shape, fx_out["end_logits"].shape

((1, 768), (1, 768))

In [1]:
from transformers import FlaxBigBirdForMultipleChoice, BigBirdConfig, FlaxBertForMultipleChoice, BertConfig

In [2]:
m = FlaxBigBirdForMultipleChoice(BigBirdConfig())



In [9]:
m = FlaxBertForMultipleChoice(BertConfig())