In [1]:
from bigbird.core import modeling
from bigbird.core import utils
from bigbird.pretrain.run_pretraining import MaskedLMLayer, NSPLayer, serving_input_fn_builder

from transformers import BigBirdForPreTraining, BigBirdConfig, BigBirdTokenizer, BigBirdForMaskedLM, BigBirdModel

import tensorflow.compat.v2 as tf
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F

tf.enable_v2_behavior()

def difference_between_tensors(tf_tensor, pt_tensor):
    tf_np = np.array(tf_tensor)
    pt_np = np.array(pt_tensor.detach())
    return np.max(np.abs(tf_np - pt_np))

TF_CKPT_DIR = "ckpt/bigbr_base/model.ckpt-0"
HF_CKPT_DIR = "google/bigbird-base/pytorch_model.bin"

In [2]:
# ifn = serving_input_fn_builder(batch_size=1, max_encoder_length=30,
#                     vocab_model_file="google/bigbird-base/gpt2.model", substitute_newline=False)

# # t = BigBirdTokenizer("google/bigbird-base/gpt2.model")
# # t.save_pretrained("google/bigbird-base")
# t = BigBirdTokenizer.from_pretrained("google/bigbird-base/")

# print(t(["This is a long example input string containing special characters .$?-, numbers 2872 234 12 and words."]).input_ids)
# # ifn()

In [3]:
bigbird_config = {
      # transformer basic configs
      "vocab_size": 50358,
      "attention_probs_dropout_prob": 0.1,
      "hidden_act": "gelu",
      "hidden_dropout_prob": 0.1,
      "hidden_size": 768,
      "initializer_range": 0.02,
      "intermediate_size": 3072,
      "max_position_embeddings": 4096,
      "num_attention_heads": 12,
      "num_hidden_layers": 12,
      "type_vocab_size": 2,
      "use_bias": True,
      "rescale_embedding": False,
      "scope": "bert",
      # sparse mask configs
      "attention_type": "block_sparse", # "block_sparse" "original_full" "simulated_sparse"
      "norm_type": "postnorm",
      "block_size": 128,
      "num_rand_blocks": 1,
      # common bert configs
      "max_encoder_length": 1024,
      "batch_size": 1,
}

hf_bigbird_config = BigBirdConfig.from_dict(bigbird_config)
hf_bigbird_config.hidden_act = "gelu_fast"
hf_bigbird_config.num_random_blocks = bigbird_config["num_rand_blocks"]

In [4]:
# tf.compat.v1.set_random_seed(0)
np.random.seed(0)

s1 = bigbird_config["batch_size"]
s2 = bigbird_config["max_encoder_length"]

arr = np.random.randint(1, s2, size=s1*s2).reshape(s1, s2)

input_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
hf_input_ids = torch.from_numpy(arr).long()

In [5]:
model = modeling.BertModel(bigbird_config)
masked_lm = MaskedLMLayer(bigbird_config["hidden_size"], bigbird_config["vocab_size"], model.embeder, activation_fn=utils.get_activation(bigbird_config["hidden_act"]))
next_sentence = NSPLayer(bigbird_config["hidden_size"])

# building all the weights before setting-up :)
sequence_output, pooler_output = model(input_ids, training=False)
_, _ = masked_lm(sequence_output)
_, _ = next_sentence(pooler_output)

hf_model = BigBirdForPreTraining(hf_bigbird_config)

INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****


In [6]:
ckpt_reader = tf.compat.v1.train.NewCheckpointReader(TF_CKPT_DIR)
model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])
masked_lm.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(masked_lm.trainable_weights, position=0)])
next_sentence.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(next_sentence.trainable_weights, position=0)])
masked_lm.trainable = False
next_sentence.trainable = False


hf_model.load_state_dict(torch.load(HF_CKPT_DIR))
hf_model.eval()
"model weights loaded"

100%|██████████| 199/199 [00:01<00:00, 159.60it/s]
100%|██████████| 5/5 [00:00<00:00, 350.95it/s]
100%|██████████| 2/2 [00:00<00:00, 1341.10it/s]


'model weights loaded'

In [7]:
sequence_output, pooler_output = model(input_ids, training=False)
masked_lm_loss, masked_lm_log_probs = masked_lm(sequence_output)
next_sentence_loss, next_sentence_log_probs = next_sentence(pooler_output)

hf_out = hf_model(hf_input_ids, output_attentions=True)
hf_sequence_output = hf_model.sequence_output
hf_pooler_output = hf_model.pooler_output

hf_masked_lm_log_probs = F.log_softmax(hf_out.prediction_logits, dim=-1)
hf_next_sentence_log_probs = F.log_softmax(hf_out.seq_relationship_logits, dim=-1)

INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****


In [8]:
ap = hf_model.bert.encoder.layer[0].attention.self.ap
my_cl = hf_model.bert.encoder.layer[0].attention.self.my_cl
final_cl = hf_model.bert.encoder.layer[0].attention.self.final_cl.transpose(1,2)
my_cl.shape, final_cl.shape

