In [1]:
# lets debug bigbird

# vocab is same as roberta/gpt2
# for running `sumulated_sparse`, encoder_max seqlen must be 4096

# fix pooler head on top

In [2]:
from bigbird.core import modeling
from bigbird.core import utils

from transformers import BigBirdForMaskedLM, BigBirdConfig

import tensorflow.compat.v2 as tf
from tqdm import tqdm
import numpy as np
import torch

tf.enable_v2_behavior()

In [3]:
TF_CKPT_DIR = "ckpt/bigbr_base/model.ckpt-0"
HF_CKPT_DIR = "google/bigbird-base/pytorch_model.bin"

In [4]:
bigbird_config = {
      # transformer basic configs
      "vocab_size": 50358,
      "attention_probs_dropout_prob": 0.1,
      "hidden_act": "gelu",
      "hidden_dropout_prob": 0.1,
      "hidden_size": 768,
      "initializer_range": 0.02,
      "intermediate_size": 3072,
      "max_position_embeddings": 4096,
      "num_attention_heads": 12,
      "num_hidden_layers": 12,
      "type_vocab_size": 2,
      "use_bias": True,
      "rescale_embedding": False,
      "scope": "bert",
      # sparse mask configs
      "attention_type": "original_full", # "block_sparse" "original_full" "simulated_sparse"
      "norm_type": "postnorm",
      "block_size": 16,
      "num_rand_blocks": 3,
      # common bert configs
      "max_encoder_length": 128,
      "batch_size": 2,
}

hf_bigbird_config = BigBirdConfig.from_dict(bigbird_config)

In [5]:
# tf.compat.v1.set_random_seed(0)
np.random.seed(0)

s1 = bigbird_config["batch_size"]
s2 = bigbird_config["max_encoder_length"]

arr = np.random.randint(1, s2, size=s1*s2).reshape(s1, s2)

input_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
hf_input_ids = torch.from_numpy(arr).long()

In [6]:
model = modeling.BertModel(bigbird_config)
_, _ = model(input_ids, training=False) # building all the weights before setting-up :)

hf_model = BigBirdForMaskedLM(hf_bigbird_config)

INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****


In [7]:
ckpt_reader = tf.compat.v1.train.NewCheckpointReader(TF_CKPT_DIR)
model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])

hf_model.load_state_dict(torch.load(HF_CKPT_DIR))
hf_model.eval()
"model weights loaded"

100%|██████████| 199/199 [00:00<00:00, 244.67it/s]


'model weights loaded'

In [8]:
sequence_output, pooled_output = model(input_ids, training=False)

INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****


In [9]:
_ = hf_model(hf_input_ids)

In [10]:
# print("model input_ids", model.input_ids, end="\n\n")
# print("embeddding", model.word_embeddings, end="\n\n")

# print("l1 layer_input", model.encoder.l1_layer_input, end="\n\n")
# print("l1 attn_mask", model.encoder.l1_attention_mask, end="\n\n")
# print("l1 encoder_from_mask", model.encoder.l1_encoder_from_mask, end="\n\n")
# print("l1 encoder_to_mask", model.encoder.l1_encoder_to_mask, end="\n\n")
# print("l1 blocked_encoder_mask", model.encoder.l1_blocked_encoder_mask, end="\n\n")
# print("l1_training", model.encoder.l1_training, end="\n\n")

# print("l1 layer_output", model.encoder.l1_layer_output, end="\n\n")
# print("last layer_output", model.encoder.last_layer_output, end="\n\n")

# print("bigbird sequence out", sequence_output, end="\n\n")
# print("bigbird pooled output", pooled_output, end="\n\n")

In [11]:
def difference_between_tensors(tf_tensor, pt_tensor):
    tf_np = np.array(tf_tensor)
    pt_np = np.array(pt_tensor.detach())
    return np.max(np.abs(tf_np - pt_np))

In [12]:
## RUN THIS FOR WEIGHTS CONVERSION

