-
Notifications
You must be signed in to change notification settings - Fork 205
/
speech_synthesis.py
90 lines (72 loc) · 2.92 KB
/
speech_synthesis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Callable, Dict, List, Sequence, Union
import torch
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.collation import TokenCollater, collate_audio
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import ifnone
class SpeechSynthesisDataset(torch.utils.data.Dataset):
"""
The PyTorch Dataset for the speech synthesis task.
Each item in this dataset is a dict of:
.. code-block::
{
'audio': (B x NumSamples) float tensor
'features': (B x NumFrames x NumFeatures) float tensor
'tokens': (B x NumTokens) long tensor
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'tokens_lens': (B, ) int tensor
'speakers': List[str] of len B (optional) # if return_spk_ids is True
}
"""
def __init__(
self,
cuts: CutSet,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
add_eos: bool = True,
add_bos: bool = True,
return_spk_ids: bool = False,
) -> None:
super().__init__()
self.cuts = cuts
self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos)
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy
self.return_spk_ids = return_spk_ids
if feature_transforms is None:
feature_transforms = []
elif not isinstance(feature_transforms, Sequence):
feature_transforms = [feature_transforms]
assert all(
isinstance(transform, Callable) for transform in feature_transforms
), "Feature transforms must be Callable"
self.feature_transforms = feature_transforms
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
validate_for_tts(cuts)
for transform in self.cut_transforms:
cuts = transform(cuts)
audio, audio_lens = collate_audio(cuts)
features, features_lens = self.feature_input_strategy(cuts)
for transform in self.feature_transforms:
features = transform(features)
tokens, tokens_lens = self.token_collater(cuts)
batch = {
"audio": audio,
"features": features,
"tokens": tokens,
"audio_lens": audio_lens,
"features_lens": features_lens,
"tokens_lens": tokens_lens,
}
if self.return_spk_ids:
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
return batch
def validate_for_tts(cuts: CutSet) -> None:
validate(cuts)
for cut in cuts:
assert (
len(cut.supervisions) == 1
), "Only the Cuts with single supervision are supported."