In [1]:
import official.nlp.projects.triviaqa.modeling as modeling
from official.nlp.configs.encoders import EncoderConfig, build_encoder
import os

from transformers import BigBirdConfig, BigBirdModel, BigBirdForQuestionAnswering

import tensorflow as tf
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F

def difference_between_tensors(tf_tensor, pt_tensor):
    tf_np = np.array(tf_tensor)
    pt_np = np.array(pt_tensor.detach())
    return np.max(np.abs(tf_np - pt_np))

model_id = "bigbird-base-trivia-itc"
TF_CKPT_DIR = f"ckpt/{model_id}/model.ckpt-0"
HF_CKPT_DIR = f"google/{model_id}/pytorch_model.bin"

In [2]:
seqlen = 1024 # min seqlen we can keep in this case
config = EncoderConfig(type="bigbird")
config.bigbird.block_size = 16
hf_config = BigBirdConfig(num_hidden_layers=config.bigbird.num_layers, hidden_act="gelu_fast", attention_type="block_sparse", num_random_blocks=config.bigbird.num_rand_blocks, **config.bigbird.__dict__)

In [3]:
np.random.seed(0)
arr = np.random.randint(1, seqlen, size=seqlen).reshape(1, seqlen)
sep_pos = 9
arr[:, sep_pos] = 66 # sep_id

input_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
question_lengths = tf.constant([sep_pos+1], dtype=tf.int32)

hf_input_ids = torch.from_numpy(arr).long()

In [4]:
input_ids.numpy()[0, :128]

array([ 685,  560,  630,  193,  836,  764,  708,  360,   10,   66,  278,
        755,  805,  600,   71,  473,  601,  397,  315,  706,  487,  552,
         88,  175,  601,  850,  678,  538,  846,   73,  778,  917,  116,
        977,  756,  710, 1023,  848,  432,  449,  851,  100,  985,  178,
        756,  798,  660,  148,  911,  424,  289,  962,  266,  698,  640,
        545,  544,  715,  245,  152,  676,  511,  460,  883,  184,   29,
        803,  129,  129,  933,   54,  902,  551,  489,  757,  274,  336,
        389,  618,   43,  443,  544,  889,  258,  322, 1000,  938,   58,
        292,  871,  120,  780,  431,   83,   92,  897,  399,  612,  566,
        909,  634,  939,   85,  204,  325,  775,  965,   48,  640, 1013,
        132,  973,  869,  181, 1001,  847,  144,  661,  228,  955,  792,
        720,  910,  374,  854,  561,  306,  582], dtype=int32)

In [5]:
# loading tf weights
savedmodel = tf.saved_model.load(os.path.join("ckpt", model_id))
model = modeling.TriviaQaModel(config, seqlen)
#encoder = build_encoder(config)
# qa_head = modeling.TriviaQaHead(
#         config.get().intermediate_size,
#         dropout_rate=config.get().dropout_rate,
#         attention_dropout_rate=config.get().attention_dropout_rate)

_ = model(dict(
    token_ids=input_ids,
    question_lengths=question_lengths
))

# x = model.encoder(dict(
#     input_word_ids=inputs['token_ids'],
#     input_mask=tf.cast(inputs['token_ids'] > 0, tf.int32),
#     input_type_ids=1 - tf.sequence_mask(inputs['question_lengths'], seqlen, tf.int32)
# ))

# _ = model.qa_head(dict(
#     token_embeddings=x['sequence_output'], 
#     token_ids=inputs['token_ids'],
#     question_lengths=inputs['question_lengths']
# ))

# enc_vars = [v.name for v in encoder.variables]
# qa_vars = [v.name for v in qa_head.variables]

# encoder.set_weights([v.numpy() for v in tqdm(savedmodel.variables) if v.name in enc_vars])
# qa_head.set_weights([v.numpy() for v in tqdm(savedmodel.variables) if v.name in qa_vars])
model.set_weights([v.numpy() for v in tqdm(savedmodel.variables)])
del savedmodel
# encoder.trainable = False
# qa_head.trainable = False
model.trainable = False

# loading hf weights
hf_model = BigBirdForQuestionAnswering.from_pretrained(f"google/{model_id}")
hf_model.eval()

"model weights loaded"

INFO:absl:Encoder class: BigBirdEncoder to build...
INFO:absl:TransformerScaffold configs: {'name': 'transformer_scaffold', 'trainable': True, 'dtype': 'float32', 'attention_cls': <official.nlp.projects.bigbird.attention.BigBirdAttention object at 0x7f8469e88280>, 'feedforward_cls': None, 'num_attention_heads': 12, 'intermediate_size': 3072, 'intermediate_activation': <function gelu at 0x7f852cab0af0>, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'norm_first': False, 'kernel_initializer': {'class_name': 'TruncatedNormal', 'config': {'mean': 0.0, 'stddev': 0.02, 'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}
INFO:absl:TransformerScaffold configs: {'name': 'transformer_scaffold_1', 'trainable': True, 'dtype': 'float32', 'attention_cls': <official.nlp.projects.bigbird.attention.BigBirdAttention object at 0x7f842c0891f0>,

'model weights loaded'

In [6]:
# @tf.function
def fwd(input_ids, question_lengths):
    encoder_out = model.encoder(dict(
        input_word_ids=input_ids,
        input_mask=tf.cast(input_ids > 0, tf.int32),
        input_type_ids=1 - tf.sequence_mask(question_lengths, seqlen, tf.int32)
    ))
    out = model.qa_head(dict(
        token_embeddings=encoder_out["sequence_output"], 
        token_ids=input_ids,
        question_lengths=question_lengths,
    ))
    return out, encoder_out["sequence_output"]

out, sequence_output = fwd(input_ids, question_lengths)
start_logits, end_logits = out[:,:,0], out[:,:,1]

In [7]:
# hf_model = BigBirdForQuestionAnswering.from_pretrained(f"google/{model_id}", block_size=16)
hf_start_logits, hf_end_logits = hf_model(hf_input_ids).to_tuple()
hf_sequence_output = hf_model.encoder_out

In [11]:
# print("difference bw input_ids:", difference_between_tensors(model.input_ids, hf_model.bert.input_ids))
# print("difference bw word_embeddings:", difference_between_tensors(model.word_embeddings, hf_model.bert.word_embeddings))

# print("difference bw l1 layer_input", difference_between_tensors(model.encoder.l1_layer_input, hf_model.bert.encoder.l1_layer_input))

# print("difference bw l1 layer_output", difference_between_tensors(model.encoder.l1_layer_output, hf_model.bert.encoder.l1_layer_output))
# print("difference bw last layer_output", difference_between_tensors(model.encoder.last_layer_output,hf_model.bert.encoder.last_layer_output))

print("difference bw encoder sequence out", difference_between_tensors(sequence_output, hf_sequence_output), end="\n\n")

# print("difference bw bigbird-qa logits", difference_between_tensors(pooler_output, hf_pooler_output), end="\n\n")

# print("difference bw bigbird masked_lm_log_probs", difference_between_tensors(masked_lm_log_probs, hf_masked_lm_log_probs), end="\n\n")
# print("difference bw bigbird next_sentence_log_probs", difference_between_tensors(next_sentence_log_probs, hf_next_sentence_log_probs), end="\n\n")

difference bw encoder sequence out 11.681164



In [14]:
hf_sequence_output.shape, hf_sequence_output

(torch.Size([1, 1024, 768]),
 tensor([[[-1.0275e-01,  1.2287e-01,  8.9848e-02,  ...,  6.8178e-01,
           -2.5683e-01,  2.1847e-02],
          [-2.1408e-01,  1.1707e-01,  1.0406e-01,  ...,  6.1583e-01,
           -2.1482e-01,  2.6167e-03],
          [-2.2626e-01,  1.0723e-01,  9.2723e-02,  ...,  5.8291e-01,
           -2.0993e-01,  1.2698e-02],
          ...,
          [-1.3344e-02,  4.8234e-02,  6.4713e-02,  ...,  6.3735e-01,
           -2.0217e-01,  4.9484e-03],
          [-1.4869e-01,  1.2185e-01,  8.7017e-02,  ...,  6.6509e-01,
           -2.1420e-01,  4.0150e-04],
          [-1.7408e-01,  1.4138e-01,  1.2934e-01,  ...,  7.4175e-01,
           -2.4669e-01, -1.9718e-02]]], grad_fn=<NativeLayerNormBackward>))

In [13]:
sequence_output

<tf.Tensor: shape=(1, 1024, 768), dtype=float32, numpy=
array([[[-0.6282492 ,  0.29299787, -0.5213406 , ...,  0.44641694,
         -0.43187696, -0.1144705 ],
        [-0.6360485 , -0.12935445, -0.1261341 , ...,  0.31974655,
         -0.09079605,  0.04832615],
        [-0.46901706,  0.06238546, -0.41553906, ..., -0.00190075,
          0.1418799 , -0.11939595],
        ...,
        [ 0.32888898, -0.1809986 , -0.39110774, ...,  0.5249772 ,
          0.27071732,  0.17822777],
        [-0.4122319 ,  0.08259138, -0.28863332, ...,  0.41247615,
          0.02627292, -0.08267803],
        [-0.40902504,  0.17081302, -0.15940398, ...,  0.70709133,
          0.02209761, -0.09360607]]], dtype=float32)>