Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a TTS recipe VITS on LJSpeech dataset #1372

Merged
merged 17 commits into from
Nov 29, 2023
Merged
100 changes: 100 additions & 0 deletions egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This file computes fbank features of the LJSpeech dataset.
It looks for manifests in the directory data/manifests.

The generated fbank features are saved in data/spectrogram.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The generated fbank features are saved in data/spectrogram.
The generated spectrogram features are saved in data/spectrogram.

"""

import logging
import os
from pathlib import Path

import torch
from lhotse import CutSet, Spectrogram, SpectrogramConfig, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.supervision import SupervisionSet

from icefall.utils import get_executor

# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)


def compute_spectrogram_ljspeech():
src_dir = Path("data/manifests")
output_dir = Path("data/spectrogram")
num_jobs = min(4, os.cpu_count())

sampling_rate = 22050
frame_length = 1024 / sampling_rate # (in second)
frame_shift = 256 / sampling_rate # (in second)
use_fft_mag = True

prefix = "ljspeech"
suffix = "jsonl.gz"
partition = "all"

recordings = load_manifest(
src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet
)
supervisions = load_manifest(
src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet
)

config = SpectrogramConfig(
sampling_rate=sampling_rate,
frame_length=frame_length,
frame_shift=frame_shift,
use_fft_mag=use_fft_mag,
)
extractor = Spectrogram(config)

with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logging.info(f"{partition} already exists - skipping.")
logging.info(f"{cuts_filename} already exists - skipping.")

return
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=recordings, supervisions=supervisions
)

cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
compute_spectrogram_ljspeech()
73 changes: 73 additions & 0 deletions egs/ljspeech/TTS/local/display_manifest_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.

See the function `remove_short_and_long_utt()` in vits/train.py
for usage.
"""


from lhotse import load_manifest_lazy


def main():
path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz"
cuts = load_manifest_lazy(path)
cuts.describe()


if __name__ == "__main__":
main()

"""
Cut statistics:
╒═══════════════════════════╤══════════╕
│ Cuts count: │ 13100 │
├───────────────────────────┼──────────┤
│ Total duration (hh:mm:ss) │ 23:55:18 │
├───────────────────────────┼──────────┤
│ mean │ 6.6 │
├───────────────────────────┼──────────┤
│ std │ 2.2 │
├───────────────────────────┼──────────┤
│ min │ 1.1 │
├───────────────────────────┼──────────┤
│ 25% │ 5.0 │
├───────────────────────────┼──────────┤
│ 50% │ 6.8 │
├───────────────────────────┼──────────┤
│ 75% │ 8.4 │
├───────────────────────────┼──────────┤
│ 99% │ 10.0 │
├───────────────────────────┼──────────┤
│ 99.5% │ 10.1 │
├───────────────────────────┼──────────┤
│ 99.9% │ 10.1 │
├───────────────────────────┼──────────┤
│ max │ 10.1 │
├───────────────────────────┼──────────┤
│ Recordings available: │ 13100 │
├───────────────────────────┼──────────┤
│ Features available: │ 13100 │
├───────────────────────────┼──────────┤
│ Supervisions available: │ 13100 │
╘═══════════════════════════╧══════════╛
"""
116 changes: 116 additions & 0 deletions egs/ljspeech/TTS/local/prepare_token_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
This file reads the texts in given manifest and generate the file that maps tokens to IDs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This file reads the texts in given manifest and generate the file that maps tokens to IDs.
This file reads the texts in given manifest and generates the file that maps tokens to IDs.

"""

import argparse
import logging
from collections import Counter
from pathlib import Path
from typing import Dict

import g2p_en
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add code about how to install g2p_en.

import tacotron_cleaner.cleaners
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add code about how to install tacotron_cleaner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

from lhotse import load_manifest


def get_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
help="Path to the manifest file",
)

parser.add_argument(
"--tokens",
type=Path,
default=Path("data/tokens.txt"),
help="Path to the tokens",
)

return parser.parse_args()


def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.

Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.

Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")


def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = {
"<blk>": 0, # blank
"<sos/eos>": 1, # sos and eos symbols.
"<unk>": 2, # OOV
}
cut_set = load_manifest(manifest_file)
g2p = g2p_en.G2p()
counter = Counter()

for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
text = cut.supervisions[0].normalized_text
# Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes
tokens = g2p(text)
for t in tokens:
counter[t] += 1

# Sort by the number of occurrences in descending order
tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to sort them by count?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make it easy to cut off the vocabulary according the counts if needed. But we don't need this now.


for token, idx in extra_tokens.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Items in a dict are iterated in an unknown order.
Please use a list for extra_tokens.

You can use

tokens_and_counts = extra_tokens + tokens_and_counts

tokens_and_counts.insert(idx, (token, None))

token2id: Dict[str, int] = {token: i for i, (token, count) in enumerate(tokens_and_counts)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
token2id: Dict[str, int] = {token: i for i, (token, count) in enumerate(tokens_and_counts)}
token2id: Dict[str, int] = {token: i for i, (token, _) in enumerate(tokens_and_counts)}

return token2id


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check that if out_file exists, it returns directly without any further computation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have checked this in prepare.sh


token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)
70 changes: 70 additions & 0 deletions egs/ljspeech/TTS/local/validate_manifest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:

- Single supervision per cut

We will add more checks later if needed.

Usage example:

python3 ./local/validate_manifest.py \
./data/spectrogram/ljspeech_cuts_all.jsonl.gz

"""

import argparse
import logging
from pathlib import Path

from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts


def get_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)

return parser.parse_args()


def main():
args = get_args()

manifest = args.manifest
logging.info(f"Validating {manifest}")

assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert isinstance(cut_set, CutSet)
assert isinstance(cut_set, CutSet), type(cut_set)


validate_for_tts(cut_set)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)

main()
Loading
Loading