In [1]:
from bigbird.core import modeling
import tensorflow.compat.v2 as tf
from tqdm import tqdm
from transformers import BigBirdPegasusConfig, BigBirdPegasusModel, BigBirdPegasusForConditionalGeneration
import torch
import numpy as np
import os

tf.enable_v2_behavior()

In [2]:
def difference_between_tensors(tf_tensor, pt_tensor):
    tf_np = np.array(tf_tensor)
    pt_np = np.array(pt_tensor.detach())
    return np.max(np.abs(tf_np - pt_np))

TF_CKPT_DIR = "tf_ckpt/bigbird-pegasus-large-arxiv/model.ckpt-0"
HF_CKPT_DIR = "google/bigbird-pegasus-large-arxiv"


In [3]:
# 'couple_encoder_decoder' is switching pegasus & encoder-decoder

bbc = {
      # transformer basic configs
      "couple_encoder_decoder": False,
      "vocab_size": 96103,
      "attention_probs_dropout_prob": 0.0,
      "hidden_act": "gelu",
      "hidden_dropout_prob": 0.0,
      "hidden_size": 1024,
      "initializer_range": 0.02,
      "intermediate_size": 4096,
      "max_position_embeddings": 4096,
      "num_attention_heads": 16,
      "num_hidden_layers": 16,
      "num_decoder_layer": 16,
      "type_vocab_size": 1,
      "use_bias": True,
      "rescale_embedding": False,
      "scope": "pegasus",
      # sparse mask configs
      "attention_type": "block_sparse", # "block_sparse" "original_full"
      "norm_type": "prenorm",
      "block_size": 16,
      "num_rand_blocks": 3,
      # common bert configs
      "max_encoder_length": 128,
      "max_decoder_length": 16,
      "batch_size": 1,
      "beam_size": 5, #
      "alpha": 0.1, #
}
hf_bigbird_config = BigBirdPegasusConfig(
        vocab_size=bbc['vocab_size'],
        max_position_embeddings=bbc["max_position_embeddings"],
        encoder_layers=bbc["num_hidden_layers"],
        encoder_ffn_dim=bbc["intermediate_size"],
        encoder_attention_heads=bbc['num_attention_heads'],
        decoder_layers=bbc['num_decoder_layer'],
        decoder_ffn_dim=bbc["intermediate_size"],
        decoder_attention_heads=bbc['num_attention_heads'],
        encoder_layerdrop=0.0,
        decoder_layerdrop=0.0,
        use_cache=True,
        is_encoder_decoder=True,
        activation_function="gelu_fast",
        d_model=bbc['hidden_size'],
        dropout=0.1,
        attention_dropout=0.0,
        activation_dropout=0.0,
        init_std=bbc["initializer_range"],
        decoder_start_token_id=2,
        classifier_dropout=0.0,
        scale_embedding=bbc['rescale_embedding'],
        gradient_checkpointing=False,
        pad_token_id=1,
        bos_token_id=0,
        eos_token_id=2,
        attention_type=bbc['attention_type'], # only for encoder
        block_size=bbc['block_size'],
        num_random_blocks=bbc['num_rand_blocks'],
        use_bias=bbc['use_bias'],
)

bigbird_config = bbc

In [4]:
bigbird_config

{'couple_encoder_decoder': False,
 'vocab_size': 96103,
 'attention_probs_dropout_prob': 0.0,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.0,
 'hidden_size': 1024,
 'initializer_range': 0.02,
 'intermediate_size': 4096,
 'max_position_embeddings': 4096,
 'num_attention_heads': 16,
 'num_hidden_layers': 16,
 'num_decoder_layer': 16,
 'type_vocab_size': 1,
 'use_bias': True,
 'rescale_embedding': False,
 'scope': 'pegasus',
 'attention_type': 'block_sparse',
 'norm_type': 'prenorm',
 'block_size': 16,
 'num_rand_blocks': 3,
 'max_encoder_length': 128,
 'max_decoder_length': 16,
 'batch_size': 1,
 'beam_size': 5,
 'alpha': 0.1}

In [5]:
s1 = bigbird_config["batch_size"]
s2 = bigbird_config["max_encoder_length"]
s3 = bigbird_config["max_decoder_length"]

np.random.seed(0)
arr = np.random.randint(1, s2, size=s1*s2).reshape(s1, s2)
input_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
hf_input_ids = torch.from_numpy(arr).long()

np.random.seed(0)
arr = np.random.randint(1, s3, size=s1*s3).reshape(s1, s3)
target_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
hf_target_ids = torch.from_numpy(arr).long()

