In [None]:
from cosyvoice.utils.file_utils import load_wav
import torchaudio
import torch
import os
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
import uuid
import threading
import soundfile as sf

from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel
from cosyvoice.cli.cosyvoice import CosyVoice

class CosyVoiceFrontEnd_eval(CosyVoiceFrontEnd):
    def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, resample_rate):
        tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
        prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=resample_rate)(prompt_speech_16k)
        speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_resample)
        speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
        if resample_rate == 24000:
            # cosyvoice2, force speech_feat % speech_token = 2
            token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
            speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
            speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
        embedding = self._extract_spk_embedding(prompt_speech_16k)
        model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
                       'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
                       'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
                       'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
                       'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
                       'llm_embedding': embedding, 'flow_embedding': embedding}
        return model_input

class CosyVoiceModel_eval(CosyVoiceModel):
    def load(self, llm_model, flow_model, hift_model):
        self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=True)
        self.llm.to(self.device).eval()
        if self.fp16 is True:
            self.llm.half()
        self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
        self.flow.to(self.device).eval()
        # in case hift_model is a hifigan model
        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
        self.hift.load_state_dict(hift_state_dict, strict=True)
        self.hift.to(self.device).eval()

    def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
        tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
                                                  token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                                  prompt_token=prompt_token.to(self.device),
                                                  prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                                  prompt_feat=prompt_feat.to(self.device),
                                                  prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                                  embedding=embedding.to(self.device),
                                                  flow_cache=self.flow_cache_dict[uuid])
        self.flow_cache_dict[uuid] = flow_cache

        # mel overlap fade in out
        if self.mel_overlap_dict[uuid].shape[2] != 0:
            tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
        # append hift cache
        if self.hift_cache_dict[uuid] is not None:
            hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
            tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
        else:
            hift_cache_source = torch.zeros(1, 1, 0)
        # keep overlap mel and hift cache
        if finalize is False:
            self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
            tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
            self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
                                          'source': tts_source[:, :, -self.source_cache_len:],
                                          'speech': tts_speech[:, -self.source_cache_len:]}
            tts_speech = tts_speech[:, :-self.source_cache_len]
        else:
            if speed != 1.0:
                assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            if self.hift_cache_dict[uuid] is not None:
                tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
        return tts_speech

    def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
            prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
        # this_uuid is used to track variables related to this inference thread
        this_uuid = str(uuid.uuid1())
        with self.lock:
            self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
            self.hift_cache_dict[this_uuid] = None
            self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
            self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
        p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
        p.start()
        # deal with all tokens
        p.join()
        this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
        this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                            prompt_token=flow_prompt_speech_token,
                                            prompt_feat=prompt_speech_feat,
                                            embedding=flow_embedding,
                                            uuid=this_uuid,
                                            finalize=True,
                                            speed=speed)
        return this_tts_speech
    
    def inference(self, flow_embedding, 
            llm_prompt_speech_token,
            flow_prompt_speech_token,
            prompt_speech_feat,
            stream=False, speed=1.0, 
            n_timesteps=10, temperature=1.0, alpha=1.0, solver='euler', **kwargs):

        prompt_token = flow_prompt_speech_token
        prompt_feat = prompt_speech_feat
        embedding = flow_embedding

        tts_mel, tp, sigma = self.flow.inference(token=llm_prompt_speech_token.to(self.device),
                                                  token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                                  prompt_token=prompt_token.to(self.device),
                                                  prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                                  prompt_feat=prompt_feat.to(self.device),
                                                  prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                                  embedding=embedding.to(self.device),
                                                  n_timesteps=n_timesteps,
                                                  temperature=temperature,
                                                  alpha=alpha,
                                                  solver=solver)

        hift_cache_source = torch.zeros(1, 1, 0)
        tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
        return tts_speech, tp, sigma


