Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpeechSynthesisDataset returns speaker_ids #1206

Merged
merged 11 commits into from
Nov 10, 2023
21 changes: 19 additions & 2 deletions lhotse/dataset/speech_synthesis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, Dict, List, Sequence, Union

import numpy as np
JinZr marked this conversation as resolved.
Show resolved Hide resolved
import torch

from lhotse import validate
Expand All @@ -23,6 +24,7 @@
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'tokens_lens': (B, ) int tensor
'speakers': (B) long tensor (optional)
}
"""

Expand All @@ -32,6 +34,7 @@
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
speaker_id_mapping: Dict[str, int] = None,
add_eos: bool = True,
add_bos: bool = True,
) -> None:
Expand All @@ -41,6 +44,7 @@
self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos)
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy
self.speaker_id_mapping = speaker_id_mapping

if feature_transforms is None:
feature_transforms = []
Expand All @@ -65,15 +69,21 @@
features = transform(features)

tokens, tokens_lens = self.token_collater(cuts)

return {
batch = {
"audio": audio,
"features": features,
"tokens": tokens,
"audio_lens": audio_lens,
"features_lens": features_lens,
"tokens_lens": tokens_lens,
}
if self.speaker_id_mapping is not None:
batch["speakers"] = torch.tensor(

Check warning on line 81 in lhotse/dataset/speech_synthesis.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/speech_synthesis.py#L81

Added line #L81 was not covered by tests
[self.speaker_id_mapping[cut.supervisions[0].speaker] for cut in cuts],
dtype=torch.long,
)
else:
JinZr marked this conversation as resolved.
Show resolved Hide resolved
return batch


def validate_for_tts(cuts: CutSet) -> None:
Expand All @@ -82,3 +92,10 @@
assert (
len(cut.supervisions) == 1
), "Only the Cuts with single supervision are supported."


def get_sid_to_index_map(sid_list) -> Dict[str, np.ndarray]:
JinZr marked this conversation as resolved.
Show resolved Hide resolved
sid_to_onehot_map = {}
for index, sid in enumerate(sid_list):
sid_to_onehot_map[sid] = index
return sid_to_onehot_map

Check warning on line 101 in lhotse/dataset/speech_synthesis.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/speech_synthesis.py#L98-L101

Added lines #L98 - L101 were not covered by tests
Loading