In [1]:
# This notebook shows how to convert a BERT model to tf_transformers framework
# We will model from BERT hub (downloaded locally) and convert it into
# tf_transformers model

In [1]:
# Hub model
import tensorflow_hub as hub
import json

import tensorflow as tf
from tf_transformers.models import BERTEncoder

from absl import logging
logging.set_verbosity("INFO")

In [2]:
tf.__version__

'2.3.0-rc0'

In [3]:
bert_hub = hub.KerasLayer("../../../pretrained_models/bert_uncased/",
                            trainable=True)

In [4]:
len(bert_hub.variables)

200

In [5]:
config  = json.load(open("../model_directory/bert_base/bert_config.json"))
# config['max_position_embeddings'] = 1024
# config['intermediate_size'] = 4096
# config['num_attention_heads'] = 16
# config['num_hidden_layers'] = 24
config

{'attention_probs_dropout_prob': 0.1,
 'hidden_act': 'gelu',
 'intermediate_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'embedding_size': 768,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'max_position_embeddings': 512,
 'num_attention_heads': 12,
 'num_hidden_layers': 12,
 'type_vocab_size': 2,
 'vocab_size': 30522,
 'layer_norm_epsilon': 1e-12}

In [18]:
16 * 64

1024

In [6]:
# Load Hub Model
# To download Hub Model
# https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1
# Download and unzip

bert_hub = hub.KerasLayer("../../../pretrained_models/bert_uncased/",
                            trainable=True)



config  = json.load(open("../model_directory/bert_base/bert_config.json"))
# We have Keras/Legacy Layer here (Not keras.model)
model_layer = BERTEncoder(config=config,
                  name='bert',
                  mask_mode='user_defined',
                  is_training=False
                  )

INFO:absl:We are overwriding `is_training` is False to `is_training` to True with `use_dropout` is False, no effects on your inference pipeline
INFO:absl:Inputs -->
INFO:absl:input_ids ---> Tensor("input_ids:0", shape=(None, None), dtype=int32)
INFO:absl:input_mask ---> Tensor("input_mask:0", shape=(None, None), dtype=int32)
INFO:absl:input_type_ids ---> Tensor("input_type_ids:0", shape=(None, None), dtype=int32)
INFO:absl:Initialized Variables


In [7]:
ckpt         = tf.train.load_checkpoint("/Users/PRVATE/pretrained_models/bert_base/bert_model.ckpt")
model_vars   = tf.train.list_variables("/Users/PRVATE/pretrained_models/bert_base/bert_model.ckpt")
len(model_vars)

201

In [8]:
config

{'attention_probs_dropout_prob': 0.1,
 'hidden_act': 'gelu',
 'intermediate_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'embedding_size': 768,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'max_position_embeddings': 512,
 'num_attention_heads': 12,
 'num_hidden_layers': 12,
 'type_vocab_size': 2,
 'vocab_size': 30522,
 'layer_norm_epsilon': 1e-12}

In [9]:
# BERT hub variables to BERT model
mapping_dict = {
    'bert_model/word_embeddings/embeddings:0': 'tf_transformers/bert/word_embeddings/embeddings:0',
    'bert_model/embedding_postprocessor/type_embeddings:0': 'tf_transformers/bert/type_embeddings/embeddings:0',
    'bert_model/embedding_postprocessor/position_embeddings:0': 'tf_transformers/bert/positional_embeddings/embeddings:0',
    'bert_model/embedding_postprocessor/layer_norm/gamma:0': 'tf_transformers/bert/embeddings/layer_norm/gamma:0',
    'bert_model/embedding_postprocessor/layer_norm/beta:0': 'tf_transformers/bert/embeddings/layer_norm/beta:0',
    'bert_model/pooler_transform/kernel:0': 'tf_transformers/bert/pooler_transform/kernel:0',
    'bert_model/pooler_transform/bias:0': 'tf_transformers/bert/pooler_transform/bias:0'
}

tf_transformers_bert_index_dict = {}
for index, var in enumerate(model_layer.variables):
    # print(index, var.name, var.shape)
    temp_var = var.name.replace('tf_transformers/bert/transformer/', '')
    tf_transformers_bert_index_dict[temp_var] = index
    
