In [1]:
import argparse
import codecs
import re
import numpy as np
import soundfile as sf
import tempfile
import tomli
import torch
import torchaudio
import tqdm

from cached_path import cached_path
from einops import rearrange
from model import CFM, DiT, MMDiT, UNetT
from model.utils import (convert_char_to_pinyin, get_tokenizer,
                         load_checkpoint, save_spectrogram)
from pathlib import Path
from pydub import AudioSegment, silence
from tempfile import NamedTemporaryFile
from torchaudio import transforms
from transformers import pipeline
from vocos import Vocos

In [17]:
# Configurations
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
remove_silence_default = True

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Extract necessary configurations from the loaded toml
config = tomli.load(open("./samples/story.toml", "rb"))

ref_audio = config["ref_audio"]
ref_text = config["ref_text"]
gen_text = config["gen_text"]
gen_file = config["gen_file"]
if gen_file:
    gen_text = codecs.open(gen_file, "r", "utf-8").read()
output_dir = config["output_dir"]
model = config["model"]
ckpt_file = ""
vocab_file = ""
remove_silence = config["remove_silence"]
wave_path = Path(output_dir)/"out.wav"
spectrogram_path = Path(output_dir)/"out.png"

# Load the vocos model once
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")

print(f"Using {device} device")

# --------------------- Settings -------------------- #

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32  # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
# fix_duration = 27  # None or float (duration in seconds)
fix_duration = None

def load_model(model_cls, model_cfg, ckpt_path, file_vocab):
    if file_vocab == "":
        file_vocab = "Emilia_ZH_EN"
        tokenizer = "pinyin"
    else:
        tokenizer = "custom"

    print("\nvocab : ", file_vocab, tokenizer)
    print("tokenizer : ", tokenizer)
    print("model : ", ckpt_path, "\n")

    vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
    model = CFM(
        transformer=model_cls(
            **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
        ),
        mel_spec_kwargs=dict(
            target_sample_rate=target_sample_rate,
            n_mel_channels=n_mel_channels,
            hop_length=hop_length,
        ),
        odeint_kwargs=dict(
            method=ode_method,
        ),
        vocab_char_map=vocab_char_map,
    ).to(device)

    model = load_checkpoint(model, ckpt_path, device, use_ema=True)

    return model

# Load models outside of functions to prevent multiple loading
F5TTS_model_cfg = dict(
    dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)

# Load the TTS model once
if model == "F5-TTS":
    if ckpt_file == "":
        repo_name = "F5-TTS"
        exp_name = "F5TTS_Base"
        ckpt_step = 1200000
        ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))

    ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file, vocab_file)
elif model == "E2-TTS":
    if ckpt_file == "":
        repo_name = "E2-TTS"
        exp_name = "E2TTS_Base"
        ckpt_step = 1200000
        ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))

    ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file, vocab_file)

def chunk_text(text, max_chars=135):
    """
    Splits the input text into chunks, each with a maximum number of characters.
    Args:
        text (str): The text to be split.
        max_chars (int): The maximum number of characters per chunk.
    Returns:
        List[str]: A list of text chunks.
    """
    chunks = []
    current_chunk = ""
    # Split the text into sentences based on punctuation followed by whitespace
    sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[；：，。！？])', text)

    for sentence in sentences:
        if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
            current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
        else:
            if current_chunk:
                chunks.append(current_chunk.strip())
            current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence

    if current_chunk:
        chunks.append(current_chunk.strip())

    return chunks

