In [1]:
# !wget https://huggingface.co/datasets/mesolitica/azure-tts-osman/resolve/main/parliament-texts.tar
# !tar -xvf parliament-texts.tar

In [2]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
from dynamicbatch_ttspipeline.f5_tts.load import (
    load_f5_tts,
    load_vocoder,
    target_sample_rate,
    hop_length,
    nfe_step,
    cfg_strength,
    sway_sampling_coef,
)
from dynamicbatch_ttspipeline.f5_tts.utils import (
    chunk_text,
    convert_char_to_pinyin,
)
from pydub import AudioSegment, silence
import torchaudio
import torch
import torch.nn.functional as F
import numpy as np
import librosa
import soundfile as sf

In [4]:
from ctc_forced_aligner import (
    load_audio,
    load_alignment_model,
    generate_emissions,
    preprocess_text,
    get_alignments,
    get_spans,
    postprocess_results,
)

language = "ms" # ISO-639-3 Language code
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16

alignment_model, alignment_tokenizer = load_alignment_model(
    device,
    dtype=torch.float16 if device == "cuda" else torch.float32,
)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [5]:
from glob import glob
import json

text = []
files = sorted(glob('osman-parliament-edge-tts-text/*.json'))
for f in files:
    with open(f) as fopen:
        text.append(json.load(fopen))
        
len(text)

59582

In [6]:
original_husein = 'Titah Pemangku Sultan Johor, Tunku Mahkota Ismail Sultan Ibrahim, mengenai pertukaran cuti hujung minggu negeri itu kepada Sabtu dan Ahad, tidak perlu dijadikan bahan politik.'

In [7]:
torch_dtype = torch.bfloat16
device = 'cuda'

model_name = 'mesolitica/Malaysian-F5-TTS'
model = load_f5_tts(model_name = model_name, device = device, dtype = torch.float16)
vocoder = load_vocoder(device = device)

In [8]:
_ = model.eval()

In [9]:
audio_input = 'husein-news.mp3'
dwav, sr_ = torchaudio.load(audio_input)
dwav = dwav.mean(dim=0).numpy()
target_rms = 0.1
audio = dwav
rms = np.sqrt(np.mean(np.square(audio)))
if rms < target_rms:
    audio = audio * target_rms / rms

if sr_ != target_sample_rate:
    audio = librosa.resample(audio, orig_sr = sr_, target_sr = target_sample_rate)
    
audios = torch.tensor(audio)[None].cuda()

In [10]:
ref_text = original_husein
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
    if ref_text.endswith("."):
        ref_text += " "
    else:
        ref_text += ". "
    
ref_text

'Titah Pemangku Sultan Johor, Tunku Mahkota Ismail Sultan Ibrahim, mengenai pertukaran cuti hujung minggu negeri itu kepada Sabtu dan Ahad, tidak perlu dijadikan bahan politik. '

In [11]:
max_chars = int(len(ref_text.encode("utf-8")) / (audios.shape[-1] / sr_) * (25 - audios.shape[-1] / sr_))
ref_audio_len = audios.shape[-1] // hop_length
speed = 1

In [12]:
!mkdir generate-husein-wiki-normalized

In [None]:
import re
from tqdm import tqdm

for i in tqdm(range(len(text))):
    new_filename = os.path.join('generate-husein-wiki-normalized', f'{i}.mp3')
    if os.path.exists(new_filename):
        continue

    gen_text = text[i]['normalized'].replace('\'', '').replace('"', '')
    gen_text = re.sub(r'[ ]+', ' ', gen_text).strip()
    if len(gen_text) < 3:
        continue
    final_text_lists, durations, after_durations = [], [], []
    text_list = [ref_text + gen_text]
    final_text_list = convert_char_to_pinyin(text_list)
    ref_text_len = len(ref_text.encode("utf-8"))
    gen_text_len = len(gen_text.encode("utf-8"))
    after_duration = int(ref_audio_len / ref_text_len * gen_text_len / speed)
    final_text_lists = [final_text_list[0]]
    durations = [ref_audio_len + after_duration]
    after_durations = [after_duration]

    for _ in range(5):
        with torch.no_grad():
            generated, _ = model.sample(
                cond=audios.repeat(len(final_text_lists), 1),
                text=final_text_lists,
                duration=torch.Tensor(durations).to(device).type(torch.long),
                steps=nfe_step,
                cfg_strength=2,
                sway_sampling_coef=-1.0,
            )
            generated_mel_spec = generated.to(torch.float32)[:, ref_audio_len:, :].permute(0, 2, 1)
            generated_wave = vocoder.decode(generated_mel_spec)
            if rms < target_rms:
                generated_wave = generated_wave * rms / target_rms
            actual_after_durations = [d * hop_length for d in after_durations]
            new_wav = generated_wave[0, :actual_after_durations[0]]
            audio_waveform = torchaudio.functional.resample(
                new_wav, orig_freq=24000, new_freq=16000
            ).type(torch.float16)
            emissions, stride = generate_emissions(
                alignment_model, audio_waveform, batch_size=1
            )
            tokens_starred, text_starred = preprocess_text(
                gen_text,
                romanize=True,
                language=language,
            )
            segments, scores, blank_token = get_alignments(
                emissions,
                tokens_starred,
                alignment_tokenizer,
            )
            spans = get_spans(tokens_starred, segments, blank_token)
            word_timestamps = postprocess_results(text_starred, spans, stride, scores)
            scores = [w['score'] for w in word_timestamps if w['score'] < -20]
            if not len(scores):
                a = new_wav.cpu().numpy()
                sf.write(new_filename, a, 24000)
                break

  0%|                                                 | 0/59582 [00:00<?, ?it/s]Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.318 seconds.
Prefix dict has been built successfully.
 22%|██████▉                        | 13250/59582 [16:54:03<34:40:29,  2.69s/it]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 39%|████████████                   | 23279/59582 [29:46:50<38:26:48,  3.81s/it]