Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify SpeechSynthesisDataset class, make it return text #1205

Merged
merged 8 commits into from
Nov 30, 2023
37 changes: 25 additions & 12 deletions lhotse/dataset/speech_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.collation import TokenCollater, collate_audio
from lhotse.dataset.collation import collate_audio
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import ifnone

Expand All @@ -19,31 +19,34 @@
{
'audio': (B x NumSamples) float tensor
'features': (B x NumFrames x NumFeatures) float tensor
'tokens': (B x NumTokens) long tensor
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'tokens_lens': (B, ) int tensor
'speakers': List[str] of len B (optional) # if return_spk_ids is True
'text': List[str] of len B # when return_text=True
'tokens': List[List[str]] # when return_tokens=True
'speakers': List[str] of len B # when return_spk_ids=True
'cut': List of Cuts # when return_cuts=True
}
"""

def __init__(
self,
cuts: CutSet,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
add_eos: bool = True,
add_bos: bool = True,
return_text: bool = True,
return_tokens: bool = False,
return_spk_ids: bool = False,
return_cuts: bool = False,
) -> None:
super().__init__()

self.cuts = cuts
self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos)
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy

self.return_text = return_text
self.return_tokens = return_tokens
self.return_spk_ids = return_spk_ids
self.return_cuts = return_cuts

if feature_transforms is None:
feature_transforms = []
Expand All @@ -67,18 +70,28 @@
for transform in self.feature_transforms:
features = transform(features)

tokens, tokens_lens = self.token_collater(cuts)
batch = {
"audio": audio,
"features": features,
"tokens": tokens,
"audio_lens": audio_lens,
"features_lens": features_lens,
"tokens_lens": tokens_lens,
}

if self.return_text:
# use normalized text
text = [cut.supervisions[0].normalized_text for cut in cuts]
batch["text"] = text

if self.return_tokens:
tokens = [cut.tokens for cut in cuts]
batch["tokens"] = tokens

Check warning on line 87 in lhotse/dataset/speech_synthesis.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/speech_synthesis.py#L86-L87

Added lines #L86 - L87 were not covered by tests

if self.return_spk_ids:
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]

if self.return_cuts:
batch["cut"] = [cut for cut in cuts]

Check warning on line 93 in lhotse/dataset/speech_synthesis.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/speech_synthesis.py#L93

Added line #L93 was not covered by tests

return batch


Expand Down
2 changes: 1 addition & 1 deletion lhotse/recipes/ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def prepare_ljspeech(
supervisions = []
with open(metadata_csv_path) as f:
for line in f:
recording_id, text, normalized = line.split("|")
recording_id, text, normalized = line.strip().split("|")
audio_path = corpus_dir / "wavs" / f"{recording_id}.wav"
if not audio_path.is_file():
logging.warning(f"No such file: {audio_path}")
Expand Down
8 changes: 3 additions & 5 deletions test/dataset/test_speech_synthesis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,18 @@ def test_speech_synthesis_dataset(cut_set, transform):
else:
transform = None

dataset = SpeechSynthesisDataset(cut_set, feature_transforms=transform)
dataset = SpeechSynthesisDataset(feature_transforms=transform)
example = dataset[cut_set]
assert example["audio"].shape[1] > 0
assert example["features"].shape[1] > 0
assert example["tokens"].shape[1] > 0
assert len(example["text"]) > 0
assert len(example["text"][0]) > 0

assert example["audio"].ndim == 2
assert example["features"].ndim == 3
assert example["tokens"].ndim == 2

assert isinstance(example["audio_lens"], torch.IntTensor)
assert isinstance(example["features_lens"], torch.IntTensor)
assert isinstance(example["tokens_lens"], torch.IntTensor)

assert example["audio_lens"].ndim == 1
assert example["features_lens"].ndim == 1
assert example["tokens_lens"].ndim == 1
4 changes: 3 additions & 1 deletion test/fixtures/ljspeech/cuts.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"duration": 1.5396371882086168,
"channel": 0,
"text": "IN EIGHTEEN THIRTEEN",
"custom":{"normalized_text": "IN EIGHTEEN THIRTEEN"},
"language": "English",
"gender": "female"
}
Expand Down Expand Up @@ -59,6 +60,7 @@
"duration": 1.5976870748299319,
"channel": 0,
"text": "EIGHT THE PRESS YARD",
"custom":{"normalized_text": "EIGHT THE PRESS YARD"},
"language": "English",
"gender": "female"
}
Expand Down Expand Up @@ -93,4 +95,4 @@
},
"type": "MonoCut"
}
]
]
Loading