From 59576fc0ecf7c86681741bf440a699f85c11bdc5 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Fri, 20 Oct 2023 17:29:43 -0300 Subject: [PATCH] Bug fix on XTTS v1.1 inference (#3093) * Bug fix on XTTS v1.1 inference * Update .models.json --------- Co-authored-by: Julian Weber --- TTS/.models.json | 10 +++++----- TTS/tts/configs/xtts_config.py | 6 +++--- TTS/tts/models/xtts.py | 11 ++++++++--- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index 8e35893bef..0c31874046 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -18,12 +18,12 @@ "xtts_v1.1": { "description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.", "hf_url": [ - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/model.pth", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/config.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/vocab.json", - "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1/hash.md5" + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5" ], - "model_hash": "10163afc541dc86801b33d1f3217b456", + "model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad", "default_vocoder": null, "commit": "82910a63", "license": "CPML", diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py index b968559047..4e5031ba5a 100644 --- a/TTS/tts/configs/xtts_config.py +++ b/TTS/tts/configs/xtts_config.py @@ -78,13 +78,13 @@ class XttsConfig(BaseTTSConfig): ) # inference params - temperature: float = 0.2 + temperature: float = 0.85 length_penalty: float = 1.0 repetition_penalty: float = 2.0 top_k: int = 50 - top_p: float = 0.8 + top_p: float = 0.85 cond_free_k: float = 2.0 diffusion_temperature: float = 1.0 - num_gpt_outputs: int = 16 + num_gpt_outputs: int = 1 decoder_iterations: int = 30 decoder_sampler: str = "ddim" diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 76c5595ec3..40e8f946c6 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -821,8 +821,6 @@ def load_checkpoint( self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path) self.init_models() - if eval: - self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else [] @@ -831,7 +829,14 @@ def load_checkpoint( for key in list(checkpoint.keys()): if key.split(".")[0] in ignore_keys: del checkpoint[key] - self.load_state_dict(checkpoint, strict=strict) + + # deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not + try: + self.load_state_dict(checkpoint, strict=strict) + except: + if eval: + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + self.load_state_dict(checkpoint, strict=strict) if eval: if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()