Skip to content

Commit

Permalink
[TFWav2Vec2Model] Fix input shapes in TFWav2Vec2WeightNormConv1D (#14319
Browse files Browse the repository at this point in the history
)

* Add paddings to input shapes

* Add padding comment
  • Loading branch information
anton-l authored Nov 8, 2021
1 parent e30078b commit df1f94e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/hubert/modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ def _normalize_kernel(self):

def build(self, input_shape):
if not self.built:
input_shape = input_shape.as_list()
# Conv1D output shapes are checked at build time since TF 2.7, so we need to account for padding
input_shape[-2] += self.explicit_padding * 2
super().build(input_shape)

self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
self.weight_v = self.kernel

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,11 @@ def _normalize_kernel(self):

def build(self, input_shape):
if not self.built:
input_shape = input_shape.as_list()
# Conv1D output shapes are checked at build time since TF 2.7, so we need to account for padding
input_shape[-2] += self.explicit_padding * 2
super().build(input_shape)

self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
self.weight_v = self.kernel

Expand Down

0 comments on commit df1f94e

Please sign in to comment.