In [None]:
import IPython.display as ipd
import torch
import soundfile as sf
from tqdm import tqdm

from api import StableTTSAPI

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tts_model_path = 'CosyVoice-DiT-sfm-libritts.pt'

vocoder_model_path = 'vocoders/pretrained/vocos.pt'
vocoder_type = 'vocos'

model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type)
model.to(device)

tts_param, vocoder_param = model.get_params()
print(f'tts_param: {tts_param}, vocoder_param: {vocoder_param}')

In [None]:
wavs_dict = {}
with open("CosyVoice-libritts-data/dev-clean/wav.scp", "r") as f:
    wavs = f.readlines()
for wav in wavs:
    wavs_dict[wav.split(" ")[0]] = wav.split(" ", 1)[1].strip("\n").replace("xxx", "your LibriTTS wav path")

# texts_dict = {}
# with open("CosyVoice-libritts-data/dev-clean/text", "r") as f:
#     texts = f.readlines()
# for t in texts:
#     texts_dict[t.split(" ")[0]] = t.split(" ", 1)[1].strip("\n")

embeds = torch.load("CosyVoice-libritts-data/dev-clean/utt2embedding.pt")

pairs_dict = {}
with open("../libritts-cross_sentence-infer/val_pairs.txt", "r") as f:
    pairs = f.readlines()
for pair in pairs:
    prompt, target = pair.strip("\n").split(" ")
    pairs_dict[target] = prompt

token_dict = torch.load("../libritts-cross_sentence-infer/val_target_tokens.pt")

In [None]:
language = 'english' # support chinese, japanese and english
solver = 'dopri5' # recommend using euler, midpoint or dopri5
steps = 30 # 
cfg = 3 # recommend 1-4
temperature = 1.0
length_scale = 1.0
alpha = 2.5

# folder = "xxx"
# import os
# if os.path.exists(folder):
#     os.system(f"rm -r {folder}")
# os.makedirs(folder, exist_ok=True)

SEED = 1234
import random
random.seed(SEED)
import numpy as np
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

tps = 0. 
sigma_ps = 0.
for target in tqdm(list(pairs_dict.keys())):
    prompt = pairs_dict[target]
    ref_audio = wavs_dict[target]
    token = torch.tensor(token_dict[target], dtype=torch.long, device=device).unsqueeze(0)
    token_length = torch.tensor([token.size(-1)], dtype=torch.long, device=device)
    c = torch.tensor(embeds[prompt]).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs, tp, sigma_p  = model.tts_model.synthesise(token, token_length, c, steps, temperature, alpha, length_scale, solver, cfg)
        audio_output = model.vocoder_model(outputs['decoder_outputs']).cpu()
        break
#         tps += tp 
#         sigma_ps += sigma_p
#     sf.write(f'{folder}/{target+".wav"}', audio_output[0], model.mel_config.sample_rate, 'PCM_24')

# tps = round(tps/len(pairs_dict.keys()), 8)
# sigma_ps = round(sigma_ps/len(pairs_dict.keys()), 8)
ipd.display(ipd.Audio(audio_output[0], rate=model.mel_config.sample_rate))