In [1]:
from bigbird.core import modeling
import tensorflow.compat.v2 as tf
from tqdm import tqdm
from transformers import BigBirdPegasusConfig, BigBirdPegasusModel, BigBirdPegasusForConditionalGeneration, BigBirdPegasusTokenizer
from bigbird.summarization.run_summarization import serving_input_fn_builder
import torch
import numpy as np
import os

tf.enable_v2_behavior()

In [2]:
# t = BigBirdPegasusTokenizer("tf_ckpt/spiece.model")
# t.save_pretrained("google/bigbird-pegasus-large-arxiv")
# t = BigBirdPegasusTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv")

In [3]:
# o = t("This is a long example input string containing special characters .\n$?-, numbers 2872 234 12 and words.", max_length=30, padding="max_length").input_ids

# print(o)
# # t.convert_ids_to_tokens(o)
# ifn = serving_input_fn_builder(batch_size=1, max_encoder_length=30, vocab_model_file="tf_ckpt/spiece.model", substitute_newline=False)

# ifn()

In [4]:
def difference_between_tensors(tf_tensor, pt_tensor):
    tf_np = np.array(tf_tensor)
    pt_np = np.array(pt_tensor.detach())
    return np.max(np.abs(tf_np - pt_np))

TF_CKPT_DIR = "tf_ckpt/bigbird-pegasus-large-pubmed/model.ckpt-300000"
HF_CKPT_DIR = "google/bigbird-pegasus-large-pubmed"


In [5]:
# 'couple_encoder_decoder' is switching pegasus & encoder-decoder

bbc = {
      # transformer basic configs
      "couple_encoder_decoder": False,
      "vocab_size": 96103,
      "attention_probs_dropout_prob": 0.0,
      "hidden_act": "gelu",
      "hidden_dropout_prob": 0.0,
      "hidden_size": 1024,
      "initializer_range": 0.02,
      "intermediate_size": 4096,
      "max_position_embeddings": 4096,
      "num_attention_heads": 16,
      "num_hidden_layers": 16,
      "num_decoder_layer": 16,
      "type_vocab_size": 1,
      "use_bias": True,
      "rescale_embedding": False,
      "scope": "pegasus",
      # sparse mask configs
      "attention_type": "original_full", # "block_sparse" "original_full"
      "norm_type": "prenorm",
      "block_size": 16,
      "num_rand_blocks": 3,
      # common bert configs
      "max_encoder_length": 1024,
      "max_decoder_length": 16,
      "batch_size": 1,
      "beam_size": 5, #
      "alpha": 0.1, #
}
hf_bigbird_config = BigBirdPegasusConfig(
        vocab_size=bbc['vocab_size'],
        max_position_embeddings=bbc["max_position_embeddings"],
        encoder_layers=bbc["num_hidden_layers"],
        encoder_ffn_dim=bbc["intermediate_size"],
        encoder_attention_heads=bbc['num_attention_heads'],
        decoder_layers=bbc['num_decoder_layer'],
        decoder_ffn_dim=bbc["intermediate_size"],
        decoder_attention_heads=bbc['num_attention_heads'],
        encoder_layerdrop=0.0,
        decoder_layerdrop=0.0,
        use_cache=True,
        is_encoder_decoder=True,
        activation_function="gelu_fast",
        d_model=bbc['hidden_size'],
        dropout=0.1,
        attention_dropout=0.0,
        activation_dropout=0.0,
        init_std=bbc["initializer_range"],
        decoder_start_token_id=2,
        classifier_dropout=0.0,
        scale_embedding=bbc['rescale_embedding'],
        gradient_checkpointing=False,
        pad_token_id=1,
        bos_token_id=0,
        eos_token_id=2,
        attention_type=bbc['attention_type'], # only for encoder
        block_size=bbc['block_size'],
        num_random_blocks=bbc['num_rand_blocks'],
        use_bias=bbc['use_bias'],
)

bigbird_config = bbc

In [6]:
bigbird_config

