In [1]:
import official.nlp.projects.triviaqa.modeling as modeling
from official.nlp.configs.encoders import EncoderConfig, build_encoder
import os

from transformers import BigBirdConfig, BigBirdModel, BigBirdForQuestionAnswering

import tensorflow as tf
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F

def difference_between_tensors(tf_tensor, pt_tensor):
    tf_np = np.array(tf_tensor)
    pt_np = np.array(pt_tensor.detach())
    return np.max(np.abs(tf_np - pt_np))

model_id = "bigbird-base-trivia-itc"

In [2]:
seqlen = 1024 # min seqlen we can keep in this case
config = EncoderConfig(type="bigbird")
config.bigbird.block_size = 16
hf_config = BigBirdConfig(num_hidden_layers=config.bigbird.num_layers, hidden_act="gelu_fast", attention_type="block_sparse", num_random_blocks=config.bigbird.num_rand_blocks, **config.bigbird.__dict__)

In [3]:
np.random.seed(0)
arr = np.random.randint(1, seqlen, size=seqlen).reshape(1, seqlen)
sep_pos = 9
arr[:, sep_pos] = 66 # sep_id

input_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
question_lengths = tf.constant([sep_pos+1], dtype=tf.int32)

hf_input_ids = torch.from_numpy(arr).long()

In [4]:
# loading tf weights
savedmodel = tf.saved_model.load(os.path.join("ckpt", model_id))
model = modeling.TriviaQaModel(config, seqlen)
#encoder = build_encoder(config)
# qa_head = modeling.TriviaQaHead(
#         config.get().intermediate_size,
#         dropout_rate=config.get().dropout_rate,
#         attention_dropout_rate=config.get().attention_dropout_rate)

_ = model(dict(
    token_ids=input_ids,
    question_lengths=question_lengths
))

# x = model.encoder(dict(
#     input_word_ids=inputs['token_ids'],
#     input_mask=tf.cast(inputs['token_ids'] > 0, tf.int32),
#     input_type_ids=1 - tf.sequence_mask(inputs['question_lengths'], seqlen, tf.int32)
# ))

# _ = model.qa_head(dict(
#     token_embeddings=x['sequence_output'], 
#     token_ids=inputs['token_ids'],
#     question_lengths=inputs['question_lengths']
# ))

# enc_vars = [v.name for v in encoder.variables]
# qa_vars = [v.name for v in qa_head.variables]

# encoder.set_weights([v.numpy() for v in tqdm(savedmodel.variables) if v.name in enc_vars])
# qa_head.set_weights([v.numpy() for v in tqdm(savedmodel.variables) if v.name in qa_vars])
model.set_weights([v.numpy() for v in tqdm(savedmodel.variables)])
del savedmodel
# encoder.trainable = False
# qa_head.trainable = False
model.trainable = False

# loading hf weights
hf_model = BigBirdForQuestionAnswering(hf_config)
state_dict = torch.load(f"google/{model_id}/pytorch_model.bin")
hf_model.load_state_dict(state_dict)
hf_model.eval()

"model weights loaded"