# from transformers import BigBirdForMaskedLM, BigBirdConfig, load_tf_weights_in_big_bird

# config = BigBirdConfig()
# model = BigBirdForMaskedLM(config)

# old = model.state_dict()

# model = load_tf_weights_in_big_bird(model, "ckpt/bigbr_base/model.ckpt-0")

# model.save_pretrained("google/bigbird-base")

In [13]:
print("difference bw input_ids:", difference_between_tensors(model.input_ids, hf_model.bert.input_ids))
print("difference bw word_embeddings:", difference_between_tensors(model.word_embeddings, hf_model.bert.word_embeddings))

print("difference bw l1 layer_input", difference_between_tensors(model.encoder.l1_layer_input, hf_model.bert.encoder.l1_layer_input))

print("difference bw l1 layer_output", difference_between_tensors(model.encoder.l1_layer_output, hf_model.bert.encoder.l1_layer_output))
print("difference bw last layer_output", difference_between_tensors(model.encoder.last_layer_output,hf_model.bert.encoder.last_layer_output))

# print("difference bw bigbird sequence out", difference_between_tensors(sequence_output, hf_sequence_output), end="\n\n")
# print("difference bw bigbird pooled output", difference_between_tensors(pooled_output, hf_pooled_output), end="\n\n")

difference bw input_ids: 0
difference bw word_embeddings: 0.0
difference bw l1 layer_input 7.1525574e-07
difference bw l1 layer_output 0.0035440922
difference bw last layer_output 0.0568192


In [14]:
# tf.train.list_variables("ckpt/bigbr_base/model.ckpt-0")

In [15]:
# from transformers.models.big_bird.modeling_big_bird import BigBirdAttention
# from bigbird.core.attention import MultiHeadedAttentionLayer

In [16]:
# layer-0 debugging

print("difference bw k:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.k, hf_model.bert.encoder.layer[0].attention.self.k
))

print("difference bw q:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.q, hf_model.bert.encoder.layer[0].attention.self.q
))

print("difference bw v:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.v, hf_model.bert.encoder.layer[0].attention.self.v
))

print("difference bw v:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.v, hf_model.bert.encoder.layer[0].attention.self.v
))

print("difference bw attn_sc:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.attn_sc, hf_model.bert.encoder.layer[0].attention.self.attn_sc
))

print("difference bw attn_p:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.attn_p, hf_model.bert.encoder.layer[0].attention.self.attn_p
))


print("difference bw attn_o:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.attn_o, hf_model.bert.encoder.layer[0].attention.self.attn_o
))

print("difference bw attn_proj_o:", difference_between_tensors(model.encoder.encoder_layers[0].attn_proj_o, hf_model.bert.encoder.layer[0].attn_proj_o
))

print("difference bw int_o:", difference_between_tensors(model.encoder.encoder_layers[0].int_o, hf_model.bert.encoder.layer[0].int_o
))

print("difference bw io:", difference_between_tensors(model.encoder.encoder_layers[0].io, hf_model.bert.encoder.layer[0].output.io
))

print("difference bw o:", difference_between_tensors(model.encoder.encoder_layers[0].o, hf_model.bert.encoder.layer[0].output.o
))

# print("difference bw do:", difference_between_tensors(model.encoder.encoder_layers[0].do, hf_model.bert.encoder.layer[0].output.do
# ))

# print("difference bw l_o:", difference_between_tensors(model.encoder.encoder_layers[0].l_o, hf_model.bert.encoder.layer[0].l_o
# ))

difference bw k: 5.2452087e-06
difference bw q: 5.722046e-06
difference bw v: 1.4305115e-06
difference bw v: 1.4305115e-06
difference bw attn_sc: 3.0517578e-05
difference bw attn_p: 4.708767e-06
difference bw attn_o: 1.013279e-05
difference bw attn_proj_o: 8.583069e-06
difference bw int_o: 0.0004749298
difference bw io: 0.0004749298
difference bw o: 0.011826634
