-
Notifications
You must be signed in to change notification settings - Fork 25.3k
/
tokenization_whisper.py
1325 lines (1176 loc) · 53 KB
/
tokenization_whisper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Whisper."""
import json
import os
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
import regex as re
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json",
"merges_file": "merges.txt",
"normalizer_file": "normalizer.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"openai/whisper-base": "https://huggingface.co/openai/whisper-base/resolve/main/vocab.json",
},
"merges_file": {"openai/whisper-base": "https://huggingface.co/openai/whisper-base/resolve/main/merges_file.txt"},
"normalizer_file": {
"openai/whisper-base": "https://huggingface.co/openai/whisper-base/resolve/main/normalizer.json"
},
}
MAX_MODEL_INPUT_SIZES = {
"openai/whisper-base": 448,
}
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
logger = logging.get_logger(__name__)
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
TASK_IDS = ["translate", "transcribe"]
class WhisperTokenizer(PreTrainedTokenizer):
"""
Construct a Whisper tokenizer.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
the superclass for more information regarding such methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
normalizer_file (`str`, *optional*):
Path to the normalizer_file file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The beginning of sequence token. The `decoder_start_token_id` is used to set the first token as
`"<|startoftranscript|>"` when generating.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
pad_token (`str`, *optional*):
The token used for padding, for example when batching sequences of different lengths.
add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word.
language (`str`, *optional*):
The language of the transcription text. The corresponding language id token is appended to the start of the
sequence for multilingual speech recognition and speech translation tasks, e.g. for Spanish the token
`"<|es|>"` is appended to the start of sequence. This should be used for multilingual fine-tuning only.
task (`str`, *optional*):
Task identifier to append at the start of sequence (if any). This should be used for mulitlingual
fine-tuning, with `"transcribe"` for speech recognition and `"translate"` for speech translation.
predict_timestamps (`bool`, *optional*, defaults to `False`):
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = MAX_MODEL_INPUT_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
normalizer_file=None,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_token=None,
add_prefix_space=False,
language=None,
task=None,
predict_timestamps=False,
**kwargs,
):
bos_token = (
AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(bos_token, str)
else bos_token
)
eos_token = (
AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(eos_token, str)
else eos_token
)
unk_token = (
AddedToken(unk_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(unk_token, str)
else unk_token
)
pad_token = (
AddedToken(pad_token, lstrip=False, rstrip=False, normalized=False, special=True)
if isinstance(pad_token, str)
else pad_token
)
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space
if normalizer_file is not None:
with open(normalizer_file, encoding="utf-8") as vocab_handle:
self.english_spelling_normalizer = json.load(vocab_handle)
else:
self.english_spelling_normalizer = None
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
self.language = language
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
**kwargs,
)
self.task = task
self.predict_timestamps = predict_timestamps
@property
def vocab_size(self) -> int:
return len(self.encoder)
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe with GPT2 -> Whisper
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None):
"""
Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to
update the prefix tokens as required when fine-tuning. Example:
```python
>>> # instantiate the tokenizer and set the prefix token to Spanish
>>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
>>> # now switch the prefix token from Spanish to French
>>> tokenizer.set_prefix_tokens(language="french")
```
Args:
language (`str`, *optional*, defaults to `None`):
The language of the transcription text.
task (`str`, *optional*, defaults to `None`):
Task identifier to append at the start of sequence (if any).
predict_timestamps (`bool`, *optional*, defaults to `None`):
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
"""
self.language = language if language is not None else self.language
self.task = task if task is not None else self.task
self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps
@property
def prefix_tokens(self) -> List[int]:
bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
translate_token_id = self.convert_tokens_to_ids("<|translate|>")
transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>")
notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>")
langs = tuple(LANGUAGES.keys())
if self.language is not None:
self.language = self.language.lower()
if self.language in TO_LANGUAGE_CODE:
language_id = TO_LANGUAGE_CODE[self.language]
elif self.language in TO_LANGUAGE_CODE.values():
language_id = self.language
else:
is_language_code = len(self.language) == 2
raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)
if self.task is not None:
if self.task not in TASK_IDS:
raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
bos_sequence = [bos_token_id]
if self.language is not None:
bos_sequence.append(bos_token_id + 1 + langs.index(language_id))
if self.task is not None:
bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
if not self.predict_timestamps:
bos_sequence.append(notimestamps_token_id)
return bos_sequence
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + [self.eos_token_id]
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id]
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1]
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id with GPT2 -> Whisper
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""
Converts an index (integer) in a token (str) using the vocab. Whisper's base tokenizer always decodes OOV
tokens as "", thus we do not use the `unk_token` here.
"""
return self.decoder.get(index, "")
def _normalize(self, text):
"""
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on
english text.
"""
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)
@staticmethod
def _basic_normalize(text, remove_diacritics=False):
"""
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
multilingual text.
"""
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
return normalizer(text)
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
timestamp_begin = self.all_special_ids[-1] + 1
outputs = [[]]
cur_max_timestamp = 0.0
prev_segments_len = 0.0
for token in token_ids:
if token >= timestamp_begin:
timestamp = float((token - timestamp_begin) * time_precision)
if timestamp < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
cur_max_timestamp = timestamp
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
outputs.append([])
else:
outputs[-1].append(token)
outputs = [
s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
]
return "".join(outputs)
def _compute_offsets(self, token_ids, time_precision=0.02):
"""
Compute offsets for a given tokenized input
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
offsets = []
# ensure torch tensor of token ids is placed on cpu
if "torch" in str(type(token_ids)) and (hasattr(token_ids, "cpu") and callable(token_ids.cpu)):
token_ids = token_ids.cpu()
token_ids = np.array(token_ids)
if token_ids.shape[0] > 1 and len(token_ids.shape) > 1:
raise ValueError("Can only process a single input at a time")
timestamp_begin = self.all_special_ids[-1] + 1
timestamp_tokens = token_ids >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
if consecutive.shape[0] == 0 and timestamp_tokens.sum() <= 1:
# either there are no timestamps or there are no consecutive ones
return []
elif np.where(timestamp_tokens)[0][-1] + 1 not in consecutive:
# we add the final timestamp if it is not already in the list
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
last_slice = np.where(timestamp_tokens)[0][0]
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
text = self._filter_timestamp_ids(text)
offsets.append(
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
),
}
)
last_slice = current_slice
return offsets
@lru_cache
def timestamp_ids(self, time_precision=0.02):
"""
Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache.
Args:
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
"""
return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
"""
Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
removed.
"""
if skip_special_tokens:
prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
return token_ids
def _filter_timestamp_ids(self, token_ids):
return re.sub(self.timestamp_pat, "", token_ids)
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
output_offsets: bool = False,
time_precision: float = 0.02,
decode_with_timestamps: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
) -> str:
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
tokens and clean up tokenization spaces.
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
Args:
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (`bool`, *optional*):
Whether or not to clean up the tokenization spaces. If `None`, will default to
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
output_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
timestamps.
time_precision (`float`, `optional`, defaults to 0.02):
The time ratio to convert from token to time.
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
Whether or not to decode with timestamps included in the raw text.
normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
target text is in English. Otherwise, the basic text normalizer should be applied.
basic_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
target text.
remove_diacritics (`bool`, *optional*, defaults to `False`):
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
destroy information in the decoded text, hence it should be used with caution.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific decode method.
Returns:
`str`: The decoded sentence.
"""
filtered_ids = self._preprocess_token_ids(
token_ids,
skip_special_tokens=skip_special_tokens,
)
text = super().decode(
filtered_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
normalize=normalize,
basic_normalize=basic_normalize,
remove_diacritics=remove_diacritics,
**kwargs,
)
if decode_with_timestamps:
# legacy method to decode timestamps when not included in the tokenizer vocabulary
text = self._decode_with_timestamps(
filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
)
else:
text = self._filter_timestamp_ids(text)
# retrieve offsets
if output_offsets:
offsets = self._compute_offsets(token_ids, time_precision=time_precision)
return {"text": text, "offsets": offsets}
return text
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
normalize: bool = False,
basic_normalize: bool = False,
remove_diacritics: bool = False,
**kwargs,
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = "".join(sub_texts)
if normalize:
clean_text = self._normalize(text)
return clean_text
elif basic_normalize:
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
return clean_text
else:
return text
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string with GPT2 -> Whisper
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
normalizer_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["normalizer_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
if self.english_spelling_normalizer is not None:
with open(normalizer_file, "w", encoding="utf-8") as f:
f.write(
json.dumps(self.english_spelling_normalizer, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
)
return vocab_file, merge_file, normalizer_file
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.prepare_for_tokenization with GPT2 -> Whisper
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if is_split_into_words or add_prefix_space:
text = " " + text
return (text, kwargs)
@property
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
def default_chat_template(self):
"""
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
"""
logger.warning_once(
"\nNo chat template is defined for this tokenizer - using the default template "
f"for the {self.__class__.__name__} class. If the default is not appropriate for "
"your model, please set `tokenizer.chat_template` to an appropriate template. "
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
)
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
# prefix tokens are of the form: <|startoftranscript|> <|lang_id|> <|task|> <|notimestamps|>
# we don't want to force the bos token at position 1, as this is the starting token
# when we generate, so we slice the prefix tokens to: <|lang_id|> <|task|> <|notimestamps|>
# to get the forced tokens
forced_tokens = self.prefix_tokens[1:]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
return forced_decoder_ids
def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision):
return _decode_asr(
self,
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
def get_prompt_ids(self, text: str, return_tensors="np"):
"""Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False)
# Check for special tokens
prompt_text_ids = batch_encoding["input_ids"][1:]
special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
if special_token_id is not None:
token = self.convert_ids_to_tokens(special_token_id)
raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
return batch_encoding["input_ids"]
@staticmethod
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
else:
return []
return token_ids
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
"""
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
the various options not allowed in other seq2seq models
"""
# =========== Overview ============
# - iterate over all outputs
# - all tokens within output
# - Each token can be
# - language token
# - special token
# - timestamp token
# - text token
# - We accumulate the text tokens.
# - We split on end timestamps
# - Lots of complexity comes from stride and timestamps
last_language = None
def new_chunk():
return {"language": last_language, "timestamp": [None, None], "text": ""}
# Welcome to the state machine !
chunks = []
chunk = new_chunk()
time_offset = 0.0
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
previous_tokens = []
previous_token_timestamps = []
skip = False
right_stride_start = None
all_special_ids = set(tokenizer.all_special_ids)
# - iterate over all outputs
for chunk_id, output in enumerate(model_outputs):
# We can drop everything to Python list, it's going to make
# our lives easier
token_ids = output["tokens"][0].tolist()
if return_timestamps == "word":
token_timestamps = output["token_timestamps"][0].tolist()
# Those keep track of timestamps within strides
# Which need to be skipped and resolve all tokens in a single
# chunk.
last_timestamp = None
first_timestamp = timestamp_begin
if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Offset the timings to account for the other `model_outputs`.
time_offset -= stride_left
right_stride_start = chunk_len - stride_right
# Keeping track of timestamps within strides
# We're going to NOT split on those, and delay until we're
# out of BOTH stride. Otherwise lots of issues occur and
# corner cases
if stride_left:
first_timestamp = stride_left / time_precision + timestamp_begin
if stride_right:
for token in reversed(token_ids):
if token >= timestamp_begin:
# There can be several token in the right stride
# But the last one is ALWAYS going to be skipped
if (
last_timestamp is not None
and (token - timestamp_begin) * time_precision < right_stride_start
):
break
last_timestamp = token
current_tokens = []
current_token_timestamps = []
# - all tokens within output
for i, token in enumerate(token_ids):
# 4 possible states for each token
# - 1/ Language code
# - 2/ all other special tokens (which we ignore)
# - 3/ Timestamp
# - 4/ Regular text
if token in all_special_ids:
# Either language code or other
text = tokenizer.decode([token])
# Removing outer shell <|XX|>
text = text[2:-2]
language = LANGUAGES.get(text, None)
if language is not None:
# 1/ Indeed some language
# TODO Handle when language is different from the previous
# one, and we cannot use timestamped tokens to create chunks
if last_language and language != last_language and not return_timestamps:
previous_tokens.append(current_tokens)
resolved_tokens = _find_longest_common_sequence(previous_tokens)
resolved_text = tokenizer.decode(resolved_tokens)
chunk["text"] = resolved_text
chunks.append(chunk)
# Flush all our temporary context
previous_tokens = []
current_tokens = []
chunk = new_chunk()
chunk["language"] = language
last_language = language
else:
# 2/ This is a regular special token, ignoring it
pass
elif token >= timestamp_begin:
# 3/ Timestamp token
time = (token - timestamp_begin) * time_precision + time_offset
time = round(time, 2)
if last_timestamp and token >= last_timestamp:
# Whisper outputted a timestamp token, but it falls within
# our stride, so we're going to skip it for the time being
# and resolve this later
# Skip is necessary because timestamp tokens always come
# by pair, so we need to skip the next one too (which would mark the start of another chunk).
skip = True
elif skip or (previous_tokens and token < first_timestamp):
skip = False
elif chunk["timestamp"][0] is None:
chunk["timestamp"][0] = time
else:
# This is the end of the timestamp chunk
if time == chunk["timestamp"][0]:
# This is a bug in timestamp token output
# where we're taking the duplicate token
# as a stop where it should be a start.
# This is an issue in the underlying model output
# Let's just skip it so it becomes de-factor
# a start agin
pass
else: