diff --git a/lhotse/dataset/speech_synthesis.py b/lhotse/dataset/speech_synthesis.py index 8de147b97..3a823a8ba 100644 --- a/lhotse/dataset/speech_synthesis.py +++ b/lhotse/dataset/speech_synthesis.py @@ -4,7 +4,7 @@ from lhotse import validate from lhotse.cut import CutSet -from lhotse.dataset.collation import TokenCollater, collate_audio +from lhotse.dataset.collation import collate_audio from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures from lhotse.utils import ifnone @@ -19,31 +19,34 @@ class SpeechSynthesisDataset(torch.utils.data.Dataset): { 'audio': (B x NumSamples) float tensor 'features': (B x NumFrames x NumFeatures) float tensor - 'tokens': (B x NumTokens) long tensor 'audio_lens': (B, ) int tensor 'features_lens': (B, ) int tensor - 'tokens_lens': (B, ) int tensor - 'speakers': List[str] of len B (optional) # if return_spk_ids is True + 'text': List[str] of len B # when return_text=True + 'tokens': List[List[str]] # when return_tokens=True + 'speakers': List[str] of len B # when return_spk_ids=True + 'cut': List of Cuts # when return_cuts=True } """ def __init__( self, - cuts: CutSet, cut_transforms: List[Callable[[CutSet], CutSet]] = None, feature_input_strategy: BatchIO = PrecomputedFeatures(), feature_transforms: Union[Sequence[Callable], Callable] = None, - add_eos: bool = True, - add_bos: bool = True, + return_text: bool = True, + return_tokens: bool = False, return_spk_ids: bool = False, + return_cuts: bool = False, ) -> None: super().__init__() - self.cuts = cuts - self.token_collater = TokenCollater(cuts, add_eos=add_eos, add_bos=add_bos) self.cut_transforms = ifnone(cut_transforms, []) self.feature_input_strategy = feature_input_strategy + + self.return_text = return_text + self.return_tokens = return_tokens self.return_spk_ids = return_spk_ids + self.return_cuts = return_cuts if feature_transforms is None: feature_transforms = [] @@ -67,18 +70,28 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: for transform in self.feature_transforms: features = transform(features) - tokens, tokens_lens = self.token_collater(cuts) batch = { "audio": audio, "features": features, - "tokens": tokens, "audio_lens": audio_lens, "features_lens": features_lens, - "tokens_lens": tokens_lens, } + + if self.return_text: + # use normalized text + text = [cut.supervisions[0].normalized_text for cut in cuts] + batch["text"] = text + + if self.return_tokens: + tokens = [cut.tokens for cut in cuts] + batch["tokens"] = tokens + if self.return_spk_ids: batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] + if self.return_cuts: + batch["cut"] = [cut for cut in cuts] + return batch diff --git a/lhotse/recipes/ljspeech.py b/lhotse/recipes/ljspeech.py index ab777b6c3..9d9866406 100644 --- a/lhotse/recipes/ljspeech.py +++ b/lhotse/recipes/ljspeech.py @@ -72,7 +72,7 @@ def prepare_ljspeech( supervisions = [] with open(metadata_csv_path) as f: for line in f: - recording_id, text, normalized = line.split("|") + recording_id, text, normalized = line.strip().split("|") audio_path = corpus_dir / "wavs" / f"{recording_id}.wav" if not audio_path.is_file(): logging.warning(f"No such file: {audio_path}") diff --git a/test/dataset/test_speech_synthesis_dataset.py b/test/dataset/test_speech_synthesis_dataset.py index 32ecac7f3..376572ae1 100644 --- a/test/dataset/test_speech_synthesis_dataset.py +++ b/test/dataset/test_speech_synthesis_dataset.py @@ -20,20 +20,18 @@ def test_speech_synthesis_dataset(cut_set, transform): else: transform = None - dataset = SpeechSynthesisDataset(cut_set, feature_transforms=transform) + dataset = SpeechSynthesisDataset(feature_transforms=transform) example = dataset[cut_set] assert example["audio"].shape[1] > 0 assert example["features"].shape[1] > 0 - assert example["tokens"].shape[1] > 0 + assert len(example["text"]) > 0 + assert len(example["text"][0]) > 0 assert example["audio"].ndim == 2 assert example["features"].ndim == 3 - assert example["tokens"].ndim == 2 assert isinstance(example["audio_lens"], torch.IntTensor) assert isinstance(example["features_lens"], torch.IntTensor) - assert isinstance(example["tokens_lens"], torch.IntTensor) assert example["audio_lens"].ndim == 1 assert example["features_lens"].ndim == 1 - assert example["tokens_lens"].ndim == 1 diff --git a/test/fixtures/ljspeech/cuts.json b/test/fixtures/ljspeech/cuts.json index 8d75b51f9..32bf45c6d 100644 --- a/test/fixtures/ljspeech/cuts.json +++ b/test/fixtures/ljspeech/cuts.json @@ -12,6 +12,7 @@ "duration": 1.5396371882086168, "channel": 0, "text": "IN EIGHTEEN THIRTEEN", + "custom":{"normalized_text": "IN EIGHTEEN THIRTEEN"}, "language": "English", "gender": "female" } @@ -59,6 +60,7 @@ "duration": 1.5976870748299319, "channel": 0, "text": "EIGHT THE PRESS YARD", + "custom":{"normalized_text": "EIGHT THE PRESS YARD"}, "language": "English", "gender": "female" } @@ -93,4 +95,4 @@ }, "type": "MonoCut" } -] \ No newline at end of file +]