(torch.Size([1, 12, 1024, 64]), torch.Size([1, 12, 1024, 64]))

In [9]:
seqlen = bigbird_config["max_encoder_length"]
block_size = bigbird_config["block_size"]
for k in range(0, seqlen, block_size):
    d = (my_cl[:,:,k:k+block_size,:]-final_cl[:,:,k:k+block_size,:]).max()
    print(d)

tensor(0., grad_fn=<MaxBackward1>)
tensor(0.3152, grad_fn=<MaxBackward1>)
tensor(0.3984, grad_fn=<MaxBackward1>)
tensor(0.7096, grad_fn=<MaxBackward1>)
tensor(0.2550, grad_fn=<MaxBackward1>)
tensor(0.4019, grad_fn=<MaxBackward1>)
tensor(1.1921e-06, grad_fn=<MaxBackward1>)
tensor(0., grad_fn=<MaxBackward1>)


In [10]:
k = 1
(my_cl[:,:,128:128*2,:]-final_cl[:,:,128:256,:]).max()

tensor(0.3152, grad_fn=<MaxBackward1>)

In [11]:
# print(sequence_output.shape, hf_sequence_output.shape, pooler_output.shape, hf_pooler_output.shape)

In [12]:
# print("model input_ids", model.input_ids, end="\n\n")
# print("embeddding", model.word_embeddings, end="\n\n")

# print("l1 layer_input", model.encoder.l1_layer_input, end="\n\n")
# print("l1 attn_mask", model.encoder.l1_attention_mask, end="\n\n")
# print("l1 encoder_from_mask", model.encoder.l1_encoder_from_mask, end="\n\n")
# print("l1 encoder_to_mask", model.encoder.l1_encoder_to_mask, end="\n\n")
# print("l1 blocked_encoder_mask", model.encoder.l1_blocked_encoder_mask, end="\n\n")
# print("l1_training", model.encoder.l1_training, end="\n\n")

# print("l1 layer_output", model.encoder.l1_layer_output, end="\n\n")
# print("last layer_output", model.encoder.last_layer_output, end="\n\n")

# print("bigbird sequence out", sequence_output, end="\n\n")
# print("bigbird pooler output", pooled_output, end="\n\n")

In [13]:
# # RUN THIS FOR WEIGHTS CONVERSION

# from transformers import BigBirdForPreTraining, BigBirdConfig, load_tf_weights_in_big_bird

# config = BigBirdConfig()
# model = BigBirdForPreTraining(config)

# old = model.state_dict()

# model = load_tf_weights_in_big_bird(model, "ckpt/bigbr_base/model.ckpt-0")

# model.save_pretrained("google/bigbird-base")

In [14]:
print("difference bw input_ids:", difference_between_tensors(model.input_ids, hf_model.bert.input_ids))
print("difference bw word_embeddings:", difference_between_tensors(model.word_embeddings, hf_model.bert.word_embeddings))

print("difference bw l1 layer_input", difference_between_tensors(model.encoder.l1_layer_input, hf_model.bert.encoder.l1_layer_input))

print("difference bw l1 layer_output", difference_between_tensors(model.encoder.l1_layer_output, hf_model.bert.encoder.l1_layer_output))
print("difference bw last layer_output", difference_between_tensors(model.encoder.last_layer_output,hf_model.bert.encoder.last_layer_output))

print("difference bw bigbird sequence out", difference_between_tensors(sequence_output, hf_sequence_output), end="\n\n")

print("difference bw bigbird pooler output", difference_between_tensors(pooler_output, hf_pooler_output), end="\n\n")

print("difference bw bigbird masked_lm_log_probs", difference_between_tensors(masked_lm_log_probs, hf_masked_lm_log_probs), end="\n\n")
print("difference bw bigbird next_sentence_log_probs", difference_between_tensors(next_sentence_log_probs, hf_next_sentence_log_probs), end="\n\n")


difference bw input_ids: 0
difference bw word_embeddings: 0.0
difference bw l1 layer_input 7.1525574e-07
difference bw l1 layer_output 5.2452087e-06
difference bw last layer_output 4.5657158e-05
difference bw bigbird sequence out 4.5657158e-05

difference bw bigbird pooler output 1.5199184e-06

difference bw bigbird masked_lm_log_probs 0.00032615662

difference bw bigbird next_sentence_log_probs 0.0



In [15]:
# tf.train.list_variables("ckpt/bigbr_base/model.ckpt-0")

In [16]:
# from transformers.models.big_bird.modeling_big_bird import BigBirdAttention
# from bigbird.core.attention import MultiHeadedAttentionLayer

In [17]:
# # layer-0 debugging

# print("difference bw k:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.k, hf_model.bert.encoder.layer[0].attention.self.k
# ))

