In [1]:
import math

import bert
import tensorflow as tf
from bert import BertModelLayer, bert_tokenization
from bert.loader import StockBertConfig, load_stock_weights, map_stock_config_to_params
from tensorflow import keras

In [2]:
# def flatten_layers(root_layer):
#     if isinstance(root_layer, keras.layers.Layer):
#         yield root_layer
#     for layer in root_layer._layers:
#         for sub_layer in flatten_layers(layer):
#             yield sub_layer


# def freeze_bert_layers(l_bert):
#     """
#     Freezes all but LayerNorm and adapter layers as per https://arxiv.org/abs/1902.00751
#     @see https://arxiv.org/abs/1902.00751
#     """
#     for layer in flatten_layers(l_bert):
#         if layer.name in ["LayerNorm", "adapter-down", "adapter-up"]:
#             layer.trainable = True
#         elif len(layer._layers) == 0:
#             layer.trainable = False
#         l_bert.embeddings_layer.trainable = False


#  as per https://arxiv.org/abs/1902.00751
def create_learning_rate_scheduler(
    max_learn_rate=5e-5,
    end_learn_rate=1e-7,
    warmup_epoch_count=10,
    total_epoch_count=90,
):
    """
    Leartning rate scheduler as per https://arxiv.org/abs/1902.00751
    @see https://arxiv.org/abs/1902.00751
    """
    def lr_scheduler(epoch):
        if epoch < warmup_epoch_count:
            res = (max_learn_rate / warmup_epoch_count) * (epoch + 1)
        else:
            res = max_learn_rate * math.exp(
                math.log(end_learn_rate / max_learn_rate)
                * (epoch - warmup_epoch_count + 1)
                / (total_epoch_count - warmup_epoch_count + 1)
            )
        return float(res)

    learning_rate_scheduler = tf.keras.callbacks.LearningRateScheduler(
        lr_scheduler, verbose=1
    )

    return learning_rate_scheduler

In [3]:
bert_config_file = "/app/_data/bert/bert_config.json"
bert_ckpt_file = "/app/_data/bert/bert_model.ckpt"
max_seq_len=128
adapter_size=None


# create the bert layer
with tf.io.gfile.GFile(bert_config_file, "r") as reader:
    bc = StockBertConfig.from_json_string(reader.read())
    print(bc)
    bert_params = map_stock_config_to_params(bc)
    bert_params.adapter_size = adapter_size
    bert = BertModelLayer.from_params(bert_params, name="bert")

input_token_ids = keras.layers.Input(
    shape=(max_seq_len,), dtype="int32", name="input_ids"
)

# token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="token_type_ids")
# output         = bert([input_ids, token_type_ids])
x = bert(input_token_ids)


print("bert shape", x.shape)

x = keras.layers.Lambda(lambda seq: seq[:, 0, :])(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(units=768, activation="relu")(x)
x = keras.layers.Dropout(0.5)(x)
x = keras.layers.Dense(units=6, activation="sigmoid")(x)

model = keras.Model(inputs=input_token_ids, outputs=x)
model.build(input_shape=(None, max_seq_len))

# load the pre-trained model weights
load_stock_weights(bert, bert_ckpt_file)

# # freeze weights if adapter-BERT is used
# if adapter_size is not None:
#     freeze_bert_layers(bert)

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

model.summary()

{'attention_probs_dropout_prob': 0.1, 'hidden_act': 'gelu', 'hidden_dropout_prob': 0.1, 'hidden_size': 768, 'initializer_range': 0.02, 'intermediate_size': 3072, 'max_position_embeddings': 512, 'num_attention_heads': 12, 'num_hidden_layers': 12, 'type_vocab_size': 2, 'vocab_size': 30522, 'ln_type': None, 'embedding_size': None}
bert shape (None, 128, 768)
Done loading 196 BERT weights from: /app/_data/bert/bert_model.ckpt into <bert.model.BertModelLayer object at 0x7f513dba2f70> (prefix:bert). Count of weights not found in the checkpoint was: [0]. Count of weights with mismatched shape: [0]
Unused weights from checkpoint: 
	bert/embeddings/token_type_embeddings
	bert/pooler/dense/bias
	bert/pooler/dense/kernel
	cls/predictions/output_bias
	cls/predictions/transform/LayerNorm/beta
	cls/predictions/transform/LayerNorm/gamma
	cls/predictions/transform/dense/bias
	cls/predictions/transform/dense/kernel
	cls/seq_relationship/output_bias
	cls/seq_relationship/output_weights
Model: "functiona

In [4]:
import numpy as np
a = np.array([[[1],[11],[111]], [[2],[22],[222]]])
a[:,0,:]

array([[1],
       [2]])