diff --git a/bert/embeddings.py b/bert/embeddings.py index 88769a3..9bef90c 100644 --- a/bert/embeddings.py +++ b/bert/embeddings.py @@ -44,9 +44,7 @@ def call(self, inputs, **kwargs): # that seq_len is less than max_position_embeddings seq_len = inputs - assert_op = tf.compat.v1.assert_less_equal(seq_len, self.params.max_position_embeddings) - # TODO: TF < v2.0 - # assert_op = tf.assert_less_equal(seq_len, self.params.max_position_embeddings) + assert_op = tf.compat.v2.debugging.assert_less_equal(seq_len, self.params.max_position_embeddings) with tf.control_dependencies([assert_op]): # slice to seq_len @@ -154,6 +152,7 @@ def build(self, input_shape): input_dim=self.params.extra_tokens_vocab_size + 1, # +1 is for a /0 vector output_dim=embedding_size, mask_zero=self.params.mask_zero, + embeddings_initializer=self.create_initializer(), name="extra_word_embeddings" ) diff --git a/bert/layer.py b/bert/layer.py index cdb7f50..d67fd41 100644 --- a/bert/layer.py +++ b/bert/layer.py @@ -17,10 +17,7 @@ class Params(pf.Layer.Params): initializer_range = 0.02 def create_initializer(self): - return tf.compat.v1.initializers.truncated_normal(stddev=self.params.initializer_range) - # return tf.compat.v2.initializers.TruncatedNormal(stddev=self.params.initializer_range) - # TODO: TF < v2.0 - # return tf.truncated_normal_initializer(stddev=self.params.initializer_range) + return tf.keras.initializers.TruncatedNormal(stddev=self.params.initializer_range) @staticmethod def get_activation(activation_string): diff --git a/bert/transformer.py b/bert/transformer.py index 9dff430..381ebfd 100644 --- a/bert/transformer.py +++ b/bert/transformer.py @@ -48,12 +48,12 @@ def build(self, input_shape): if self.params.adapter_size is not None: self.adapter_down = keras.layers.Dense(units=self.params.adapter_size, - kernel_initializer=tf.compat.v1.initializers.truncated_normal( + kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=self.params.adapter_init_scale), activation=self.get_activation(self.params.adapter_activation), name="adapter-down") self.adapter_up = keras.layers.Dense(units=self.params.hidden_size, - kernel_initializer=tf.compat.v1.initializers.truncated_normal( + kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=self.params.adapter_init_scale), name="adapter-up")