diff --git a/src/transformers/models/cvt/modeling_tf_cvt.py b/src/transformers/models/cvt/modeling_tf_cvt.py index 87ad121380089f..448bfd23028854 100644 --- a/src/transformers/models/cvt/modeling_tf_cvt.py +++ b/src/transformers/models/cvt/modeling_tf_cvt.py @@ -180,8 +180,8 @@ def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: name="convolution", groups=embed_dim, ) - # Using the same default epsilon & momentum as PyTorch - self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization") + # Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum) + self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization") def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: hidden_state = self.convolution(self.padding(hidden_state)) @@ -538,10 +538,11 @@ def __init__(self, config: CvtConfig, stage: int, **kwargs): if self.config.cls_token[self.stage]: self.cls_token = self.add_weight( shape=(1, 1, self.config.embed_dim[-1]), - initializer="zeros", + initializer=get_initializer(self.config.initializer_range), trainable=True, name="cvt.encoder.stages.2.cls_token", ) + self.embedding = TFCvtEmbeddings( self.config, patch_size=config.patch_sizes[self.stage],