INFO:absl:Encoder class: BigBirdEncoder to build...
INFO:absl:TransformerScaffold configs: {'name': 'transformer_scaffold', 'trainable': True, 'dtype': 'float32', 'attention_cls': <official.nlp.projects.bigbird.attention.BigBirdAttention object at 0x7f9554d80280>, 'feedforward_cls': None, 'num_attention_heads': 12, 'intermediate_size': 3072, 'intermediate_activation': <function gelu at 0x7f9622af0af0>, 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'norm_first': False, 'kernel_initializer': {'class_name': 'TruncatedNormal', 'config': {'mean': 0.0, 'stddev': 0.02, 'seed': None}}, 'bias_initializer': {'class_name': 'Zeros', 'config': {}}, 'kernel_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}
INFO:absl:TransformerScaffold configs: {'name': 'transformer_scaffold_1', 'trainable': True, 'dtype': 'float32', 'attention_cls': <official.nlp.projects.bigbird.attention.BigBirdAttention object at 0x7f95220a7f70>,

'model weights loaded'

In [5]:
# BUILD BIGBIRD QA MODEL

from official.nlp.projects.bigbird import attention


word_ids = input_ids
mask = tf.cast(input_ids > 0, tf.int32)
type_ids = 1 - tf.sequence_mask(question_lengths, seqlen, tf.int32)

word_embeddings = model.encoder._embedding_layer(word_ids)
position_embeddings = model.encoder._position_embedding_layer(word_embeddings)
type_embeddings = model.encoder._type_embedding_layer(type_ids)
embeddings = tf.keras.layers.Add()(
        [word_embeddings, position_embeddings, type_embeddings])
block_size = model.encoder.get_config()["block_size"]
num_layers = model.encoder.get_config()["num_layers"]
embeddings = model.encoder._embedding_norm_layer(embeddings)

data = embeddings
masks = attention.BigBirdMasks(block_size=block_size)(
        tf.cast(mask, embeddings.dtype))
l1_input = data
l = []
for i in range(num_layers):
    data = model.encoder._transformer_layers[i]([data, masks])
    l.append(data)
sequence_output=data

out = model.qa_head(dict(
        token_embeddings=sequence_output, 
        token_ids=input_ids,
        question_lengths=question_lengths,
    ))
start_logits, end_logits = out[:,:,0], out[:,:,1]

In [6]:
# hf_model = BigBirdForQuestionAnswering.from_pretrained(f"google/{model_id}", block_size=16)
hf_start_logits, hf_end_logits = hf_model(hf_input_ids).to_tuple()
hf_sequence_output = hf_model.encoder_out

In [7]:
print("difference bw input_ids:", difference_between_tensors(input_ids, hf_input_ids))

print("difference bw token_type_ids:", difference_between_tensors(type_ids, hf_model.tti))

print("difference bw word embeddings:", difference_between_tensors(word_embeddings, hf_model.bert.embeddings.we))

print("difference bw position embeddings:", difference_between_tensors(position_embeddings, hf_model.bert.embeddings.pe))

print("difference bw token type embeddings:", difference_between_tensors(type_embeddings, hf_model.bert.embeddings.tte))

print("difference bw embeddings:", difference_between_tensors(embeddings, hf_model.bert.embed))

print("difference bw l1 layer_output", difference_between_tensors(l[0], hf_model.bert.encoder.l[0]))
print("difference bw last layer_output", difference_between_tensors(l[-1],hf_model.bert.encoder.l[-1]))

print("difference bw encoder sequence out", difference_between_tensors(sequence_output, hf_sequence_output), end="\n\n")

print("difference bw start logits", difference_between_tensors(start_logits, hf_start_logits), end="\n\n")

print("difference bw end logits", difference_between_tensors(end_logits, hf_end_logits), end="\n\n")


difference bw input_ids: 0
difference bw token_type_ids: 0
difference bw word embeddings: 0.0
difference bw position embeddings: 0.0
difference bw token type embeddings: 0.0
difference bw embeddings: 1.4305115e-06
difference bw l1 layer_output 2.3841858e-06
difference bw last layer_output 0.00012207031
difference bw encoder sequence out 0.00012207031

difference bw start logits 0.020858765

difference bw end logits 0.0625



In [15]:
print("difference bw encoder out", difference_between_tensors(model.qa_head.qa_te, hf_model.qa_classifier.qa_te), end="\n\n")

print("difference bw inter", difference_between_tensors(model.qa_head.inter, hf_model.qa_classifier.inter), end="\n\n")

print("difference bw o1", difference_between_tensors(model.qa_head.o_1, hf_model.qa_classifier.output.o_1), end="\n\n")

print("difference bw it", difference_between_tensors(model.qa_head.it, hf_model.qa_classifier.output.it), end="\n\n")

print("difference bw hs", difference_between_tensors(model.qa_head.hs, hf_model.qa_classifier.output.hs), end="\n\n")


print("difference bw o", difference_between_tensors(model.qa_head.o, hf_model.qa_classifier.o), end="\n\n")

print("difference bw l", difference_between_tensors(model.qa_head.l, hf_model.l), end="\n\n")

print("difference bw final logits", difference_between_tensors(out, hf_model.fl), end="\n\n")

print("difference bw lmask", difference_between_tensors(model.qa_head.lmask, hf_model.lmask), end="\n\n")

difference bw encoder out 0.00012207031

difference bw inter 7.390976e-06

difference bw o1 2.0980835e-05

difference bw it 0.00012207031

difference bw hs 2.0980835e-05

difference bw o 0.011413574

difference bw l 0.021133423

difference bw final logits 0.0625

difference bw lmask 0.0



In [13]:
difference_between_tensors(model.variables[-5], hf_model.state_dict()["qa_classifier.output.LayerNorm.bias"])

0.0

In [9]:
# tf_q = model.encoder._transformer_layers[0]._attention_layer.q
# py_q = hf_model.bert.encoder.layer[0].attention.self.q

# print("difference bw q", difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.q, hf_model.bert.encoder.layer[0].attention.self.q))

# print("difference bw k", difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.k, hf_model.bert.encoder.layer[0].attention.self.k))

# print("difference bw v",difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.v, hf_model.bert.encoder.layer[0].attention.self.v))

In [10]:
# # deep inside attention layer

# print("difference bw l1 attn out", difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.clo, hf_model.bert.encoder.layer[0].attention.self.clo))

# print("difference bw bqm", difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.bqm, hf_model.bert.encoder.layer[0].attention.self.bqm))

# print("difference bw bkm", difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.bkm, hf_model.bert.encoder.layer[0].attention.self.bkm))

# print("difference bw bvm", difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.bvm, hf_model.bert.encoder.layer[0].attention.self.bvm))

# print("difference bw ra", difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.ra, hf_model.bert.encoder.layer[0].attention.self.ra))

# print("difference bw fcl", difference_between_tensors(model.encoder._transformer_layers[0]._attention_layer.fcl, hf_model.bert.encoder.layer[0].attention.self.fcl))