-
Notifications
You must be signed in to change notification settings - Fork 55
/
tokenizer.py
100 lines (85 loc) · 3.22 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from typing import Optional, Set, final
from fairseq2.data.text import SentencePieceEncoder, SentencePieceTokenizer
from fairseq2.typing import Device, override
@final
class S2TTransformerTokenizer(SentencePieceTokenizer):
"""Represents an S2T Transformer tokenizer."""
_task: str
_target_langs: Set[str]
_default_target_lang: str
def __init__(
self,
path: Path,
task: str,
target_langs: Set[str],
default_target_lang: str,
) -> None:
"""
:param path:
The path to the SentencePiece model file.
:param task:
The task for which to generate token indices. The valid values are
'transcription' and 'translation'.
:param target_langs:
The list of supported target languages.
:param default_target_lang:
The fall-back language if no target language is specified.
"""
super().__init__(path)
if task != "transcription" and task != "translation":
raise ValueError(
f"`task` must be 'transcripton' or 'translation', but is '{task}' instead."
)
self._task = task
self._target_langs = target_langs
self._default_target_lang = default_target_lang
@override
def create_encoder(
self,
*,
task: Optional[str] = None,
lang: Optional[str] = None,
mode: Optional[str] = None,
device: Optional[Device] = None,
pin_memory: bool = False,
) -> SentencePieceEncoder:
"""Create a token encoder.
:param task:
Must match :attr:`task`. If ``None``, defaults to :attr:`task`.
:param lang:
A language from :attr:`target_langs`. If ``None``, defaults to
:attr:`default_target_lang`.
:param mode:
Must be 'target'. If ``None``, defaults to 'target'.
:param device:
The device on which to construct tensors.
:param pin_memory:
If ``True``, uses pinned memory while constructing tensors.
"""
if task is not None and task != self._task:
raise ValueError(f"`task` must be '{self._task}', but is '{task}' instead.")
if mode is not None and mode != "target":
raise ValueError(f"`mode` must be 'target', but is '{mode}' instead.")
if lang is None:
lang = self._default_target_lang
if lang not in self._target_langs:
raise ValueError(
f"`lang` must be a supported language, but is '{lang}' instead."
)
# For multilingual speech translation we prepend the language token.
if self._task == "translation" and len(self._target_langs) > 1:
prefix_tokens = ["</s>", f"<lang:{lang}>"]
else:
prefix_tokens = ["</s>"]
return SentencePieceEncoder(
self._model,
prefix_tokens=prefix_tokens,
device=device,
pin_memory=pin_memory,
)