In [1]:
from transformers import BigBirdForPreTraining, FlaxBigBirdForPreTraining, FlaxBigBirdForQuestionAnswering, BigBirdForQuestionAnswering
import numpy as np
import jax
import torch
import jax.numpy as jnp

MODEL_ID = "google/bigbird-base-trivia-itc"

def get_difference(torch_array, jx_array):
    torch_array = torch_array.detach().numpy()
    jx_array = np.array(jx_array)
    return np.max(torch_array - jx_array)

In [2]:
fx_model = FlaxBigBirdForQuestionAnswering.from_pretrained(MODEL_ID, attention_type="block_sparse")
model = BigBirdForQuestionAnswering.from_pretrained(MODEL_ID, attention_type="block_sparse")

Downloading: 100%|██████████| 527M/527M [01:17<00:00, 6.80MB/s]


In [3]:
np.random.seed(0)
array = np.random.randint(1, 1024, size=(1, 768))

jx_array = jnp.array(array)
torch_array = torch.tensor(array, dtype=torch.long)

jx_array.shape, torch_array.shape

((1, 768), torch.Size([1, 768]))

In [4]:
@jax.jit
def forward(jx_array):
    return fx_model(jx_array)

fx_out = forward(jx_array)

In [5]:
fx_out = fx_model(jx_array)

In [6]:
fx_out

FlaxQuestionAnsweringModelOutput(start_logits=DeviceArray([[-1.00001419e+06, -9.48659325e+00, -1.10608063e+01,
              -1.45308657e+01, -1.08723812e+01, -1.05584497e+01,
              -1.03852854e+01, -1.27956553e+01,  1.04537451e+00,
              -1.01275368e+01, -1.15132866e+01, -1.03173122e+01,
              -1.25697193e+01, -1.11597233e+01,  1.14303100e+00,
              -9.31545639e+00, -1.15054264e+01, -1.05190964e+01,
              -1.28111725e+01, -1.27645082e+01, -9.04795742e+00,
              -1.11875372e+01,  1.66732037e+00, -8.75372314e+00,
              -1.29193583e+01, -1.15552578e+01, -1.18228283e+01,
              -1.34917221e+01, -5.79090452e+00, -1.63757038e+00,
              -1.17210541e+01, -8.59663773e+00, -1.43713608e+01,
              -1.35778427e+01, -1.34649124e+01, -1.17007818e+01,
              -1.08012581e+01, -1.23075247e+01, -1.17841110e+01,
              -1.15530949e+01, -1.17487135e+01, -1.60052261e+01,
              -1.33426685e+01, -1.41677809e+

In [7]:
with torch.no_grad():
    torch_out = model(torch_array)
# torch_out

In [8]:
print("difference in inputs:", get_difference(torch_array, jx_array))

print("difference in start logits", get_difference(torch_out["start_logits"], fx_out["start_logits"]))

print("difference in end logits", get_difference(torch_out["end_logits"], fx_out["end_logits"]))

# print("difference in logits:", get_difference(torch_out["prediction_logits"], fx_out["prediction_logits"]))

# print("difference in seq-logits:", get_difference(torch_out["seq_relationship_logits"], fx_out["seq_relationship_logits"]))

difference in inputs: 0
difference in start logits 1.335144e-05
difference in end logits 1.04904175e-05


In [9]:
fx_out["start_logits"].shape, fx_out["end_logits"].shape

((1, 768), (1, 768))

In [10]:
from transformers import FlaxBigBirdForMultipleChoice, BigBirdConfig, FlaxBertForMultipleChoice, BertConfig

In [11]:
m = FlaxBigBirdForMultipleChoice(BigBirdConfig())

In [12]:
m = FlaxBertForMultipleChoice(BertConfig())