{'couple_encoder_decoder': False,
 'vocab_size': 96103,
 'attention_probs_dropout_prob': 0.0,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.0,
 'hidden_size': 1024,
 'initializer_range': 0.02,
 'intermediate_size': 4096,
 'max_position_embeddings': 4096,
 'num_attention_heads': 16,
 'num_hidden_layers': 16,
 'num_decoder_layer': 16,
 'type_vocab_size': 1,
 'use_bias': True,
 'rescale_embedding': False,
 'scope': 'pegasus',
 'attention_type': 'original_full',
 'norm_type': 'prenorm',
 'block_size': 16,
 'num_rand_blocks': 3,
 'max_encoder_length': 1024,
 'max_decoder_length': 16,
 'batch_size': 1,
 'beam_size': 5,
 'alpha': 0.1}

In [7]:
s1 = bigbird_config["batch_size"]
s2 = bigbird_config["max_encoder_length"]
s3 = bigbird_config["max_decoder_length"]

np.random.seed(0)
arr = np.random.randint(1, s2, size=s1*s2).reshape(s1, s2)
input_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
hf_input_ids = torch.from_numpy(arr).long()

np.random.seed(0)
arr = np.random.randint(1, s3, size=s1*s3).reshape(s1, s3)
target_ids = tf.convert_to_tensor(arr, dtype=tf.int32)
hf_target_ids = torch.from_numpy(arr).long()

In [8]:
hf_model = BigBirdPegasusForConditionalGeneration(hf_bigbird_config)
hf_model.load_state_dict(torch.load(os.path.join(HF_CKPT_DIR, "pytorch_model.bin")))
hf_model.eval()
for p in hf_model.parameters():
    p.requires_grad_(False)

In [9]:
hf_model.config.attention_type

'original_full'

In [10]:
model = modeling.TransformerModel(bigbird_config)
o = model(input_ids, target_ids=target_ids)
del o

ckpt_reader = tf.compat.v1.train.NewCheckpointReader(TF_CKPT_DIR)
model.set_weights([ckpt_reader.get_tensor(v.name[:-2]) for v in tqdm(model.trainable_weights, position=0)])

model.trainable = False

INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure

In [11]:
tf_out = model(input_ids, target_ids=target_ids)

INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****
INFO:absl:**** Using original full attention ****


In [12]:
hf_out = hf_model(input_ids=hf_input_ids, labels=hf_target_ids)

In [13]:
tf_out[0][1]

<tf.Tensor: shape=(1, 16, 96103), dtype=float32, numpy=
array([[[-5.5893642e-01,  4.0546751e+00,  3.8444996e-06, ...,
          8.1415701e-01, -2.2685637e-01, -4.3123407e+00],
        [ 5.5288756e-01,  3.6804142e+00, -1.3571173e-02, ...,
          2.2924857e+00,  8.5090071e-01, -5.4608064e+00],
        [ 1.4902123e+00,  2.2696016e+00,  9.4765723e-03, ...,
          2.3964798e+00,  1.6562212e+00, -7.2515621e+00],
        ...,
        [ 5.9318435e-01,  2.9363060e+00, -1.1102553e-02, ...,
          1.9723936e+00,  9.1420323e-01, -5.7508049e+00],
        [-5.3281975e-01,  9.6828508e-01,  6.5652058e-03, ...,
         -2.0260565e+00, -2.8542823e-01, -3.1168013e+00],
        [-4.5965534e-01,  5.6229985e-01, -2.2256300e-03, ...,
         -1.6695052e+00, -2.7297193e-01, -2.8035493e+00]]], dtype=float32)>

In [14]:
hf_out['logits'].shape

torch.Size([1, 16, 96103])

In [15]:
# a = set([v.name[:-2] for v in model.trainable_variables])
# b = set([b[0] for b in tf.train.list_variables(TF_CKPT_DIR)])

In [16]:
print("difference in encoder out", difference_between_tensors(model.encoder_o, hf_model.model.encoder.encoder_o))

print("difference in encoder out", difference_between_tensors(tf_out[1], hf_out['encoder_last_hidden_state']))

print("difference in final out", difference_between_tensors(tf_out[0][1], hf_out['logits']))

difference in encoder out 0.0009796619
difference in encoder out 0.0009796619
difference in final out 0.0019226074


In [17]:
hf_out.logits[0, 4:8, 128:156]