def infer_batch(ref_audio, ref_text, gen_text_batches, ema_model, remove_silence, cross_fade_duration=0.15):
    audio, sr = ref_audio
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)

    rms = torch.sqrt(torch.mean(torch.square(audio)))
    if rms < target_rms:
        audio = audio * target_rms / rms
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
        audio = resampler(audio)
    audio = audio.to(device)

    generated_waves = []
    spectrograms = []

    if len(ref_text[-1].encode('utf-8')) == 1:
        ref_text = ref_text + " "
    for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
        # Prepare the text
        text_list = [ref_text + gen_text]
        final_text_list = convert_char_to_pinyin(text_list)

        # Calculate duration
        ref_audio_len = audio.shape[-1] // hop_length
        zh_pause_punc = r"。，、；：？！"
        ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
        gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
        duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)

        # inference
        with torch.inference_mode():
            generated, _ = ema_model.sample(
                cond=audio,
                text=final_text_list,
                duration=duration,
                steps=nfe_step,
                cfg_strength=cfg_strength,
                sway_sampling_coef=sway_sampling_coef,
            )

        generated = generated[:, ref_audio_len:, :]
        generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
        generated_wave = vocos.decode(generated_mel_spec.cpu())
        if rms < target_rms:
            generated_wave = generated_wave * rms / target_rms

        # wav -> numpy
        generated_wave = generated_wave.squeeze().cpu().numpy()
        
        generated_waves.append(generated_wave)
        spectrograms.append(generated_mel_spec[0].cpu().numpy())

    # Combine all generated waves with cross-fading
    if cross_fade_duration <= 0:
        # Simply concatenate
        final_wave = np.concatenate(generated_waves)
    else:
        final_wave = generated_waves[0]
        for i in range(1, len(generated_waves)):
            prev_wave = final_wave
            next_wave = generated_waves[i]

            # Calculate cross-fade samples, ensuring it does not exceed wave lengths
            cross_fade_samples = int(cross_fade_duration * target_sample_rate)
            cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))

            if cross_fade_samples <= 0:
                # No overlap possible, concatenate
                final_wave = np.concatenate([prev_wave, next_wave])
                continue

            # Overlapping parts
            prev_overlap = prev_wave[-cross_fade_samples:]
            next_overlap = next_wave[:cross_fade_samples]

            # Fade out and fade in
            fade_out = np.linspace(1, 0, cross_fade_samples)
            fade_in = np.linspace(0, 1, cross_fade_samples)

            # Cross-faded overlap
            cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in

            # Combine
            new_wave = np.concatenate([
                prev_wave[:-cross_fade_samples],
                cross_faded_overlap,
                next_wave[cross_fade_samples:]
            ])

            final_wave = new_wave

    # Create a combined spectrogram
    combined_spectrogram = np.concatenate(spectrograms, axis=1)

    return final_wave, combined_spectrogram

def process_voice(ref_audio_orig, ref_text):
    print("Converting audio...")
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
        aseg = AudioSegment.from_file(ref_audio_orig)

        non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
        non_silent_wave = AudioSegment.silent(duration=0)
        for non_silent_seg in non_silent_segs:
            non_silent_wave += non_silent_seg
        aseg = non_silent_wave

        audio_duration = len(aseg)
        if audio_duration > 15000:
            print("Audio is over 15s, clipping to only first 15s.")
            aseg = aseg[:15000]
        aseg.export(f.name, format="wav")
        ref_audio = f.name

    if not ref_text.strip():
        print("No reference text provided, transcribing reference audio...")
        pipe = pipeline(
            "automatic-speech-recognition",
            model="openai/whisper-large-v3-turbo",
            torch_dtype=torch.float16,
            device=device,
        )
        ref_text = pipe(
            ref_audio,
            chunk_length_s=30,
            batch_size=128,
            generate_kwargs={"task": "transcribe"},
            return_timestamps=False,
        )["text"].strip()
        print("Finished transcription")
    else:
        print("Using custom reference text...")
    return ref_audio, ref_text    

def infer(ref_audio, ref_text, gen_text, remove_silence, ema_model, cross_fade_duration=0.15):
    print(gen_text)
    # Add the functionality to ensure it ends with ". "
    if not ref_text.endswith(". ") and not ref_text.endswith("。"):
        if ref_text.endswith("."):
            ref_text += " "
        else:
            ref_text += ". "

    # Split the input text into batches
    audio, sr = torchaudio.load(ref_audio)
    max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
    gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
    print('ref_text', ref_text)
    for i, gen_text in enumerate(gen_text_batches):
        print(f'gen_text {i}', gen_text)
    
    print(f"Generating audio in {len(gen_text_batches)} batches...")
    return infer_batch((audio, sr), ref_text, gen_text_batches, ema_model, remove_silence, cross_fade_duration)
    

