From 79821ce53fc540db7f9de34fa416e05b420c3843 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 10 Nov 2023 10:54:32 +0800 Subject: [PATCH] SpeechSynthesisDataset returns `speaker_ids` (#1206) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update speech_synthesis.py * Update speech_synthesis.py * Update speech_synthesis.py * Update speech_synthesis.py * Update lhotse/dataset/speech_synthesis.py Co-authored-by: Piotr Żelasko * Update speech_synthesis.py * Update speech_synthesis.py * updated doc str * removed unused func --------- Co-authored-by: Piotr Żelasko --- lhotse/dataset/speech_synthesis.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lhotse/dataset/speech_synthesis.py b/lhotse/dataset/speech_synthesis.py index f94bc9690..8de147b97 100644 --- a/lhotse/dataset/speech_synthesis.py +++ b/lhotse/dataset/speech_synthesis.py @@ -23,6 +23,7 @@ class SpeechSynthesisDataset(torch.utils.data.Dataset): 'audio_lens': (B, ) int tensor 'features_lens': (B, ) int tensor 'tokens_lens': (B, ) int tensor + 'speakers': List[str] of len B (optional) # if return_spk_ids is True } """ @@ -34,6 +35,7 @@ def __init__( feature_transforms: Union[Sequence[Callable], Callable] = None, add_eos: bool = True, add_bos: bool = True, + return_spk_ids: bool = False, ) -> None: super().__init__() @@ -41,6 +43,7 @@ def __init__( self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos) self.cut_transforms = ifnone(cut_transforms, []) self.feature_input_strategy = feature_input_strategy + self.return_spk_ids = return_spk_ids if feature_transforms is None: feature_transforms = [] @@ -65,8 +68,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: features = transform(features) tokens, tokens_lens = self.token_collater(cuts) - - return { + batch = { "audio": audio, "features": features, "tokens": tokens, @@ -74,6 +76,10 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: "features_lens": features_lens, "tokens_lens": tokens_lens, } + if self.return_spk_ids: + batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] + + return batch def validate_for_tts(cuts: CutSet) -> None: