# XttsConfig

* dataclass
* inherits from `BaseTTSConfig`
    * which inherits from `BaseTrainingConfig`
    * which inherits from `TrainerConfig` which is part of the  ["coqui-tts-trainer" package](https://github.com/idiap/coqui-ai-Trainer)
    * which inherits from `Coqpit` which is part of the ["coqpit" package](https://github.com/coqui-ai/coqpit)
    * which inherits from `Serializable` (part of "coqpit") and `MutableMapping` (part of the standard library)
    * only the classes in "coqpit" provide methods, the others only add fields
* `#load_json` loads the json file as dict and calls `self = self.deserialize(dump_dict)` and then `self.check_values()`
    * `self = self.deserialize(dump_dict)` creates the dataclass from the dict.
    * `self.check_values()` does nothing as it is only a dummy method by default and is not overwritten by the subclasses

In [None]:
# config = XttsConfig()
# config.load_json("/path/to/xtts/config.json")

# Xtts

* Xtts model implementation (only supports inference)
* inherits from `BaseTTS`
    * which inhertis from `BaseTrainerModel`
    * which inherits from `TrainerModel` which is an abstract class which requires the methods `init_from_config`, `inference` and `load_checkpoint` to be defined by subclasses (and that's all it does)
    * which inherits from `TrainerModel` which is part of the  ["coqui-tts-trainer" package](https://github.com/idiap/coqui-ai-Trainer)
    * which inherits from `ABC` and `nn.Module` and also defines mostly abstract methods
* `#init_from_config`
    * just calls `Xtts(config)`
* `#__init__`
    * calls `super().__init__(config, ap=None, tokenizer=None)`
        * super method just sets some instance attributes (directly or via `self._set_model_args(config)`)
    * sets the tokenizer and some instance attributes
    * calls `self.init_models()`
        * sets up `self.gpt` and `self.hifigan_decoder` 
    * calls `self.register_buffer("mel_stats", torch.ones(80))`

In [None]:
# model = Xtts.init_from_config(config)
# model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)

# Xtts#synthesize

* sets some settings
* calls `self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)` (if speaker embeddings are loaded from file, let's focus on that for now)
    * tokenizes the text
    * calls `self.gpt.generate(...)` 

In [None]:
# outputs = model.synthesize(
#     "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
#     config,
#     speaker_wav="/data/TTS-public/_refclips/3.wav",
#     gpt_cond_len=3,
#     language="en",
# )

## Tokenizer
```python
    def encode(self, txt, lang):
        lang = lang.split("-")[0]  # remove the region
        self.check_input_length(txt, lang)
        txt = self.preprocess_text(txt, lang)
        lang = "zh-cn" if lang == "zh" else lang
        txt = f"[{lang}]{txt}"
        txt = txt.replace(" ", "[SPACE]")
        return self.tokenizer.encode(txt).ids
```

# Implementation

In [None]:
from pathlib import Path

In [None]:
MODEL_PATH = Path("./XTTS-v2")
SPEAKER = "Aaron Dreschner"
LANG = "en"
TEXT = "Hi! This is a test!"

## Load config

In [None]:
import json

In [None]:
with open(MODEL_PATH / "config.json") as f:
    config = json.load(f)
print(json.dumps(config, sort_keys=True, indent=4))

## Load Speaker Embeddings

In [None]:
# torch==2.5.1
import torch

In [None]:
speakers = torch.load(MODEL_PATH / "speakers_xtts.pth", weights_only=True)
print(sorted(speakers.keys()))

In [None]:
gpt_cond_latent = speakers[SPEAKER]['gpt_cond_latent']
speaker_embedding = speakers[SPEAKER]['speaker_embedding']
print(gpt_cond_latent.shape, speaker_embedding.shape)

## Tokenizer

In [None]:
# transformers==4.46.2 (for tokenizers==0.20.3)
import re
from tokenizers import Tokenizer

In [None]:
class VoiceBpeTokenizer:
    CHAR_LIMITS = {
        "en": 250,
        "de": 253,
        "fr": 273,
        "es": 239,
        "it": 213,
        "pt": 203,
        "pl": 224,
        "zh": 82,
        "ar": 166,
        "cs": 186,
        "ru": 182,
        "nl": 251,
        "tr": 226,
        "ja": 71,
        "hu": 224,
        "ko": 95,
        "hi": 150,
    }

    def __init__(self, vocab_file):
        self.tokenizer = Tokenizer.from_file(str(vocab_file))

    def encode(self, text, lang):
        lang = lang.split("-")[0]  # remove the region
        self.check_input_length(text, lang)
        text = self.preprocess_text(text)
        text = f"[{lang}]{text}"
        text = text.replace(" ", "[SPACE]")
        return self.tokenizer.encode(text)

    def check_input_length(self, text, lang):
        limit = self.CHAR_LIMITS.get(lang, 250)
        if len(text) > limit:
            print(f"The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.")

    def preprocess_text(self, text):
        # Original has some more stuff, but seems incomplete even there.
        return re.sub(r"\s+", " ", text.lower().replace('"', ""))

In [None]:
tokenizer = VoiceBpeTokenizer(MODEL_PATH / "vocab.json")
token_encoding = tokenizer.encode(TEXT, LANG)
print(token_encoding.tokens)

In [None]:
max(tokenizer.tokenizer.get_vocab().values())

## Initialize GPT

In [None]:
from torch import nn
from torch.functional import F
from transformers import GPT2Config, GPT2Model
from transformers import GenerationMixin, GPT2PreTrainedModel
from functools import partial

def null_position_embeddings(range_, dim):
    return torch.zeros((range_.shape[0], range_.shape[1], dim), device=range_.device)

class LearnedPositionEmbeddings(nn.Module):
    def __init__(self, seq_len, model_dim, init=0.02, relative=False):
        super().__init__()
        # nn.Embedding
        self.emb = torch.nn.Embedding(seq_len, model_dim)
        # Initializing this way is standard for GPT-2
        self.emb.weight.data.normal_(mean=0.0, std=init)
        self.relative = relative
        self.seq_len = seq_len

    def forward(self, x):
        sl = x.shape[1]
        if self.relative:
            start = random.randint(sl, self.seq_len) - sl
            return self.emb(torch.arange(start, start + sl, device=x.device))
        else:
            return self.emb(torch.arange(0, sl, device=x.device))

    def get_fixed_embedding(self, ind, dev):
        return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)

In [None]:
model_config = config["model_args"]

class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
    max_seq_len = model_config["gpt_max_audio_tokens"] + model_config["gpt_max_text_tokens"] + model_config["gpt_max_prompt_tokens"] + 1
    gpt_config = GPT2Config(
        vocab_size=model_config["gpt_max_audio_tokens"],
        n_positions=max_seq_len,
        n_ctx=max_seq_len,
        n_embd=model_config["gpt_n_model_channels"],
        n_layer=model_config["gpt_layers"],
        n_head=model_config["gpt_n_heads"],
        gradient_checkpointing=False,
        use_cache=True,
    )

    def __init__(self):
        super().__init__(self.gpt_config)

        self.text_emb = nn.Embedding(model_config["gpt_number_text_tokens"], self.gpt_config.n_embd)
        self.text_pos_emb = LearnedPositionEmbeddings(model_config["gpt_max_text_tokens"] + 2, self.gpt_config.n_embd)
        self.mel_emb = nn.Embedding(model_config["gpt_num_audio_tokens"], self.gpt_config.n_embd)
        self.mel_pos_emb = LearnedPositionEmbeddings(model_config["gpt_max_audio_tokens"] + 2 + 1, self.gpt_config.n_embd)
        
        self.transformer = GPT2Model(self.gpt_config)
        del self.transformer.wpe
        self.transformer.wpe = partial(null_position_embeddings, dim=self.gpt_config.n_embd)
        # Built-in token embeddings are unused.
        del self.transformer.wte

        self.final_norm = nn.LayerNorm(self.gpt_config.n_embd)
        self.mel_head = nn.Linear(self.gpt_config.n_embd, model_config["gpt_num_audio_tokens"])
        

    def compute_embeddings(self, cond_latents, text_tokens):
        # Add start/end token to the start/end of the sequences. 
        text_tokens = F.pad(text_tokens, (1, 0), value=tokenizer.tokenizer.token_to_id('[START]'))
        text_tokens = F.pad(text_tokens, (0, 1), value=tokenizer.tokenizer.token_to_id('[STOP]'))
        
        emb = self.text_emb(text_tokens) + self.text_pos_emb(text_tokens)
        emb = torch.cat([cond_latents, emb], dim=1)
        self.cached_prefix_emb = emb
        gpt_inputs = torch.full(
            (
                emb.shape[0],
                emb.shape[1] + 1,  # +1 for the start_audio_token
            ),
            fill_value=1,
            dtype=torch.long,
            device=text_tokens.device,
        )
        gpt_inputs[:, -1] = model_config["gpt_start_audio_token"]
        return gpt_inputs

    def forward(self, input_ids, attention_mask, position_ids, past_key_values=None, *args, **kwargs):
        prefix_emb = self.cached_prefix_emb
        prefix_len = prefix_emb.shape[1]
        if input_ids.shape[-1] != 1:
            # Only for first generation step.
            gen_inputs = input_ids[:, prefix_len:]
            gen_emb = self.mel_emb(gen_inputs)
            gen_emb = gen_emb + self.mel_pos_emb(gen_emb)
            emb = torch.cat([prefix_emb, gen_emb], dim=1)
        else:
            emb = self.mel_emb(input_ids)
            emb = emb + self.mel_pos_emb.get_fixed_embedding(
                attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
            )
        outputs = self.transformer(inputs_embeds=emb, attention_mask=attention_mask,
                                   position_ids=position_ids, past_key_values=past_key_values,
                                   use_cache=True)
        outputs.logits = self.mel_head(self.final_norm(outputs.last_hidden_state))
        return outputs

In [None]:
gpt = GPT2InferenceModel()
gpt.eval()

In [None]:
text_tokens = torch.tensor(token_encoding.ids, dtype=torch.int32).unsqueeze(0)
cond_latents = gpt_cond_latent
input_ids = gpt.compute_embeddings(cond_latents, text_tokens)
print(input_ids)

In [None]:
outputs = gpt.forward(input_ids)
print(outputs)

In [None]:
output = gpt.generate(
    input_ids,
    bos_token_id=model_config["gpt_start_audio_token"],
    pad_token_id=model_config["gpt_stop_audio_token"],
    eos_token_id=model_config["gpt_stop_audio_token"],
    max_length=100 # TODO
)
print(output)

## Load Model

In [None]:
import pickle
import torch
import io
import struct
import numpy as np

class CustomDict(dict):
    @property
    def __dict__(self):
        return self

class TensorWrapper:
    def __init__(self, size, dtype=torch.float32, device='cpu'):
        self.size = size
        self.dtype = dtype
        self.device = device
        
    def set_(self, storage, offset, size, stride):
        if not isinstance(size, tuple):
            size = tuple(size)
        # Handle empty size case
        if len(size) == 0:
            return torch.tensor(storage.data[offset], dtype=storage.dtype)
        # Handle scalar case
        if size == (1,):
            return torch.tensor(storage.data[offset], dtype=storage.dtype)
        
        # Calculate total number of elements needed based on size
        total_elements = np.prod(size)
        
        # Slice the storage data using the offset
        data_slice = storage.data[offset:offset + total_elements]
        
        # Regular case with proper offset handling
        return torch.from_numpy(data_slice).to(storage.dtype).view(size)

class CustomStorage:
    def __init__(self, data, dtype):
        self.dtype = dtype
        self.data = data
        
    @property
    def _untyped_storage(self):
        return self

    @property
    def device(self):
        return 'cpu'

class CustomUnpickler(pickle.Unpickler):
    """Custom unpickler that substitutes dummy classes for missing ones."""
    
    DTYPE_MAP = {
        'FloatStorage': (torch.float32, np.float32),
        'LongStorage': (torch.int64, np.int64),
        'IntStorage': (torch.int32, np.int32),
        'BoolStorage': (torch.bool, np.bool_),
    }
    
    def __init__(self, file, zip_archive):
        super().__init__(file)
        self.zip_archive = zip_archive
        
    def persistent_load(self, pid):
        """Handle persistent ID loading by returning a dummy storage with actual data."""
        print(f"Persistent load: {pid}")
        if isinstance(pid, tuple) and pid[0] == 'storage':
            storage_type, storage_class, key, location, numel = pid
            
            # Handle the case where storage_class is already CustomStorage
            if isinstance(storage_class, type) and storage_class.__name__ == 'CustomStorage':
                storage_name = 'FloatStorage'  # default to float if we get CustomStorage
            else:
                storage_name = storage_class.__name__
            
            torch_dtype, np_dtype = self.DTYPE_MAP.get(storage_name, (torch.float32, np.float32))
            
            try:
                tensor_data = self.load_tensor_data(key, numel, np_dtype)
                return CustomStorage(tensor_data, dtype=torch_dtype)
            except Exception as e:
                print(f"Failed to load tensor data: {e}")
                return CustomStorage(
                    np.zeros(numel, dtype=np_dtype), 
                    dtype=torch_dtype
                )
                
        return pid[1]
        
    def load_tensor_data(self, key, numel, dtype):
        """Load tensor data from the zip file."""
        data_file = f'model/data/{key}'
        with self.zip_archive.open(data_file, 'r') as f:
            data_bytes = f.read()
            return np.frombuffer(data_bytes, dtype=dtype)
        
    def find_class(self, module, name):
        """Override find_class to return dummy classes for missing ones."""
        print(f"Finding: {module}.{name}")
        try:
            if module == 'torch._utils' and name == '_rebuild_tensor_v2':
                return self._rebuild_tensor_v2
            if module == 'torch._utils' and name == '_rebuild_tensor':
                return self._rebuild_tensor
            if module == 'torch' and name in self.DTYPE_MAP:
                return CustomStorage
            return super().find_class(module, name)
        except:
            return CustomDict
            
    def _rebuild_tensor_v2(self, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
        """Custom tensor rebuilding function."""
        tensor = self._rebuild_tensor(storage, storage_offset, size, stride)
        tensor.requires_grad = requires_grad
        return tensor
    
    def _rebuild_tensor(self, storage, storage_offset, size, stride):
        """Create a new tensor with the given size and data."""
        wrapper = TensorWrapper(size, dtype=storage.dtype)
        return wrapper.set_(storage, storage_offset, size, stride)

In [None]:
import zipfile
with zipfile.ZipFile(MODEL_PATH / "model.pth", 'r') as archive:
    with archive.open('model/data.pkl', 'r') as f:
        data = CustomUnpickler(f, archive).load()

In [None]:
data.keys()

In [None]:
import re
gpt_state_dict = {re.sub(r"^gpt\.gpt", "transformer", key): value for key, value in data['model'].items() if key.startswith("gpt.gpt.")}
gpt_state_dict['text_emb.weight'] = data['model']['gpt.text_embedding.weight']
gpt_state_dict['text_pos_emb.emb.weight'] = data['model']['gpt.text_pos_embedding.emb.weight']
gpt_state_dict['mel_emb.weight'] = data['model']['gpt.mel_embedding.weight']
gpt_state_dict['mel_pos_emb.emb.weight'] = data['model']['gpt.mel_pos_embedding.emb.weight']
gpt_state_dict['final_norm.weight'] = data['model']['gpt.final_norm.weight']
gpt_state_dict['final_norm.bias'] = data['model']['gpt.final_norm.bias']
gpt_state_dict['mel_head.weight'] = data['model']['gpt.mel_head.weight']
gpt_state_dict['mel_head.bias'] = data['model']['gpt.mel_head.bias']

assert set(p[0] for p in gpt.named_parameters()) == set(gpt_state_dict.keys())

In [None]:
gpt.load_state_dict(gpt_state_dict)

In [None]:
text_tokens = torch.tensor(token_encoding.ids, dtype=torch.int32).unsqueeze(0)
cond_latents = gpt_cond_latent
input_ids = gpt.compute_embeddings(cond_latents, text_tokens)

with torch.no_grad():
    output = gpt.generate(
        input_ids,
        bos_token_id=model_config["gpt_start_audio_token"],
        pad_token_id=model_config["gpt_stop_audio_token"],
        eos_token_id=model_config["gpt_stop_audio_token"],
        do_sample=False,
        top_p=0.85,
        top_k=50,
        temperature=0.75,
        num_return_sequences=1,
        num_beams=1,
        length_penalty=1.0,
        repetition_penalty=5.0,
        max_new_tokens=model_config['gpt_max_audio_tokens']
    )
print(output)

## Decoder

In [None]:
text_inputs = F.pad(text_tokens, (1, 0), value=tokenizer.tokenizer.token_to_id('[START]'))
text_inputs = F.pad(text_inputs, (0, 1), value=tokenizer.tokenizer.token_to_id('[STOP]'))
text_emb = gpt.text_emb(text_inputs) + gpt.text_pos_emb(text_inputs)

gpt_codes = output[:, input_ids.shape[1]:]
code_stride_len = model_config['gpt_code_stride_len']
expected_output_len = gpt_codes.shape[-1] * code_stride_len
code_lengths = torch.ceil(torch.tensor([expected_output_len]) / code_stride_len).long() + 3
max_mel_len = code_lengths.max()
audio_codes = F.pad(gpt_codes, (0, max_mel_len - gpt_codes.shape[-1]))
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=model_config["gpt_stop_audio_token"])
audio_codes[0, code_lengths[0] - 3:] = model_config["gpt_stop_audio_token"]
audio_codes = F.pad(audio_codes, (1, 0), value=model_config["gpt_start_audio_token"])
audio_codes = F.pad(audio_codes, (0, 1), value=model_config["gpt_stop_audio_token"])

mel_emb = gpt.mel_emb(audio_codes) + gpt.mel_pos_emb(audio_codes)

emb = torch.cat([cond_latents, text_emb, mel_emb], dim=1)

In [None]:
gpt_out = gpt.transformer(inputs_embeds=emb, return_dict=True)

In [None]:
enc = gpt.final_norm(gpt_out.last_hidden_state[:, cond_latents.shape[1]:])
enc

In [None]:
mel_logits = enc[:, -mel_emb.shape[1]:]  # These are not really logits, but latents
mel_logits.shape

In [None]:
import sys
sys.path.append("coqui-ai-TTS/TTS/tts/layers/xtts")

from hifigan_decoder import HifiDecoder

In [None]:
gpt_latents = mel_logits

In [None]:
hifigan_decoder = HifiDecoder()
state_dict = {key[len('hifigan_decoder.'):]: value for key, value in data['model'].items() if key.startswith("hifigan_decoder.")}
hifigan_decoder.load_state_dict(state_dict)
hifigan_decoder.eval()

In [None]:
wav = hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()

In [None]:
from IPython.display import Audio, display
audio_widget = Audio(data=wav.detach().numpy(), rate=24000)
display(audio_widget)