-
Notifications
You must be signed in to change notification settings - Fork 113
/
configurations.py
369 lines (307 loc) · 16.4 KB
/
configurations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# MIT License
#
# Copyright (c) 2021 Soohwan Kim and Sangchun Ha and Soyoung Cho
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from dataclasses import _MISSING_TYPE, dataclass, field
from typing import Any, List, Optional
from omegaconf import MISSING
@dataclass
class OpenspeechDataclass:
"""OpenSpeech base dataclass that supported fetching attributes and metas"""
def _get_all_attributes(self) -> List[str]:
return [k for k in self.__dataclass_fields__.keys()]
def _get_meta(self, attribute_name: str, meta: str, default: Optional[Any] = None) -> Any:
return self.__dataclass_fields__[attribute_name].metadata.get(meta, default)
def _get_name(self, attribute_name: str) -> str:
return self.__dataclass_fields__[attribute_name].name
def _get_default(self, attribute_name: str) -> Any:
if hasattr(self, attribute_name):
if str(getattr(self, attribute_name)).startswith("${"):
return str(getattr(self, attribute_name))
elif str(self.__dataclass_fields__[attribute_name].default).startswith("${"):
return str(self.__dataclass_fields__[attribute_name].default)
elif getattr(self, attribute_name) != self.__dataclass_fields__[attribute_name].default:
return getattr(self, attribute_name)
f = self.__dataclass_fields__[attribute_name]
if not isinstance(f.default_factory, _MISSING_TYPE):
return f.default_factory()
return f.default
def _get_type(self, attribute_name: str) -> Any:
return self.__dataclass_fields__[attribute_name].type
def _get_help(self, attribute_name: str) -> Any:
return self._get_meta(attribute_name, "help")
@dataclass
class LibriSpeechConfigs(OpenspeechDataclass):
"""Configuration dataclass that common used"""
dataset: str = field(
default="librispeech", metadata={"help": "Select dataset for training (librispeech, ksponspeech, aishell, lm)"}
)
dataset_path: str = field(default=MISSING, metadata={"help": "Path of dataset"})
dataset_download: bool = field(
default=True, metadata={"help": "Flag indication whether to download dataset or not."}
)
manifest_file_path: str = field(
default="../../../LibriSpeech/libri_subword_manifest.txt", metadata={"help": "Path of manifest file"}
)
@dataclass
class KsponSpeechConfigs(OpenspeechDataclass):
"""Configuration dataclass that common used"""
dataset: str = field(
default="ksponspeech", metadata={"help": "Select dataset for training (librispeech, ksponspeech, aishell, lm)"}
)
dataset_path: str = field(default=MISSING, metadata={"help": "Path of dataset"})
test_dataset_path: str = field(default=MISSING, metadata={"help": "Path of evaluation dataset"})
manifest_file_path: str = field(
default="../../../ksponspeech_manifest.txt", metadata={"help": "Path of manifest file"}
)
test_manifest_dir: str = field(
default="../../../KsponSpeech_scripts", metadata={"help": "Path of directory contains test manifest files"}
)
preprocess_mode: str = field(
default="phonetic",
metadata={"help": "KsponSpeech preprocess mode {phonetic, spelling}"},
)
@dataclass
class AIShellConfigs(OpenspeechDataclass):
"""Configuration dataclass that common used"""
dataset: str = field(
default="aishell", metadata={"help": "Select dataset for training (librispeech, ksponspeech, aishell, lm)"}
)
dataset_path: str = field(default=MISSING, metadata={"help": "Path of dataset"})
dataset_download: bool = field(
default=True, metadata={"help": "Flag indication whether to download dataset or not."}
)
manifest_file_path: str = field(
default="../../../data_aishell/aishell_manifest.txt", metadata={"help": "Path of manifest file"}
)
@dataclass
class LMConfigs(OpenspeechDataclass):
dataset: str = field(
default="lm", metadata={"help": "Select dataset for training (librispeech, ksponspeech, aishell, lm)"}
)
dataset_path: str = field(default=MISSING, metadata={"help": "Path of dataset"})
valid_ratio: float = field(default=0.05, metadata={"help": "Ratio of validation data"})
test_ratio: float = field(default=0.05, metadata={"help": "Ratio of test data"})
@dataclass
class AugmentConfigs(OpenspeechDataclass):
apply_spec_augment: bool = field(
default=False, metadata={"help": "Flag indication whether to apply spec augment or not"}
)
apply_noise_augment: bool = field(
default=False,
metadata={
"help": "Flag indication whether to apply noise augment or not "
"Noise augment requires `noise_dataset_path`. "
"`noise_dataset_dir` should be contain audio files."
},
)
apply_joining_augment: bool = field(
default=False,
metadata={
"help": "Flag indication whether to apply joining augment or not "
"If true, create a new audio file by connecting two audio randomly"
},
)
apply_time_stretch_augment: bool = field(
default=False, metadata={"help": "Flag indication whether to apply spec augment or not"}
)
freq_mask_para: int = field(
default=27, metadata={"help": "Hyper Parameter for freq masking to limit freq masking length"}
)
freq_mask_num: int = field(default=2, metadata={"help": "How many freq-masked area to make"})
time_mask_num: int = field(default=4, metadata={"help": "How many time-masked area to make"})
noise_dataset_dir: str = field(default="None", metadata={"help": "How many time-masked area to make"})
noise_level: float = field(default=0.7, metadata={"help": "Noise adjustment level"})
time_stretch_min_rate: float = field(default=0.7, metadata={"help": "Minimum rate of audio time stretch"})
time_stretch_max_rate: float = field(default=1.4, metadata={"help": "Maximum rate of audio time stretch"})
@dataclass
class BaseTrainerConfigs(OpenspeechDataclass):
"""Base trainer dataclass"""
seed: int = field(default=1, metadata={"help": "Seed for training."})
accelerator: str = field(
default="dp", metadata={"help": "Previously known as distributed_backend (dp, ddp, ddp2, etc…)."}
)
accumulate_grad_batches: int = field(
default=1, metadata={"help": "Accumulates grads every k batches or as set up in the dict."}
)
num_workers: int = field(default=4, metadata={"help": "The number of cpu cores"})
batch_size: int = field(default=32, metadata={"help": "Size of batch"})
check_val_every_n_epoch: int = field(default=1, metadata={"help": "Check val every n train epochs."})
gradient_clip_val: float = field(default=5.0, metadata={"help": "0 means don’t clip."})
logger: str = field(default="wandb", metadata={"help": "Training logger. {wandb, tensorboard}"})
max_epochs: int = field(default=20, metadata={"help": "Stop training once this number of epochs is reached."})
save_checkpoint_n_steps: int = field(default=10000, metadata={"help": "Save a checkpoint every N steps."})
auto_scale_batch_size: str = field(
default="binsearch",
metadata={
"help": "If set to True, will initially run a batch size finder trying to find "
"the largest batch size that fits into memory."
},
)
sampler: str = field(
default="else", metadata={"help": "smart: batching with similar sequence length." "else: random batch"}
)
@dataclass
class CPUResumeTrainerConfigs(BaseTrainerConfigs):
name: str = field(default="cpu-resume", metadata={"help": "Trainer name"})
checkpoint_path: str = field(default=MISSING, metadata={"help": "Path of model checkpoint."})
device: str = field(default="cpu", metadata={"help": "Training device."})
use_cuda: bool = field(default=False, metadata={"help": "If set True, will train with GPU"})
@dataclass
class GPUResumeTrainerConfigs(BaseTrainerConfigs):
name: str = field(default="gpu-resume", metadata={"help": "Trainer name"})
checkpoint_path: str = field(default=MISSING, metadata={"help": "Path of model checkpoint."})
device: str = field(default="gpu", metadata={"help": "Training device."})
use_cuda: bool = field(default=True, metadata={"help": "If set True, will train with GPU"})
auto_select_gpus: bool = field(
default=True, metadata={"help": "If enabled and gpus is an integer, pick available gpus automatically."}
)
@dataclass
class TPUResumeTrainerConfigs(BaseTrainerConfigs):
name: str = field(default="tpu-resume", metadata={"help": "Trainer name"})
checkpoint_path: str = field(default=MISSING, metadata={"help": "Path of model checkpoint."})
device: str = field(default="tpu", metadata={"help": "Training device."})
use_cuda: bool = field(default=False, metadata={"help": "If set True, will train with GPU"})
use_tpu: bool = field(default=True, metadata={"help": "If set True, will train with GPU"})
tpu_cores: int = field(default=8, metadata={"help": "Number of TPU cores"})
@dataclass
class CPUTrainerConfigs(BaseTrainerConfigs):
name: str = field(default="cpu", metadata={"help": "Trainer name"})
device: str = field(default="cpu", metadata={"help": "Training device."})
use_cuda: bool = field(default=False, metadata={"help": "If set True, will train with GPU"})
@dataclass
class GPUTrainerConfigs(BaseTrainerConfigs):
"""GPU trainer dataclass"""
name: str = field(default="gpu", metadata={"help": "Trainer name"})
device: str = field(default="gpu", metadata={"help": "Training device."})
use_cuda: bool = field(default=True, metadata={"help": "If set True, will train with GPU"})
auto_select_gpus: bool = field(
default=True, metadata={"help": "If enabled and gpus is an integer, pick available gpus automatically."}
)
@dataclass
class TPUTrainerConfigs(BaseTrainerConfigs):
name: str = field(default="tpu", metadata={"help": "Trainer name"})
device: str = field(default="tpu", metadata={"help": "Training device."})
use_cuda: bool = field(default=False, metadata={"help": "If set True, will train with GPU"})
use_tpu: bool = field(default=True, metadata={"help": "If set True, will train with GPU"})
tpu_cores: int = field(default=8, metadata={"help": "Number of TPU cores"})
@dataclass
class Fp16GPUTrainerConfigs(GPUTrainerConfigs):
name: str = field(default="gpu-fp16", metadata={"help": "Trainer name"})
precision: int = field(
default=16,
metadata={
"help": "Double precision (64), full precision (32) or half precision (16). "
"Can be used on CPU, GPU or TPUs."
},
)
amp_backend: str = field(
default="apex", metadata={"help": "The mixed precision backend to use (“native” or “apex”)"}
)
@dataclass
class Fp16TPUTrainerConfigs(TPUTrainerConfigs):
name: str = field(default="tpu-fp16", metadata={"help": "Trainer name"})
precision: int = field(
default=16,
metadata={
"help": "Double precision (64), full precision (32) or half precision (16). "
"Can be used on CPU, GPU or TPUs."
},
)
amp_backend: str = field(
default="apex", metadata={"help": "The mixed precision backend to use (“native” or “apex”)"}
)
@dataclass
class Fp64CPUTrainerConfigs(CPUTrainerConfigs):
name: str = field(default="cpu-fp64", metadata={"help": "Trainer name"})
precision: int = field(
default=64,
metadata={
"help": "Double precision (64), full precision (32) or half precision (16). "
"Can be used on CPU, GPU or TPUs."
},
)
amp_backend: str = field(
default="apex", metadata={"help": "The mixed precision backend to use (“native” or “apex”)"}
)
@dataclass
class LearningRateSchedulerConfigs(OpenspeechDataclass):
"""Super class of learning rate dataclass"""
lr: float = field(default=1e-04, metadata={"help": "Learning rate"})
@dataclass
class TokenizerConfigs(OpenspeechDataclass):
"""Super class of tokenizer dataclass"""
sos_token: str = field(default="<sos>", metadata={"help": "Start of sentence token"})
eos_token: str = field(default="<eos>", metadata={"help": "End of sentence token"})
pad_token: str = field(default="<pad>", metadata={"help": "Pad token"})
blank_token: str = field(default="<blank>", metadata={"help": "Blank token (for CTC training)"})
encoding: str = field(default="utf-8", metadata={"help": "Encoding of vocab"})
@dataclass
class EvaluationConfigs(OpenspeechDataclass):
use_cuda: bool = field(default=True, metadata={"help": "If set True, will evaluate with GPU"})
dataset_path: str = field(default=MISSING, metadata={"help": "Path of dataset."})
checkpoint_path: str = field(default=MISSING, metadata={"help": "Path of model checkpoint."})
manifest_file_path: str = field(default=MISSING, metadata={"help": "Path of evaluation manifest file."})
result_path: str = field(default=MISSING, metadata={"help": "Path of evaluation result file."})
num_workers: int = field(default=4, metadata={"help": "Number of worker."})
batch_size: int = field(default=32, metadata={"help": "Batch size."})
beam_size: int = field(default=1, metadata={"help": "Beam size of beam search."})
@dataclass
class EnsembleEvaluationConfigs(OpenspeechDataclass):
use_cuda: bool = field(default=True, metadata={"help": "If set True, will evaluate with GPU"})
dataset_paths: str = field(default=MISSING, metadata={"help": "Path of dataset."})
checkpoint_paths: str = field(default=MISSING, metadata={"help": "List of model checkpoint path."})
manifest_file_path: str = field(default=MISSING, metadata={"help": "Path of evaluation manifest file."})
ensemble_method: str = field(default="vanilla", metadata={"help": "Method of ensemble (vanilla, weighted)"})
ensemble_weights: str = field(default="(1.0, 1.0, 1.0 ..)", metadata={"help": "Weights of ensemble models."})
num_workers: int = field(default=4, metadata={"help": "Number of worker."})
batch_size: int = field(default=32, metadata={"help": "Batch size."})
beam_size: int = field(default=1, metadata={"help": "Beam size of beam search."})
def generate_openspeech_configs_with_help():
from openspeech.criterion import CRITERION_DATACLASS_REGISTRY
from openspeech.data import AUDIO_FEATURE_TRANSFORM_DATACLASS_REGISTRY
from openspeech.dataclass import (
AUGMENT_DATACLASS_REGISTRY,
DATASET_DATACLASS_REGISTRY,
OPENSPEECH_TRAIN_CONFIGS,
TRAINER_DATACLASS_REGISTRY,
)
from openspeech.models import MODEL_DATACLASS_REGISTRY
from openspeech.optim.scheduler import SCHEDULER_DATACLASS_REGISTRY
from openspeech.tokenizers import TOKENIZER_DATACLASS_REGISTRY
registries = {
"audio": AUDIO_FEATURE_TRANSFORM_DATACLASS_REGISTRY,
"augment": AUGMENT_DATACLASS_REGISTRY,
"trainer": TRAINER_DATACLASS_REGISTRY,
"model": MODEL_DATACLASS_REGISTRY,
"criterion": CRITERION_DATACLASS_REGISTRY,
"dataset": DATASET_DATACLASS_REGISTRY,
"lr_scheduler": SCHEDULER_DATACLASS_REGISTRY,
"tokenizer": TOKENIZER_DATACLASS_REGISTRY,
}
with open("configuration.md", "w") as f:
for group in OPENSPEECH_TRAIN_CONFIGS:
dataclass_registry = registries[group]
f.write(f"## `{group}`\n")
for k, v in dataclass_registry.items():
f.write(f"### `{k}` \n")
v = v()
for kv in v.__dataclass_fields__:
f.write(f"- `{kv}` : {v._get_help(kv)}\n")