# legacy_ai <-- hub
assigned_map = []
assigned_map_values = []
for var in bert_hub.variables:
    
    if 'Variable:0' in var.name:
        continue
    
    temp_var = var.name.replace('bert_model/encoder/', '')
    
    # If var in mapping dict, then we can get tf_transformers_bert_index_dict[mapping_dict[var]] index
    if temp_var in mapping_dict:
        index = tf_transformers_bert_index_dict[mapping_dict[temp_var]]
        model_layer.variables[index].assign(var)
        assigned_map.append((var.name, model_layer.variables[index].name))
        assigned_map_values.append((tf.reduce_sum(var).numpy(), tf.reduce_sum(model_layer.variables[index]).numpy()))
        continue
        
    # If not in mapping_dict, then mostly it is from attention layer
    index = tf_transformers_bert_index_dict[temp_var]
    if 'query/kernel:0' in temp_var or 'key/kernel:0' in temp_var or 'value/kernel:0' in temp_var:
        # hub (2D) to tf_transformers (3D)
        model_layer.variables[index].assign(tf.reshape(var, (config['embedding_size'],
                                                            config['num_attention_heads'], 
                                                            config['embedding_size'] // config['num_attention_heads'])))
        
        assigned_map.append((var.name, model_layer.variables[index].name))
        assigned_map_values.append((tf.reduce_sum(var).numpy(), tf.reduce_sum(model_layer.variables[index]).numpy()))
        continue

    if 'query/bias:0' in temp_var or 'key/bias:0' in temp_var or 'value/bias:0' in temp_var:
        # hub (2D) to tf_transformers (3D)
        model_layer.variables[index].assign(tf.reshape(var, (config['num_attention_heads'],
                                                            config['embedding_size'] // config['num_attention_heads'])))
        assigned_map.append((var.name, model_layer.variables[index].name))
        assigned_map_values.append((tf.reduce_sum(var).numpy(), tf.reduce_sum(model_layer.variables[index]).numpy()))
        continue
        
    # Rest of the variables
    model_layer.variables[index].assign(var)
    assigned_map.append((var.name, model_layer.variables[index].name))
    assigned_map_values.append((tf.reduce_sum(var).numpy(), tf.reduce_sum(model_layer.variables[index]).numpy()))

    
logging.info("Done assigning variables weights")

INFO:absl:Done assigning variables weights


In [None]:
cls_output --> tf.Tensor(-2.06532, shape=(), dtype=float32) --> (1, 768)
token_embeddings --> tf.Tensor(-26.426851, shape=(), dtype=float32) --> (1, 3, 768)
token_logits --> tf.Tensor(-17970.629, shape=(), dtype=float32) --> (1, 3, 30522)
last_token_logits --> tf.Tensor(-4169.4917, shape=(), dtype=float32) --> (1, 30522)

In [10]:
# Compare the results from Hub and tf_transformers

results_hub      = bert_hub([tf.constant([[1, 2,3]]), tf.constant([[1, 1,1]]), tf.constant([[0,0,0]])])
results_legacy   = model_layer({'input_ids': tf.constant([[1, 2,3]]), 
                                'input_mask': tf.constant([[1,1,1]]), 
                                'input_type_ids': tf.constant([[0, 0, 0]])})


In [11]:
for r in results_hub:
    print(tf.reduce_sum(r), '-->', r.shape)

tf.Tensor(-2.0653088, shape=(), dtype=float32) --> (1, 768)
tf.Tensor(-26.426857, shape=(), dtype=float32) --> (1, 3, 768)


In [12]:
input_ids = tf.constant([[1, 9, 10, 11, 23], 
                         [1, 22, 234, 432, 2349]])
input_mask = tf.ones_like(input_ids)
input_type_ids = tf.ones_like(input_ids)
results_hub      = bert_hub([input_ids, input_mask, input_type_ids])
for r in results_hub:
    print(tf.reduce_sum(r), '-->', r.shape)

tf.Tensor(-10.324542, shape=(), dtype=float32) --> (2, 768)
tf.Tensor(-99.82985, shape=(), dtype=float32) --> (2, 5, 768)


In [13]:
for k, r in results_legacy.items():
    if isinstance(r, list):
        continue
    print(k, '-->', tf.reduce_sum(r), '-->', r.shape)

cls_output --> tf.Tensor(-2.0653143, shape=(), dtype=float32) --> (1, 768)
token_embeddings --> tf.Tensor(-26.426872, shape=(), dtype=float32) --> (1, 3, 768)
token_logits --> tf.Tensor(-17970.516, shape=(), dtype=float32) --> (1, 3, 30522)
last_token_logits --> tf.Tensor(-4169.453, shape=(), dtype=float32) --> (1, 30522)


In [14]:
# Save the model

checkpoint_dir = '../model_directory/bert_base'
ckpt    = tf.train.Checkpoint(model=model_layer)
manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=1)
save_path = manager.save()

In [15]:
input_ids = tf.constant([[1, 9, 10, 11, 23], 
                         [1, 22, 234, 432, 2349]])
input_mask = tf.ones_like(input_ids)
input_type_ids = tf.ones_like(input_ids)

results_legacy = model_layer({'input_ids': input_ids, 
                                'input_mask': input_mask, 
                                'input_type_ids': input_type_ids})


for k, r in results_legacy.items():
    if isinstance(r, list):
        continue
    print(k, '-->', tf.reduce_sum(r), '-->', r.shape)

cls_output --> tf.Tensor(-10.324546, shape=(), dtype=float32) --> (2, 768)
token_embeddings --> tf.Tensor(-99.82982, shape=(), dtype=float32) --> (2, 5, 768)
token_logits --> tf.Tensor(-4488.868, shape=(), dtype=float32) --> (2, 5, 30522)
last_token_logits --> tf.Tensor(199.83713, shape=(), dtype=float32) --> (2, 30522)