tensor([[ 3.7736,  0.6459,  5.9393, -2.0550,  1.3957,  1.6994,  1.7002,  4.3194,
          6.7270,  0.8877,  2.7457,  0.3128, -0.3091,  3.7636,  6.4191,  3.2155,
         -0.9953,  7.4407,  3.8938,  0.4070,  3.7436,  3.7248,  5.6073,  3.8378,
         -1.9400,  5.2315,  4.6829,  2.0397],
        [ 3.8075,  0.5993,  5.9881, -1.9268,  1.7395,  1.9801,  1.4785,  4.4040,
          6.9427,  0.6825,  2.8742,  0.7088, -0.6241,  3.3309,  6.5836,  3.2848,
         -1.1375,  7.2144,  4.1101,  0.8657,  3.9520,  3.5079,  5.4696,  3.9301,
         -2.2243,  5.2562,  4.6510,  1.9688],
        [ 3.9393,  0.5811,  6.1118, -1.9829,  1.9584,  2.0622,  1.6118,  4.5815,
          7.1832,  0.6703,  2.9474,  0.8766, -0.7241,  3.3090,  6.7720,  3.4544,
         -1.0948,  7.0197,  4.2286,  1.1543,  4.0334,  3.4939,  5.5613,  4.1545,
         -2.2169,  5.4238,  4.7881,  1.9614],
        [ 4.0004,  0.5768,  6.1671, -2.1092,  2.0556,  2.0222,  1.5487,  4.5812,
          7.2975,  0.7099,  3.0134,  0.9069, -0.8406

In [18]:
# print("difference in embed out", difference_between_tensors(model.embed_o, hf_model.model.encoder.embed_o))

# print("difference in before_attn_o out", difference_between_tensors(model.encoder.encoder_layers[0].before_attn_o, hf_model.model.encoder.layers[0].before_attn_o))

# print("difference in after self_o out", difference_between_tensors(tf.reshape(model.encoder.encoder_layers[0].self_o, (1, 128, 1024)), hf_model.model.encoder.layers[0].self_attn.self_o))

# print("difference in after so_o out", difference_between_tensors(model.encoder.encoder_layers[0].so_o, hf_model.model.encoder.layers[0].self_attn.so_o))

# print("difference in after self-attn out", difference_between_tensors(model.encoder.encoder_layers[0].after_attn_o, hf_model.model.encoder.layers[0].after_attn_o))

# print("difference in before inter out", difference_between_tensors(model.encoder.encoder_layers[0].before_inter_o, hf_model.model.encoder.layers[0].before_inter_o))

# print("difference in after inter out", difference_between_tensors(model.encoder.encoder_layers[0].after_inter_o, hf_model.model.encoder.layers[0].after_inter_o))

# print("difference in output out", difference_between_tensors(model.encoder.encoder_layers[0].output_o, hf_model.model.encoder.layers[0].output_o))

# print("difference in l0 out", difference_between_tensors(model.encoder.l0_o, hf_model.model.encoder.l0_o))

# print("difference in l last out", difference_between_tensors(model.encoder.llast_o, hf_model.model.encoder.llast_o))

# print("difference in ki", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.ki, hf_model.model.encoder.layers[0].self_attn.self.qi))
# print("difference in qi", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.qi, hf_model.model.encoder.layers[0].self_attn.self.qi))

# print("difference in qo out", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.qo, hf_model.model.encoder.layers[0].self_attn.self.qo))

# print("difference in ko out", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.ko, hf_model.model.encoder.layers[0].self_attn.self.ko))

# print("difference in vo out", difference_between_tensors(model.encoder.encoder_layers[0].attn_layer.vo, hf_model.model.encoder.layers[0].self_attn.self.vo))

In [19]:
# bigbird pegasus large pubmed
# difference in encoder out 0.0002682209
# difference in encoder out 0.0002682209
# difference in final out 0.0008444786

# bigbird pegasus large bigpatent
# difference in encoder out 0.00029605627
# difference in encoder out 0.00029605627
# difference in final out 0.00074386597

# bigbird pegasus large arxiv
# difference in encoder out 0.0005502105
# difference in encoder out 0.0005502105
# difference in final out 0.00051164627


# bigbird pegasus large
# difference in encoder out 0.00011986494
# difference in encoder out 0.00011986494
# difference in final out 0.012252808