In [6]:
hf_model = BigBirdPegasusForConditionalGeneration(hf_bigbird_config)
hf_model.load_state_dict(torch.load(os.path.join(HF_CKPT_DIR, "pytorch_model.bin")))
hf_model.eval()
for p in hf_model.parameters():
    p.requires_grad_(False)

In [7]:
model = modeling.TransformerModel(bigbird_config)
o = model(input_ids, target_ids=target_ids)
del o

ckpt_reader = tf.compat.v1.train.NewCheckpointReader(TF_CKPT_DIR)
model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])

model.trainable = False

INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradien

In [8]:
tf_out = model(input_ids, target_ids=target_ids)

INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using block sparse attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** U

In [9]:
hf_out = hf_model(input_ids=hf_input_ids, labels=hf_target_ids)

In [10]:
tf_out[0][1]

<tf.Tensor: shape=(1, 16, 96103), dtype=float32, numpy=
array([[[-0.96725243,  1.990289  ,  0.        , ..., -1.721262  ,
         -0.9308357 , -0.92421764],
        [-0.07624511,  2.5005665 ,  0.        , ..., -0.5825788 ,
         -0.07323553, -0.0829926 ],
        [ 0.49531904,  1.8297524 ,  0.        , ..., -0.725717  ,
          0.48727834,  0.46995315],
        ...,
        [-0.60462195,  2.3246245 ,  0.        , ..., -1.9783067 ,
         -0.61941296, -0.6342458 ],
        [-0.6550408 ,  2.3161514 ,  0.        , ..., -1.7870656 ,
         -0.66218054, -0.6700583 ],
        [-0.63965726,  2.1932306 ,  0.        , ..., -1.5617602 ,
         -0.6443455 , -0.65117365]]], dtype=float32)>

In [11]:
hf_out['logits'].shape

torch.Size([1, 16, 96103])

In [12]:
# a = set([v.name[:-2] for v in model.trainable_variables])
# b = set([b[0] for b in tf.train.list_variables(TF_CKPT_DIR)])

In [13]:
print("difference in encoder out", difference_between_tensors(model.encoder_o, hf_model.model.encoder.encoder_o))

print("difference in encoder out", difference_between_tensors(tf_out[1], hf_out['encoder_last_hidden_state']))

print("difference in final out", difference_between_tensors(tf_out[0][1], hf_out['logits']))

difference in encoder out 0.00036275387
difference in encoder out 0.00036275387
difference in final out 0.00091362


In [14]:
# print("difference in embed out", difference_between_tensors(model.embed_o, hf_model.model.encoder.embed_o))

# print("difference in before_attn_o out", difference_between_tensors(model.encoder.encoder_layers[0].before_attn_o, hf_model.model.encoder.layers[0].before_attn_o))

# print("difference in after self_o out", difference_between_tensors(tf.reshape(model.encoder.encoder_layers[0].self_o, (1, 128, 1024)), hf_model.model.encoder.layers[0].self_attn.self_o))

# print("difference in after so_o out", difference_between_tensors(model.encoder.encoder_layers[0].so_o, hf_model.model.encoder.layers[0].self_attn.so_o))

# print("difference in after self-attn out", difference_between_tensors(model.encoder.encoder_layers[0].after_attn_o, hf_model.model.encoder.layers[0].after_attn_o))

# print("difference in before inter out", difference_between_tensors(model.encoder.encoder_layers[0].before_inter_o, hf_model.model.encoder.layers[0].before_inter_o))

# print("difference in after inter out", difference_between_tensors(model.encoder.encoder_layers[0].after_inter_o, hf_model.model.encoder.layers[0].after_inter_o))

# print("difference in output out", difference_between_tensors(model.encoder.encoder_layers[0].output_o, hf_model.model.encoder.layers[0].output_o))

# print("difference in l0 out", difference_between_tensors(model.encoder.l0_o, hf_model.model.encoder.l0_o))

# print("difference in l last out", difference_between_tensors(model.encoder.llast_o, hf_model.model.encoder.llast_o))

# print("difference in ki", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.ki, hf_model.model.encoder.layers[0].self_attn.self.qi))
# print("difference in qi", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.qi, hf_model.model.encoder.layers[0].self_attn.self.qi))

# print("difference in qo out", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.qo, hf_model.model.encoder.layers[0].self_attn.self.qo))

# print("difference in ko out", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.ko, hf_model.model.encoder.layers[0].self_attn.self.ko))

# print("difference in vo out", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.vo, hf_model.model.encoder.layers[0].self_attn.self.vo))