Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify SpeechSynthesisDataset class, make it return text #1205

Merged
merged 8 commits into from
Nov 30, 2023
42 changes: 32 additions & 10 deletions lhotse/dataset/speech_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,35 @@ class SpeechSynthesisDataset(torch.utils.data.Dataset):
{
'audio': (B x NumSamples) float tensor
'features': (B x NumFrames x NumFeatures) float tensor
'tokens': (B x NumTokens) long tensor
'tokens': (B x NumTokens) long tensor # when return_tokens=True
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'tokens_lens': (B, ) int tensor
'tokens_lens': (B, ) int tensor # when return_tokens=True
'text': List[str] of len B # when return_tokens=False
'cut': List of Cuts # when return_cuts=True
}
"""

def __init__(
self,
cuts: CutSet,
return_cuts: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is breaking, can you move return_cuts to some position towards the end of parameter list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

cuts: CutSet = None,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
return_tokens: bool = True,
add_eos: bool = True,
add_bos: bool = True,
) -> None:
super().__init__()
self.return_cuts = return_cuts

self.return_tokens = return_tokens
if return_tokens:
assert cuts is not None, "cuts is required when return_tokens=True"
self.cuts = cuts
self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos)

self.cuts = cuts
self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos)
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy

Expand All @@ -64,17 +73,30 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
for transform in self.feature_transforms:
features = transform(features)

tokens, tokens_lens = self.token_collater(cuts)

return {
batch = {
"audio": audio,
"features": features,
"tokens": tokens,
"audio_lens": audio_lens,
"features_lens": features_lens,
"tokens_lens": tokens_lens,
}

if self.return_cuts:
batch["cut"] = [cut for cut in cuts]

if self.return_tokens:
tokens, tokens_lens = self.token_collater(cuts)
batch["tokens"] = tokens
batch["tokens_lens"] = tokens_lens
else:
# use normalized text
text = [
" ".join(sup.normalized_text for sup in cut.supervisions)
for cut in cuts
]
batch["text"] = text

return batch


def validate_for_tts(cuts: CutSet) -> None:
validate(cuts)
Expand Down
2 changes: 1 addition & 1 deletion lhotse/recipes/ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def prepare_ljspeech(
supervisions = []
with open(metadata_csv_path) as f:
for line in f:
recording_id, text, normalized = line.split("|")
recording_id, text, normalized = line.strip().split("|")
audio_path = corpus_dir / "wavs" / f"{recording_id}.wav"
if not audio_path.is_file():
logging.warning(f"No such file: {audio_path}")
Expand Down
Loading