Skip to content

Commit

Permalink
Modify SpeechSynthesisDataset class, make it return text (#1205)
Browse files Browse the repository at this point in the history
* modify SpeechSynthesisDataset, make it return text

* modify SpeechSynthesisDataset, remove TokenCollater

* fix test/dataset/test_speech_synthesis_dataset.py

---------

Co-authored-by: Piotr Żelasko <petezor@gmail.com>
Co-authored-by: Desh Raj <r.desh26@gmail.com>
  • Loading branch information
3 people committed Nov 30, 2023
1 parent 89ca0e6 commit b869488
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 19 deletions.
37 changes: 25 additions & 12 deletions lhotse/dataset/speech_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.collation import TokenCollater, collate_audio
from lhotse.dataset.collation import collate_audio
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import ifnone

Expand All @@ -19,31 +19,34 @@ class SpeechSynthesisDataset(torch.utils.data.Dataset):
{
'audio': (B x NumSamples) float tensor
'features': (B x NumFrames x NumFeatures) float tensor
'tokens': (B x NumTokens) long tensor
'audio_lens': (B, ) int tensor
'features_lens': (B, ) int tensor
'tokens_lens': (B, ) int tensor
'speakers': List[str] of len B (optional) # if return_spk_ids is True
'text': List[str] of len B # when return_text=True
'tokens': List[List[str]] # when return_tokens=True
'speakers': List[str] of len B # when return_spk_ids=True
'cut': List of Cuts # when return_cuts=True
}
"""

def __init__(
self,
cuts: CutSet,
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
feature_input_strategy: BatchIO = PrecomputedFeatures(),
feature_transforms: Union[Sequence[Callable], Callable] = None,
add_eos: bool = True,
add_bos: bool = True,
return_text: bool = True,
return_tokens: bool = False,
return_spk_ids: bool = False,
return_cuts: bool = False,
) -> None:
super().__init__()

self.cuts = cuts
self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos)
self.cut_transforms = ifnone(cut_transforms, [])
self.feature_input_strategy = feature_input_strategy

self.return_text = return_text
self.return_tokens = return_tokens
self.return_spk_ids = return_spk_ids
self.return_cuts = return_cuts

if feature_transforms is None:
feature_transforms = []
Expand All @@ -67,18 +70,28 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
for transform in self.feature_transforms:
features = transform(features)

tokens, tokens_lens = self.token_collater(cuts)
batch = {
"audio": audio,
"features": features,
"tokens": tokens,
"audio_lens": audio_lens,
"features_lens": features_lens,
"tokens_lens": tokens_lens,
}

if self.return_text:
# use normalized text
text = [cut.supervisions[0].normalized_text for cut in cuts]
batch["text"] = text

if self.return_tokens:
tokens = [cut.tokens for cut in cuts]
batch["tokens"] = tokens

if self.return_spk_ids:
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]

if self.return_cuts:
batch["cut"] = [cut for cut in cuts]

return batch


Expand Down
2 changes: 1 addition & 1 deletion lhotse/recipes/ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def prepare_ljspeech(
supervisions = []
with open(metadata_csv_path) as f:
for line in f:
recording_id, text, normalized = line.split("|")
recording_id, text, normalized = line.strip().split("|")
audio_path = corpus_dir / "wavs" / f"{recording_id}.wav"
if not audio_path.is_file():
logging.warning(f"No such file: {audio_path}")
Expand Down
8 changes: 3 additions & 5 deletions test/dataset/test_speech_synthesis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,18 @@ def test_speech_synthesis_dataset(cut_set, transform):
else:
transform = None

dataset = SpeechSynthesisDataset(cut_set, feature_transforms=transform)
dataset = SpeechSynthesisDataset(feature_transforms=transform)
example = dataset[cut_set]
assert example["audio"].shape[1] > 0
assert example["features"].shape[1] > 0
assert example["tokens"].shape[1] > 0
assert len(example["text"]) > 0
assert len(example["text"][0]) > 0

assert example["audio"].ndim == 2
assert example["features"].ndim == 3
assert example["tokens"].ndim == 2

assert isinstance(example["audio_lens"], torch.IntTensor)
assert isinstance(example["features_lens"], torch.IntTensor)
assert isinstance(example["tokens_lens"], torch.IntTensor)

assert example["audio_lens"].ndim == 1
assert example["features_lens"].ndim == 1
assert example["tokens_lens"].ndim == 1
4 changes: 3 additions & 1 deletion test/fixtures/ljspeech/cuts.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"duration": 1.5396371882086168,
"channel": 0,
"text": "IN EIGHTEEN THIRTEEN",
"custom":{"normalized_text": "IN EIGHTEEN THIRTEEN"},
"language": "English",
"gender": "female"
}
Expand Down Expand Up @@ -59,6 +60,7 @@
"duration": 1.5976870748299319,
"channel": 0,
"text": "EIGHT THE PRESS YARD",
"custom":{"normalized_text": "EIGHT THE PRESS YARD"},
"language": "English",
"gender": "female"
}
Expand Down Expand Up @@ -93,4 +95,4 @@
},
"type": "MonoCut"
}
]
]

0 comments on commit b869488

Please sign in to comment.