Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.19.1 #3115

Merged
merged 10 commits into from
Oct 28, 2023
4 changes: 2 additions & 2 deletions TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self.voice_converter = None
self.csapi = None
self.cs_api_model = cs_api_model
self.model_name = None
self.model_name = ""

if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
Expand Down Expand Up @@ -460,7 +460,7 @@ def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None):
"""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
# Lazy code... save it to a temp file to resample it while reading it for VC
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name)
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name,speaker_wav=speaker_wav)
if self.voice_converter is None:
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
Expand Down
4 changes: 1 addition & 3 deletions TTS/tts/layers/xtts/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,20 +483,18 @@ def preprocess_text(self, txt, lang):
if lang == "zh-cn":
txt = chinese_transliterate(txt)
elif lang == "ja":
assert txt[:4] == "[ja]", "Japanese speech should start with the [ja] token."
txt = txt[4:]
if self.katsu is None:
import cutlet
self.katsu = cutlet.Cutlet()
txt = japanese_cleaners(txt, self.katsu)
txt = "[ja]" + txt
else:
raise NotImplementedError()
return txt

def encode(self, txt, lang):
if self.preprocess:
txt = self.preprocess_text(txt, lang)
txt = f"[{lang}]{txt}"
txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids

Expand Down
97 changes: 39 additions & 58 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn.functional as F
import torchaudio
import librosa
from coqpit import Coqpit

from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
Expand All @@ -21,34 +22,6 @@
init_stream_support()


def load_audio(audiopath, sr=22050):
"""
Load an audio file from disk and resample it to the specified sampling rate.

Args:
audiopath (str): Path to the audio file.
sr (int): Target sampling rate.

Returns:
Tensor: Audio waveform tensor with shape (1, T), where T is the number of samples.
"""
audio, sampling_rate = torchaudio.load(audiopath)

if len(audio.shape) > 1:
if audio.shape[0] < 5:
audio = audio[0]
else:
assert audio.shape[1] < 5
audio = audio[:, 0]

if sampling_rate != sr:
resampler = torchaudio.transforms.Resample(sampling_rate, sr)
audio = resampler(audio)

audio = audio.clamp_(-1, 1)
return audio.unsqueeze(0)


def wav_to_mel_cloning(
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
):
Expand Down Expand Up @@ -376,32 +349,29 @@ def device(self):
return next(self.parameters()).device

@torch.inference_mode()
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.

Args:
audio_path (str): Path to the audio file.
length (int): Length of the audio in seconds. Defaults to 3.
"""

audio = load_audio(audio_path)
audio = audio[:, : 22050 * length]
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
audio_22k = audio_22k[:, : 22050 * length]
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)

@torch.inference_mode()
def get_diffusion_cond_latents(
self,
audio_path,
):
def get_diffusion_cond_latents(self, audio, sr):
from math import ceil

diffusion_conds = []
CHUNK_SIZE = 102400
audio = load_audio(audio_path, 24000)
for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)):
current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
audio_24k = torchaudio.functional.resample(audio, sr, 24000)
for chunk in range(ceil(audio_24k.shape[1] / CHUNK_SIZE)):
current_sample = audio_24k[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
cond_mel = wav_to_univnet_mel(
current_sample.to(self.device),
Expand All @@ -414,27 +384,38 @@ def get_diffusion_cond_latents(
return diffusion_latent

@torch.inference_mode()
def get_speaker_embedding(self, audio_path):
audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
speaker_embedding = (
self.hifigan_decoder.speaker_encoder.forward(audio.to(self.device), l2_norm=True)
.unsqueeze(-1)
.to(self.device)
)
return speaker_embedding

def get_speaker_embedding(self, audio, sr):
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
return self.hifigan_decoder.speaker_encoder.forward(
audio_16k.to(self.device), l2_norm=True
).unsqueeze(-1).to(self.device)

@torch.inference_mode()
def get_conditioning_latents(
self,
audio_path,
gpt_cond_len=3,
):
gpt_cond_len=6,
max_ref_length=10,
librosa_trim_db=None,
sound_norm_refs=False,
):
speaker_embedding = None
diffusion_cond_latents = None
if self.args.use_hifigan:
speaker_embedding = self.get_speaker_embedding(audio_path)

audio, sr = torchaudio.load(audio_path)
audio = audio[:, : sr * max_ref_length].to(self.device)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
if sound_norm_refs:
audio = (audio / torch.abs(audio).max()) * 0.75
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

if self.args.use_hifigan or self.args.use_ne_hifigan:
speaker_embedding = self.get_speaker_embedding(audio, sr)
else:
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr)
gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T]
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding

def synthesize(self, text, config, speaker_wav, language, **kwargs):
Expand Down Expand Up @@ -494,7 +475,7 @@ def full_inference(
repetition_penalty=2.0,
top_k=50,
top_p=0.85,
gpt_cond_len=4,
gpt_cond_len=6,
do_sample=True,
# Decoder inference
decoder_iterations=100,
Expand Down Expand Up @@ -531,7 +512,7 @@ def full_inference(
(aka boring) outputs. Defaults to 0.8.

gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 3 seconds.
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.

decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
Expand Down Expand Up @@ -610,7 +591,7 @@ def inference(
decoder="hifigan",
**hf_generate_kwargs,
):
text = f"[{language}]{text.strip().lower()}"
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)

assert (
Expand Down Expand Up @@ -722,7 +703,7 @@ def inference_stream(
assert hasattr(
self, "hifigan_decoder"
), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
text = f"[{language}]{text.strip().lower()}"
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)

fake_inputs = self.gpt.compute_embeddings(
Expand Down
Loading