In [None]:
from pathlib import Path

import click
import hydra
import librosa
import numpy as np
import soundfile as sf
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from lightning import LightningModule
from loguru import logger
from omegaconf import OmegaConf
from transformers import AutoTokenizer

from fish_speech.models.vits_decoder.lit_module import VITSDecoder
from fish_speech.utils.file import AUDIO_EXTENSIONS

# register eval resolver
OmegaConf.register_new_resolver("eval", eval)


def load_model(config_name, checkpoint_path, device="cuda")->VITSDecoder:
    hydra.core.global_hydra.GlobalHydra.instance().clear()
    with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
        cfg = compose(config_name=config_name)

    # 加载decoder模型
    model: VITSDecoder = instantiate(cfg.model)
    state_dict = torch.load(
        checkpoint_path,
        map_location=model.device,
    )

    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]

    model.load_state_dict(state_dict, strict=False)
    model.eval()
    model.to(device)
    logger.info("Restored model from checkpoint")

    return model


@torch.no_grad()
def main(
    reference_path,
    text,
    tokenizer = "fishaudio/fish-speech-1",
    output_path = "chatts.wav",
    config_name = "vits_decoder_finetune",
    checkpoint_path = "checkpoints/vq-gan-group-fsq-2x1024.pth",
    device = "cuda",
):
    """
    通过参考音频的风格编码，实现音频风格的迁移
    """
    model:VITSDecoder = load_model(config_name, checkpoint_path, device=device)


    # 确保参考音频文件是有效的音频格式。
    assert (
        reference_path.suffix in AUDIO_EXTENSIONS
    ), f"Expected audio file, got {reference_path.suffix}"
    reference_audio, sr = librosa.load(reference_path, sr=model.sampling_rate)
    reference_audio = torch.from_numpy(reference_audio).to(model.device).float()
    # 将参考音频形状从(1,seq_length)扩展到 (1,seq_length,1)后，使用spec_transform 方法生成参考音频的线性频谱图
    # reference_spec 形状是 (1,N,M)
    reference_spec = model.spec_transform(reference_audio[None])
    # 使用模型的 encode_ref 方法对参考频谱图进行编码，得到参考嵌入
    reference_embedding = model.generator.encode_ref(
        reference_spec,
        torch.tensor([reference_spec.shape[-1]], device=model.device),
    )
    logger.info(
        f"Loaded reference audio from {reference_path}, shape: {reference_audio.shape}"
    )

    # Extract text
    # 使用指定的 tokenizer 对文本进行编码
    tokenizer = AutoTokenizer.from_pretrained(tokenizer)
    encoded_text = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
    logger.info(f"Encoded text: {encoded_text.shape}")
    
    chatts_quantized = np.load("./chatts.npy")
    quantized = torch.from_numpy(chatts_quantized).to(model.device).long()
    
    logger.info(f"Restored VQ features: {quantized.shape}")

    # Decode
    # 基于量化特征、编码文本和参考嵌入生成新的音频片段
    fake_audios = model.generator.decode(
        quantized,
        torch.tensor([quantized.shape[-1]], device=model.device),
        encoded_text,
        torch.tensor([encoded_text.shape[-1]], device=model.device),
        ge=reference_embedding,
    )
    logger.info(
        f"Generated audio: {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
    )

    # Save audio
    fake_audio = fake_audios[0, 0].float().cpu().numpy()
    sf.write(output_path, fake_audio, model.sampling_rate)
    logger.info(f"Saved audio to {output_path}")


if __name__ == "__main__":
    main(reference_path="zhongli.ogg",text="黄金是璃月的财富，是令璃月的心脏搏动的血液。你是否拥有黄金般闪耀的心，就让我拭目以待吧。")
