Skip to content

Commit

Permalink
removing some v1.compat APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
kpe committed Nov 19, 2019
1 parent 6400082 commit 1872731
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
5 changes: 2 additions & 3 deletions bert/embeddings.py
Expand Up @@ -44,9 +44,7 @@ def call(self, inputs, **kwargs):
# that seq_len is less than max_position_embeddings
seq_len = inputs

assert_op = tf.compat.v1.assert_less_equal(seq_len, self.params.max_position_embeddings)
# TODO: TF < v2.0
# assert_op = tf.assert_less_equal(seq_len, self.params.max_position_embeddings)
assert_op = tf.compat.v2.debugging.assert_less_equal(seq_len, self.params.max_position_embeddings)

with tf.control_dependencies([assert_op]):
# slice to seq_len
Expand Down Expand Up @@ -154,6 +152,7 @@ def build(self, input_shape):
input_dim=self.params.extra_tokens_vocab_size + 1, # +1 is for a <pad>/0 vector
output_dim=embedding_size,
mask_zero=self.params.mask_zero,
embeddings_initializer=self.create_initializer(),
name="extra_word_embeddings"
)

Expand Down
5 changes: 1 addition & 4 deletions bert/layer.py
Expand Up @@ -17,10 +17,7 @@ class Params(pf.Layer.Params):
initializer_range = 0.02

def create_initializer(self):
return tf.compat.v1.initializers.truncated_normal(stddev=self.params.initializer_range)
# return tf.compat.v2.initializers.TruncatedNormal(stddev=self.params.initializer_range)
# TODO: TF < v2.0
# return tf.truncated_normal_initializer(stddev=self.params.initializer_range)
return tf.keras.initializers.TruncatedNormal(stddev=self.params.initializer_range)

@staticmethod
def get_activation(activation_string):
Expand Down
4 changes: 2 additions & 2 deletions bert/transformer.py
Expand Up @@ -48,12 +48,12 @@ def build(self, input_shape):

if self.params.adapter_size is not None:
self.adapter_down = keras.layers.Dense(units=self.params.adapter_size,
kernel_initializer=tf.compat.v1.initializers.truncated_normal(
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.params.adapter_init_scale),
activation=self.get_activation(self.params.adapter_activation),
name="adapter-down")
self.adapter_up = keras.layers.Dense(units=self.params.hidden_size,
kernel_initializer=tf.compat.v1.initializers.truncated_normal(
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.params.adapter_init_scale),
name="adapter-up")

Expand Down

0 comments on commit 1872731

Please sign in to comment.