def process(ref_audio, ref_text, text_gen, remove_silence, ema_model):
    main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
    if "voices" not in config:
        voices = {"main": main_voice}
    else:
        voices = config["voices"]
        voices["main"] = main_voice
    for voice in voices:
        voices[voice]['ref_audio'], voices[voice]['ref_text'] = process_voice(voices[voice]['ref_audio'], voices[voice]['ref_text'])

    generated_audio_segments = []
    reg1 = r'(?=\[\w+\])'
    chunks = re.split(reg1, text_gen)
    reg2 = r'\[(\w+)\]'
    for text in chunks:
        match = re.match(reg2, text)
        if not match or voice not in voices:
            voice = "main"
        else:
            voice = match[1]
        text = re.sub(reg2, "", text)
        gen_text = text.strip()
        ref_audio = voices[voice]['ref_audio']
        ref_text = voices[voice]['ref_text']
        print(f"Voice: {voice}")
        audio, spectragram = infer(ref_audio, ref_text, gen_text, remove_silence, ema_model)
        generated_audio_segments.append(audio)

    if generated_audio_segments:
        final_wave = np.concatenate(generated_audio_segments)
        with open(wave_path, "wb") as f:
            sf.write(f.name, final_wave, target_sample_rate)
            # Remove silence
            if remove_silence:
                aseg = AudioSegment.from_file(f.name)
                non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
                non_silent_wave = AudioSegment.silent(duration=0)
                for non_silent_seg in non_silent_segs:
                    non_silent_wave += non_silent_seg
                aseg = non_silent_wave
                aseg.export(f.name, format="wav")
            print(f.name)

# Now call process with the loaded ema_model
process(ref_audio, ref_text, gen_text, remove_silence, ema_model)

  state_dict = torch.load(model_path, map_location="cpu")


Using cuda device

vocab :  Emilia_ZH_EN pinyin
tokenizer :  pinyin
model :  /home/hulk/.cache/huggingface/hub/models--SWivid--F5-TTS/snapshots/d0ac03c2b5ded76f302ada887ae0da5675e88a5d/F5TTS_Base/model_1200000.safetensors 

Converting audio...
No reference text provided, transcribing reference audio...




Finished transcription
Converting audio...
No reference text provided, transcribing reference audio...




Finished transcription
Converting audio...
No reference text provided, transcribing reference audio...




Finished transcription
Converting audio...
No reference text provided, transcribing reference audio...




Finished transcription
Voice: main
A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with
ref_text It's not God so much as afterlife. Religion is pretty much about not dying. I would say 99.9% of human beings don't want to die. 
gen_text 0 A Town Mouse and a Country Mouse were acquaintances,
gen_text 1 and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came,
gen_text 2 and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour.
gen_text 3 The fare was not much to the taste of the guest, and presently he broke out with
Generating audio in 4 batches...


100%|█████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:13<00:00,  3.45s/it]


Voice: town
“My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.”
ref_text The difference in the rainbow depends considerably upon the size of the drops and the width of the coloured band increases as the size of the drops increases. 
gen_text 0 “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.”
Generating audio in 1 batches...


100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.84s/it]


Voice: main
So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor.
ref_text It's not God so much as afterlife. Religion is pretty much about not dying. I would say 99.9% of human beings don't want to die. 
gen_text 0 So when he returned to town he took the Country Mouse with him,
gen_text 1 and showed him into a larder containing flour and oatmeal and figs and honey and dates.
gen_text 2 The Country Mouse had never seen anything like it, and sat down to enjoy the luxur

100%|█████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:25<00:00,  3.64s/it]


Voice: country
“Goodbye,”
ref_text Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob. 
gen_text 0 “Goodbye,”
Generating audio in 1 batches...


100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.31s/it]


Voice: main
said he,
ref_text It's not God so much as afterlife. Religion is pretty much about not dying. I would say 99.9% of human beings don't want to die. 
gen_text 0 said he,
Generating audio in 1 batches...


100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.08s/it]


Voice: country
“I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."
ref_text Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob. 
gen_text 0 “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."
Generating audio in 1 batches...


100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.93s/it]


samples/out.wav
