Skip to content

Commit

Permalink
SpeechSynthesisDataset returns speaker_ids (#1206)
Browse files Browse the repository at this point in the history
* Update speech_synthesis.py

* Update speech_synthesis.py

* Update speech_synthesis.py

* Update speech_synthesis.py

* Update lhotse/dataset/speech_synthesis.py

Co-authored-by: Piotr Żelasko <petezor@gmail.com>

* Update speech_synthesis.py

* Update speech_synthesis.py

* updated doc str

* removed unused func

---------

Co-authored-by: Piotr Żelasko <petezor@gmail.com>
  • Loading branch information
JinZr and pzelasko committed Nov 10, 2023
1 parent 3d1f4b5 commit 79821ce
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions lhotse/dataset/speech_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SpeechSynthesisDataset(torch.utils.data.Dataset):
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'tokens_lens': (B, ) int tensor
'speakers': List[str] of len B (optional) # if return_spk_ids is True
}
"""

Expand All @@ -34,13 +35,15 @@ def __init__(
feature_transforms: Union[Sequence[Callable], Callable] = None,
add_eos: bool = True,
add_bos: bool = True,
return_spk_ids: bool = False,
) -> None:
super().__init__()

self.cuts = cuts
self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos)
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy
self.return_spk_ids = return_spk_ids

if feature_transforms is None:
feature_transforms = []
Expand All @@ -65,15 +68,18 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
features = transform(features)

tokens, tokens_lens = self.token_collater(cuts)

return {
batch = {
"audio": audio,
"features": features,
"tokens": tokens,
"audio_lens": audio_lens,
"features_lens": features_lens,
"tokens_lens": tokens_lens,
}
if self.return_spk_ids:
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]

return batch


def validate_for_tts(cuts: CutSet) -> None:
Expand Down

0 comments on commit 79821ce

Please sign in to comment.