class CosyVoice_eval(CosyVoice):
    def __init__(self, model_dir, flow_dir, config_dir, load_jit=False, load_trt=False, fp16=False):
        #super().__init__(model_dir, load_jit, load_trt, fp16)
        self.instruct = True if '-Instruct' in model_dir else False
        self.model_dir = model_dir
        self.fp16 = fp16
        if not os.path.exists(model_dir):
            model_dir = snapshot_download(model_dir)
        with open(config_dir, 'r') as f:
            configs = load_hyperpyyaml(f)
        self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
                                          configs['feat_extractor'],
                                          '{}/campplus.onnx'.format(model_dir),
                                          '{}/speech_tokenizer_v1.onnx'.format(model_dir),
                                          '{}/spk2info.pt'.format(model_dir),
                                          configs['allowed_special'])
        self.sample_rate = configs['sample_rate']
        if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
            load_jit, load_trt, fp16 = False, False, False
            logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
        self.model = CosyVoiceModel_eval(configs['llm'], configs['flow'], configs['hift'], fp16)
        self.model.load('{}/llm.pt'.format(model_dir),
                        flow_dir,
                        '{}/hift.pt'.format(model_dir))
        if load_jit:
            self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                '{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                '{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
        if load_trt:
            self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
                                '{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
                                self.fp16)
        del configs
    
    def inference_zero_shot(self, prompt_text, target_text, prompt_token, target_token, prompt_speech_16k, prompt_embed, stream=False, speed=1.0, text_frontend=True, n_timesteps=10, temperature=1.0, alpha=1.0, solver='euler'):
        prompt_speech_resample = torchaudio.transforms.Resample(orig_freq=16000, new_freq=self.sample_rate)(prompt_speech_16k)
        prompt_feat, prompt_feat_len = self.frontend._extract_speech_feat(prompt_speech_resample)

        # The following commented-out codes are about using LLM to generate tokens
        '''
        target_text_token, target_text_token_len = self.frontend._extract_text_token(target_text)
        prompt_text_token, prompt_text_token_len = self.frontend._extract_text_token(prompt_text)
        
        target_token = []
        for i in self.model.llm.inference(text=target_text_token.to(self.model.device),
                                text_len=torch.tensor([target_text_token.shape[1]], dtype=torch.int32).to(self.model.device),
                                prompt_text=prompt_text_token.to(self.model.device),
                                prompt_text_len=torch.tensor([prompt_text_token.shape[1]], dtype=torch.int32).to(self.model.device),
                                prompt_speech_token=prompt_token.to(self.model.device),
                                prompt_speech_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.model.device),
                                embedding=prompt_embed.to(self.model.device)):
            target_token.append(i)
        target_token = torch.tensor(target_token).unsqueeze(0).to(self.model.device)
        '''
        model_input = {'llm_prompt_speech_token': target_token,
                       'flow_prompt_speech_token': prompt_token,
                       'prompt_speech_feat': prompt_feat, 'prompt_speech_feat_len': prompt_feat_len,
                       'flow_embedding': prompt_embed.to(self.model.device)}
        model_output, tp, sigma = self.model.inference(**model_input, stream=stream, speed=speed, n_timesteps=n_timesteps, temperature=temperature, alpha=alpha, solver=solver)
        return model_output, tp, sigma


In [None]:
cosyvoice = CosyVoice_eval('pretrained_models/CosyVoice-300M', 
                           'pretrained_models/CosyVoice-sfm-epoch_199_step_200201.pt',
                           "configs/cosyvoice.yaml",
                           load_jit=False, load_trt=False, fp16=False)

In [None]:
wavs_dict = {}
with open("CosyVoice-libritts-data/test-clean/wav.scp", "r") as f:
    wavs = f.readlines()
for wav in wavs:
    wavs_dict[wav.split(" ")[0]] = wav.split(" ", 1)[1].strip("\n").replace("xxx", "your LibriTTS wav path")

# texts_dict = {}
# with open("CosyVoice-libritts-data/test-clean/text", "r") as f:
#     texts = f.readlines()
# for t in texts:
#     texts_dict[t.split(" ")[0]] = t.split(" ", 1)[1].strip("\n")

tokens = torch.load("CosyVoice-libritts-data/test-clean/utt2speech_token.pt")
embeds = torch.load("CosyVoice-libritts-data/test-clean/utt2embedding.pt")

pairs_dict = {}
with open("../libritts-cross_sentence-infer/test_pairs.txt", "r") as f:
    pairs = f.readlines()
for pair in pairs:
    prompt, target = pair.strip("\n").split(" ")
    pairs_dict[target] = prompt

token_dict = torch.load("../libritts-cross_sentence-infer/test_target_tokens.pt")

In [None]:
# folder = "xxx"
# import os
# if os.path.exists(folder):
#     os.system(f"rm -r {folder}")
# os.makedirs(folder, exist_ok=True)

alpha = 2.0
n_timesteps = 10
solver = 'euler'

tps = 0.0
sigma_ps = 0.0

SEED = 1234
import random
random.seed(SEED)
import numpy as np
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
for target in tqdm(pairs_dict.keys()):
    prompt = pairs_dict[target]
    prompt_speech_16k = load_wav(wavs_dict[prompt], 16000)
    prompt_text = None #texts_dict[prompt]
    target_text = None #texts_dict[target]

    prompt_token = torch.tensor(tokens[prompt]).unsqueeze(0)
    prompt_embed = torch.tensor(embeds[prompt]).unsqueeze(0)
    target_token = torch.tensor(token_dict[target]).unsqueeze(0)

    output, tp, sigma_p = cosyvoice.inference_zero_shot(prompt_text, target_text, prompt_token, target_token, prompt_speech_16k, prompt_embed, n_timesteps=n_timesteps, alpha=alpha, solver=solver)
    output = output.cpu().squeeze(0)
    break
    # tps += tp 
    # sigma_ps += sigma_p
#     sf.write(f'{folder}/{target+".wav"}', output, 22050, 'PCM_24')

# tps = round(tps/len(pairs_dict.keys()), 8)
# sigma_ps = round(sigma_ps/len(pairs_dict.keys()), 8)
# print(tps, sigma_ps)

In [None]:
import IPython.display as ipd
ipd.display(ipd.Audio(output, rate=22050))