diff --git a/tests/ext/modeling.py b/tests/ext/modeling.py index f0d1efb..b4844cc 100644 --- a/tests/ext/modeling.py +++ b/tests/ext/modeling.py @@ -368,8 +368,9 @@ def layer_norm(input_tensor, name=None): epsilon = 1e-12 input_shape = input_tensor.shape - gamma = tf.compat.v1.get_variable(name="gamma", shape=input_shape[-1:], initializer=tf.compat.v1.initializers.ones(), trainable=True) - beta = tf.compat.v1.get_variable(name="beta", shape=input_shape[-1:], initializer=tf.compat.v1.initializers.zeros(), trainable=True) + with tf.compat.v1.variable_scope("LayerNorm"): + gamma = tf.compat.v1.get_variable(name="gamma", shape=input_shape[-1:], initializer=tf.compat.v1.initializers.ones(), trainable=True) + beta = tf.compat.v1.get_variable(name="beta", shape=input_shape[-1:], initializer=tf.compat.v1.initializers.zeros(), trainable=True) x = input_tensor if tf.__version__.startswith("2."):