# print("difference bw q:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.q, hf_model.bert.encoder.layer[0].attention.self.q
# ))

# print("difference bw v:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.v, hf_model.bert.encoder.layer[0].attention.self.v
# ))

# print("difference bw v:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.v, hf_model.bert.encoder.layer[0].attention.self.v
# ))

# print("difference bw attn_sc:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.attn_sc, hf_model.bert.encoder.layer[0].attention.self.attn_sc
# ))

# print("difference bw attn_p:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.attn_p, hf_model.bert.encoder.layer[0].attention.self.attn_p
# ))


# print("difference bw attn_o:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.attn_o, hf_model.bert.encoder.layer[0].attention.self.attn_o
# ))

# print("difference bw attn_proj_o:", difference_between_tensors(model.encoder.encoder_layers[0].attn_proj_o, hf_model.bert.encoder.layer[0].attn_proj_o
# ))
# 
# print("difference bw do:", difference_between_tensors(model.encoder.encoder_layers[0].do, hf_model.bert.encoder.layer[0].output.do
# ))

# print("difference bw l_o:", difference_between_tensors(model.encoder.encoder_layers[0].l_o, hf_model.bert.encoder.layer[0].l_o
# ))

In [18]:
# print("difference bw hs:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.hs, hf_model.bert.encoder.layer[0].attention.self.hs))
# print("difference bw bm:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.bm, hf_model.bert.encoder.layer[0].attention.self.bm))                                                     
# print("difference bw fm:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.fm, hf_model.bert.encoder.layer[0].attention.self.fm))
# print("difference bw tm:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.tm, hf_model.bert.encoder.layer[0].attention.self.tm))
# print("difference bw fbm:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.fbm, hf_model.bert.encoder.layer[0].attention.self.fbm))
# print("difference bw tbm:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.tbm, hf_model.bert.encoder.layer[0].attention.self.tbm))

# print("difference bw ran:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.ran, hf_model.bert.encoder.layer[0].attention.self.ran))

# print("difference bw rand_mask:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.rand_mask, hf_model.bert.encoder.layer[0].attention.self.rand_mask))


In [19]:
# print("difference bw gk:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.gk, hf_model.bert.encoder.layer[0].attention.self.gk))
# print("difference bw gv:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.gv, hf_model.bert.encoder.layer[0].attention.self.gv))

# print("difference bw fcl:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.fcl, hf_model.bert.encoder.layer[0].attention.self.fcl))

# print("difference bw fcl:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.scl, hf_model.bert.encoder.layer[0].attention.self.scl))
# print("difference bw cl:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.cl, hf_model.bert.encoder.layer[0].attention.self.cl))

# print("difference bw slcl:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.cl, hf_model.bert.encoder.layer[0].attention.self.slcl))


# print("difference bw final_cl:", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.final_cl, hf_model.bert.encoder.layer[0].attention.self.final_cl))


In [20]:
# replacement of tf.gather in torch

In [21]:
# def torch_gather_b2(params, indices):
#     batch_dims = 2
#     assert params.shape[:batch_dims] == indices.shape[:batch_dims]
#     out_shape = indices.shape + params.shape[-1:]

#     out = torch.stack(
#         [torch.stack(
#             [p2[i2.flatten()] for p2, i2 in zip(p1, i1)]
#         ) for p1, i1 in zip(params, indices)]
#     )
#     return out.view(out_shape)

In [22]:
# import tensorflow as tf
# import torch
# import numpy as np

# np.random.seed(0)

# params = np.random.randn(2, 12, 256, 16, 3)
# indices = np.random.randint(2, dtype=np.int32, size=(2, 12, 256, 3))

# tf_p = tf.convert_to_tensor(params)
# tf_i = tf.convert_to_tensor(indices)

# py_p = torch.from_numpy(params)
# py_i = torch.from_numpy(indices).long()

# # output.shape = params.shape[:axis] + indices.shape[batch_dims:] + params.shape[axis + 1:]

# out_tf = tf.gather(tf_p, tf_i, batch_dims=3)
# out_pt = torch_gather_b3(py_p, py_i).view((2,12,256,3,3))
# # out_tf = tf.gather(tf_p, tf_i, batch_dims=1)
# # params = py_p
# # indices = py_i
# # out_pt = torch.stack([p1[i1.flatten()] for p1, i1 in zip(params, indices)]).view(indices.shape + params.shape[-2:])
# np.max(np.abs(out_pt.numpy() - out_tf.numpy()))

In [23]:
hf_model.bert.encoder.layer[0].attention.self.gk.shape

torch.Size([1, 12, 6, 128, 64])

In [24]:
hf_model.bert.encoder.layer[0].attention.self.ran.shape

torch.Size([1, 12, 6, 1])

In [25]:
hf_model.bert.encoder.layer[0].attention.self.ran[0, 0]

tensor([[5],
        [4],
        [7],
        [2],
        [2],
        [1]])

In [26]:
seqlen = 128
block_size = 16