Skip to content

Commit

Permalink
[Whisper] Move decoder id method to tokenizer (#20589)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Dec 5, 2022
1 parent 9ffbed2 commit e7e6d18
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 32 deletions.
32 changes: 1 addition & 31 deletions src/transformers/models/whisper/processing_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,37 +42,7 @@ def __init__(self, feature_extractor, tokenizer):
self._in_target_context_manager = False

def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
forced_decoder_tokens = ""

if language is not None:
if f"<|{language}|>" not in self.tokenizer.additional_special_tokens:
raise ValueError(
f"{language} is not supported. The language should be one of the following: '<|en|>',"
" '<|zh|>', '<|de|>', '<|es|>', '<|ru|>', '<|ko|>', '<|fr|>', '<|ja|>', '<|pt|>', '<|tr|>',"
" '<|pl|>', '<|ca|>', '<|nl|>', '<|ar|>', '<|sv|>', '<|it|>', '<|id|>', '<|hi|>', '<|fi|>',"
" '<|vi|>', '<|iw|>', '<|uk|>', '<|el|>', '<|ms|>', '<|cs|>', '<|ro|>', '<|da|>', '<|hu|>',"
" '<|ta|>', '<|no|>', '<|th|>', '<|ur|>', '<|hr|>', '<|bg|>', '<|lt|>', '<|la|>', '<|mi|>',"
" '<|ml|>', '<|cy|>', '<|sk|>', '<|te|>', '<|fa|>', '<|lv|>', '<|bn|>', '<|sr|>', '<|az|>',"
" '<|sl|>', '<|kn|>', '<|et|>', '<|mk|>', '<|br|>', '<|eu|>', '<|is|>', '<|hy|>', '<|ne|>',"
" '<|mn|>', '<|bs|>', '<|kk|>', '<|sq|>', '<|sw|>', '<|gl|>', '<|mr|>', '<|pa|>', '<|si|>',"
" '<|km|>', '<|sn|>', '<|yo|>', '<|so|>', '<|af|>', '<|oc|>', '<|ka|>', '<|be|>', '<|tg|>',"
" '<|sd|>', '<|gu|>', '<|am|>', '<|yi|>', '<|lo|>', '<|uz|>', '<|fo|>', '<|ht|>', '<|ps|>',"
" '<|tk|>', '<|nn|>', '<|mt|>', '<|sa|>', '<|lb|>', '<|my|>', '<|bo|>', '<|tl|>', '<|mg|>',"
" '<|as|>', '<|tt|>', '<|haw|>', '<|ln|>', '<|ha|>', '<|ba|>', '<|jw|>', '<|su|>'"
)
forced_decoder_tokens += f"<|{language}|>"

if task is not None:
if f"<|{task}|>" not in self.tokenizer.additional_special_tokens:
raise ValueError(
f"'{task}' is not supported. The language should be in : {{'transcribe', 'translate'}}"
)
forced_decoder_tokens += f"<|{task}|>"

forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else ""
ids = self.tokenizer.encode(forced_decoder_tokens, add_special_tokens=False)
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)]
return forced_decoder_ids
return self.tokenizer.get_decoder_prompt_ids(task=task, language=language, no_timestamps=no_timestamps)

def __call__(self, *args, **kwargs):
"""
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,13 @@ def prefix_tokens(self) -> List[int]:
self.language = self.language.lower()
if self.language in TO_LANGUAGE_CODE:
language_id = TO_LANGUAGE_CODE[self.language]
elif self.language in TO_LANGUAGE_CODE.values():
language_id = self.language
else:
is_language_code = len(self.language) == 2
raise ValueError(
f"Unsupported language: {self.language}. Language should be in: {TO_LANGUAGE_CODE.keys()}"
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)

if self.task is not None:
Expand Down Expand Up @@ -577,3 +581,7 @@ def _build_conversation_input_ids(self, conversation) -> List[int]:
if len(input_ids) > self.model_max_length:
input_ids = input_ids[-self.model_max_length :]
return input_ids

def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
self.set_prefix_tokens(task=task, language=language, predict_timestamps=no_timestamps)
return self.prefix_tokens

0 comments on commit e7e6d18

Please sign in to comment.