Skip to content

Commit

Permalink
Fix skip_special_tokens for Wav2Vec2CTCTokenizer._decode (#29311)
Browse files Browse the repository at this point in the history
* Fix skip_special_tokens process for Wav2Vec2CTCTokenizer._decode

* Fix skip_special_tokens for Wav2Vec2CTCTokenizer._decode

* Exclude pad_token filtering since it is used as CTC-blank token

* Add small test for skip_special_tokens

* Update decoding test for added new token
  • Loading branch information
msublee committed Apr 2, 2024
1 parent cb5927c commit 15cd687
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/transformers/models/wav2vec2/tokenization_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput):


class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):

"""
Constructs a Wav2Vec2CTC tokenizer.
Expand Down Expand Up @@ -420,7 +419,9 @@ def _decode(

result = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
if skip_special_tokens and (
token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens)
):
continue
result.append(token)

Expand Down Expand Up @@ -881,7 +882,9 @@ def _decode(

result = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
if skip_special_tokens and (
token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens)
):
continue
result.append(token)

Expand Down
13 changes: 9 additions & 4 deletions tests/models/wav2vec2/test_tokenization_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Wav2Vec2 tokenizer."""

import inspect
import json
import os
Expand Down Expand Up @@ -144,8 +145,10 @@ def test_tokenizer_decode_added_tokens(self):
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
]
batch_tokens = tokenizer.batch_decode(sample_ids)
batch_tokens_2 = tokenizer.batch_decode(sample_ids, skip_special_tokens=True)

self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
self.assertEqual(batch_tokens_2, ["HELO!?!?", "BYE BYE"])

def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
Expand Down Expand Up @@ -452,18 +455,20 @@ def test_tokenizer_decode_special(self):

def test_tokenizer_decode_added_tokens(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer.add_tokens(["!", "?"])
tokenizer.add_tokens(["!", "?", "<new_tokens>"])
tokenizer.add_special_tokens({"cls_token": "$$$"})

# fmt: off
sample_ids = [
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34, 35, 35],
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34, 35, 35],
]
# fmt: on
batch_tokens = tokenizer.batch_decode(sample_ids)
batch_tokens_2 = tokenizer.batch_decode(sample_ids, skip_special_tokens=True)

self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?<new_tokens>$$$", "BYE BYE<unk><new_tokens>$$$"])
self.assertEqual(batch_tokens_2, ["HELO!?!?<new_tokens>", "BYE BYE<new_tokens>"])

def test_special_characters_in_vocab(self):
sent = "ʈʰ æ æ̃ ˧ kʰ"
Expand Down

0 comments on commit 15cd687

Please sign in to comment.