In [1]:
import official.nlp.projects.triviaqa.modeling as modeling
from official.nlp.configs.encoders import EncoderConfig
import os

from transformers import BigBirdConfig, BigBirdModel, BigBirdForQuestionAnswering

import tensorflow as tf
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F

def difference_between_tensors(tf_tensor, pt_tensor):
    tf_np = np.array(tf_tensor)
    pt_np = np.array(pt_tensor.detach())
    return np.max(np.abs(tf_np - pt_np))

model_id = "bigbird-base-trivia-itc"
TF_CKPT_DIR = f"ckpt/{model_id}/model.ckpt-0"
HF_CKPT_DIR = f"google/{model_id}/pytorch_model.bin"

In [2]:
seqlen = 4096
config = EncoderConfig(type="bigbird")
# block_size=16, num_rand_blocks=3
hf_config = BigBirdConfig(num_hidden_layers=config.bigbird.num_layers, hidden_act="gelu_fast", attention_type="block_sparse", num_random_blocks=config.bigbird.num_rand_blocks, **config.bigbird.__dict__)

In [3]:
np.random.seed(0)
arr = np.random.randint(1, seqlen, size=seqlen).reshape(1, seqlen)

input_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
hf_input_ids = torch.from_numpy(arr).long()
input_ids.shape, tf.constant([8]).shape

(TensorShape([1, 4096]), TensorShape([1]))

In [4]:
# loading tf weights
savedmodel = tf.saved_model.load(os.path.join("ckpt", model_id))
model = modeling.TriviaQaModel(config, seqlen)
_ = model(dict(
    token_ids=input_ids,
    question_lengths=tf.constant([8], dtype=tf.int32,)
))
model.set_weights([v.numpy() for v in tqdm(savedmodel.trainable_variables)])
del savedmodel

# loading hf weights
hf_model = BigBirdForQuestionAnswering.from_pretrained(f"google/{model_id}")

INFO:absl:Encoder class: BigBirdEncoder to build...
INFO:absl:TransformerScaffold configs: {'name': 'transformer_scaffold', 'trainable': True, 'dtype': 'float32', 'attention_cls': <official.nlp.projects.bigbird.attention.BigBirdAttention object at 0x7fed2fc9cd30>, 'feedforward_cls': None, 'num_attention_heads': 12, 'intermediate_size': 3072, 'intermediate_activation': <function gelu at 0x7fedf1b47790>, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'norm_first': False, 'kernel_initializer': {'class_name': 'TruncatedNormal', 'config': {'mean': 0.0, 'stddev': 0.02, 'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}
INFO:absl:TransformerScaffold configs: {'name': 'transformer_scaffold_1', 'trainable': True, 'dtype': 'float32', 'attention_cls': <official.nlp.projects.bigbird.attention.BigBirdAttention object at 0x7fed2fc9cbb0>,

In [5]:
out = model(dict(
    token_ids=input_ids,
    question_lengths=tf.constant([8], dtype=tf.int32,)
))
start_logits, end_logits = out[:,:,0], out[:,:,1]

In [6]:
hf_start_logits, hf_end_logits = hf_model(hf_input_ids).to_tuple()

In [12]:
start_logits, hf_start_logits

(<tf.Tensor: shape=(1, 4096), dtype=float32, numpy=
 array([[-9.99999625e+05, -1.00000994e+06, -1.00000800e+06, ...,
         -1.31911230e+01,  3.67996097e-01, -9.05163383e+00]], dtype=float32)>,
 tensor([[ 1.2715, -1.1993, -0.1114,  ...,  1.2939,  1.1098,  1.3245]],
        grad_fn=<SqueezeBackward1>))

In [9]:
end_logits, hf_end_logits

(<tf.Tensor: shape=(1, 4096), dtype=float32, numpy=
 array([[-1.00000200e+06, -1.00001206e+06, -1.00001181e+06, ...,
         -1.49152393e+01, -2.01539540e+00, -1.23520317e+01]], dtype=float32)>,
 tensor([[ 1.5141,  0.2146, -0.2144,  ...,  1.5398,  1.3746,  1.5566]],
        grad_fn=<SqueezeBackward1>))

In [10]:
# model = modeling.TriviaQaModel(config, seqlen)

# # building all the weights before setting-up :)
# sequence_output, pooler_output = model(input_ids, training=False)

# hf_model = BigBirdForPreTraining(hf_bigbird_config)

In [11]:
# print("difference bw input_ids:", difference_between_tensors(model.input_ids, hf_model.bert.input_ids))
# print("difference bw word_embeddings:", difference_between_tensors(model.word_embeddings, hf_model.bert.word_embeddings))

# print("difference bw l1 layer_input", difference_between_tensors(model.encoder.l1_layer_input, hf_model.bert.encoder.l1_layer_input))

# print("difference bw l1 layer_output", difference_between_tensors(model.encoder.l1_layer_output, hf_model.bert.encoder.l1_layer_output))
# print("difference bw last layer_output", difference_between_tensors(model.encoder.last_layer_output,hf_model.bert.encoder.last_layer_output))

# print("difference bw bigbird sequence out", difference_between_tensors(sequence_output, hf_sequence_output), end="\n\n")

# print("difference bw bigbird pooler output", difference_between_tensors(pooler_output, hf_pooler_output), end="\n\n")

# print("difference bw bigbird masked_lm_log_probs", difference_between_tensors(masked_lm_log_probs, hf_masked_lm_log_probs), end="\n\n")
# print("difference bw bigbird next_sentence_log_probs", difference_between_tensors(next_sentence_log_probs, hf_next_sentence_log_probs), end="\n\n")