Skip to content

Commit

Permalink
corrected momentum value + cls_token initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
mathieujouffroy committed Oct 7, 2022
1 parent 1978d26 commit 6d2770a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/cvt/modeling_tf_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride:
name="convolution",
groups=embed_dim,
)
# Using the same default epsilon & momentum as PyTorch
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
# Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum)
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")

def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = self.convolution(self.padding(hidden_state))
Expand Down Expand Up @@ -538,10 +538,11 @@ def __init__(self, config: CvtConfig, stage: int, **kwargs):
if self.config.cls_token[self.stage]:
self.cls_token = self.add_weight(
shape=(1, 1, self.config.embed_dim[-1]),
initializer="zeros",
initializer=get_initializer(self.config.initializer_range),
trainable=True,
name="cvt.encoder.stages.2.cls_token",
)

self.embedding = TFCvtEmbeddings(
self.config,
patch_size=config.patch_sizes[self.stage],
Expand Down

0 comments on commit 6d2770a

Please sign in to comment.