# Music Video Synthesis
* Extract lyrics from song with timestamps
* Compose scenes, include timestamps
* Construct video text prompt for each scene
* Build videos for each scene
* Stitch together

# We will use openai whipser for stability

In [None]:
#!pip install --quiet --upgrade pip
#!pip3 install torch==2.4.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
#!pip install --quiet --upgrade openai-whisper openai
# Ubuntu or Debian
#!sudo apt update && sudo apt install ffmpeg
#!pip install setuptools-rust
#!pip install -U diffusers imageio imageio_ffmpeg opencv-python moviepy transformers huggingface-hub optimum pillow safetensors optimum-quanto accelerate
#!pip install --upgrade optimum-quanto torchao --extra-index-url https://download.pytorch.org/whl/cu124 # full options are cpu/cu118/cu121/cu124
#!pip install git+https://github.com/xhinker/sd_embed.git@main
#!pip install accelerate flash_attention numba -U
#!pip install flash_attn --no-build-isolation
#!pip install -r requirements.txt -U
#!pip install protobuf==3.20.3
#!pip install --upgrade google-api-core google-cloud

In [None]:
import codecs
import cv2
import diffusers
import gc
import imageio
import imageio_ffmpeg
import json
import math
import moviepy.editor as mp
import numpy as np
import os
import random
import re
import soundfile as sf
import tempfile
import time
import transformers
import torch
import torch.multiprocessing as mp
import torchaudio
import tqdm
import whisper

from cached_path import cached_path
from contextlib import contextmanager
from datetime import datetime, timedelta
from diffusers import AutoencoderKL, AutoencoderKLCogVideoX, AutoPipelineForText2Image, CogVideoXTransformer3DModel, CogVideoXPipeline, CogVideoXDPMScheduler
from diffusers import CogVideoXTransformer3DModel, CogVideoXImageToVideoPipeline, FlowMatchEulerDiscreteScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from diffusers.utils import export_to_video, load_video, load_image
from einops import rearrange
from huggingface_hub import hf_hub_download, snapshot_download
from model import CFM, DiT, MMDiT, UNetT
from model.utils import (convert_char_to_pinyin, get_tokenizer,
                         load_checkpoint, save_spectrogram)
from model4 import T5EncoderModel as m_T5EncoderModel, FluxTransformer2DModel
from numba import cuda
from openai import OpenAI
from optimum.quanto import freeze, qfloat8, quantize, requantize
from pathlib import Path
from PIL import Image
from pydub import AudioSegment, silence
from safetensors.torch import load_file as load_safetensors, save_file as save_safetensors
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
from tempfile import NamedTemporaryFile
from torchaudio import transforms
from torchao.quantization import quantize_, int8_weight_only, int8_dynamic_activation_int8_weight
from transformers import pipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5TokenizerFast, T5EncoderModel as t_T5EncoderModel
from vocos import Vocos

os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Define the paths where quantized weights will be saved

dtype = torch.bfloat16
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
MAX_SEED = np.iinfo(np.int32).max
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "CPU"
)
print(f"Using {device} device")
retry_limit = 3

In [None]:
# Configuration
CONFIG = {
    "openai_api_key": "",
    "openai_model": "gpt-4o-mini",
    "openai_model_large": "gpt-4o",
    "hf_token": "",
    "base_working_dir": "./images",
    "base_video_dir": "./output",
    "goals": [
        '''The goal of the animation is to craft a relatable, emotionally resonant story that inspires viewers to reflect on their lives and encourages a shift in their mindset. The movie should be fun and catchy while delivering a deeper message about self-acceptance and spiritual growth.''',
        # Add more goals
    ],
    "device": device,
    "dtype": dtype,
    "retry_limit": retry_limit,
    "MAX_SEED": MAX_SEED,
}
scenes_file_path = './output/scenes.json'

# Ensure base directories exist
os.makedirs(CONFIG["base_working_dir"], exist_ok=True)
os.makedirs(CONFIG["base_video_dir"], exist_ok=True)

# Configurations
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
remove_silence_default = True

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32  # 16, 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0
# fix_duration = 27  # None or float (duration in seconds)
fix_duration = None

In [None]:
def reset_memory(device):
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.reset_accumulated_memory_stats(device)
    
def get_openai_prompt_response(
    prompt: str,
    config: dict,
    max_tokens: int = 6000,
    temperature: float = 0.33,
    openai_model: str = "",
):
    """
    Sends a prompt to OpenAI's API and retrieves the response with retry logic.
    """
    client = OpenAI(api_key=config["openai_api_key"])
    response = client.chat.completions.create(
        max_tokens=max_tokens,
        messages=[
            {
                "role": "system",
                "content": """Act as a helpful assistant, you are an expert editor.""",
            },
            {"role": "user", "content": prompt},
        ],
        model=openai_model or config["openai_model"],
        temperature=temperature,
    )

    retry_count = 0
    while retry_count < config["retry_limit"]:
        try:
            message_content = response.choices[0].message.content
            return message_content
        except Exception as e:
            print(f"Error occurred: {e}")
            retry_count += 1
            if retry_count == config["retry_limit"]:
                print("Retry limit reached. Moving to the next iteration.")
                return ""
            else:
                print(f"Retrying... (Attempt {retry_count}/{config['retry_limit']})")
                time.sleep(1)  # Optional: wait before retrying


def load_flux_pipe():
    text_encoder_2 = m_T5EncoderModel.from_pretrained(
        "HighCWu/FLUX.1-dev-4bit",
        subfolder="text_encoder_2",
        torch_dtype=torch.bfloat16,
        # hqq_4bit_compute_dtype=torch.float32,
    )
    
    transformer = FluxTransformer2DModel.from_pretrained(
        "HighCWu/FLUX.1-dev-4bit",
        subfolder="transformer",
        torch_dtype=torch.bfloat16,
    )
    
    pipe = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        text_encoder_2=text_encoder_2,
        transformer=transformer,
        torch_dtype=torch.bfloat16,
    )
    #pipe.enable_model_cpu_offload() # with cpu offload, it cost 8.5GB vram
    pipe.remove_all_hooks()
    pipe = pipe.to('cuda') # without cpu offload, it cost 11GB vram

    del text_encoder_2
    del transformer

    return pipe


def gen_flux_image(pipe, prompt, config: dict, height=1024, width=1024, guidance_scale=3.5, num_inference_steps=4, max_sequence_length=512, seed=-1):
    """
    Generates an image based on the provided prompt using the Flux pipeline.
    """
    if seed == -1:
        seed = random.randint(0, MAX_SEED)
        
    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(
            pipe        = pipe,
            prompt    = prompt
        )
        
        image = pipe(
            prompt_embeds               = prompt_embeds,
            pooled_prompt_embeds      = pooled_prompt_embeds,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            output_type="pil",
            num_inference_steps=num_inference_steps,
            max_sequence_length=max_sequence_length,
            generator=torch.Generator("cpu").manual_seed(seed)
        ).images[0]

        return image


def load_video_pipeline():
    """
    Loads and configures the video generation pipeline.
    """
    quantization = int8_weight_only

    text_encoder = t_T5EncoderModel.from_pretrained("THUDM/CogVideoX-5b", subfolder="text_encoder", torch_dtype=torch.bfloat16)
    quantize_(text_encoder, quantization())
    
    transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16)
    quantize_(transformer, quantization())
    
    i2v_transformer = CogVideoXTransformer3DModel.from_pretrained(
        "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
    )
    
    vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-5b", subfolder="vae", torch_dtype=torch.bfloat16)
    quantize_(vae, quantization())
    
    # Create pipeline and run inference
    pipe = CogVideoXPipeline.from_pretrained(
        "THUDM/CogVideoX-5b",
        text_encoder=text_encoder,
        transformer=transformer,
        vae=vae,
        torch_dtype=torch.bfloat16,
    )
    #pipe.enable_model_cpu_offload()
    pipe.vae.enable_tiling()
    
    pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
    
    i2v_vae=pipe.vae
    i2v_scheduler=pipe.scheduler
    i2v_tokenizer=pipe.tokenizer
    i2v_text_encoder=pipe.text_encoder
    
    del pipe
    reset_memory(device)
    
    # Load the pipeline once before the loop
    pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
        "THUDM/CogVideoX-5b-I2V",
        transformer=i2v_transformer,
        vae=i2v_vae,
        scheduler=i2v_scheduler,
        tokenizer=i2v_tokenizer,
        text_encoder=i2v_text_encoder,
        torch_dtype=torch.bfloat16,
    ).to(device)

    #pipe_image.transformer.to(memory_format=torch.channels_last)
    #pipe_image.transformer = torch.compile(pipe_image.transformer, fullgraph=True)

    return pipe_image


def infer_vid(pipe_image, prompt: str, image_input: str, config: dict, num_inference_steps: int = 50, guidance_scale: float = 7.0, seed: int = -1, num_frames: int = 49):
    """
    Generates video frames from an image and prompt using the video pipeline.
    """
    if seed == -1:
        seed = random.randint(0, 255)

    image_input = Image.open(image_input).resize(size=(720, 480))  # Convert to PIL
    image = load_image(image_input)

    video_pt = pipe_image(
        image=image,
        prompt=prompt,
        num_inference_steps=num_inference_steps,
        num_videos_per_prompt=1,
        use_dynamic_cfg=True,
        output_type="pt",
        guidance_scale=guidance_scale,
        generator=torch.Generator(device="cpu").manual_seed(seed),
        num_frames=num_frames,
    ).frames

    return video_pt, seed


def generate_video(pipe_image, prompt, image_input, config: dict, seed_value: int = -1, video_filename: str = "", num_frames: int = 49):
    """
    Generates and saves a video from the provided image and prompt.
    """
    latents, seed = infer_vid(
        pipe_image,
        prompt,
        image_input,
        config,
        num_inference_steps=50,
        guidance_scale=7.0,
        seed=seed_value,
        num_frames=num_frames,
    )
    batch_size = latents.shape[0]
    batch_video_frames = []
    for batch_idx in range(batch_size):
        pt_image = latents[batch_idx]
        pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
        image_np = VaeImageProcessor.pt_to_numpy(pt_image)
        image_pil = VaeImageProcessor.numpy_to_pil(image_np)
        batch_video_frames.append(image_pil)

    video_path = save_video(
        batch_video_frames[0],
        fps=math.ceil((len(batch_video_frames[0]) - 1) / 6),
        filename=video_filename
    )
    # After processing frames
    del latents
    del batch_video_frames
    reset_memory(device)
    return video_path


def save_video(frames, fps: int, filename: str):
    """
    Saves a list of frames as a video file.
    """
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
        temp_video_path = temp_file.name
        writer = imageio.get_writer(temp_video_path, fps=fps)
        for frame in frames:
            writer.append_data(np.array(frame))
        writer.close()

    os.rename(temp_video_path, filename)
    return filename


def convert_to_gif(video_path: str) -> str:
    """
    Converts a video file to a GIF.
    """
    clip = mp.VideoFileClip(video_path)
    clip = clip.set_fps(8)
    clip = clip.resize(height=240)
    gif_path = video_path.replace(".mp4", ".gif")
    clip.write_gif(gif_path, fps=8)
    return gif_path


def resize_if_unfit(input_video: str) -> str:
    """
    Resizes the video to the target dimensions if it does not match.
    """
    width, height = get_video_dimensions(input_video)

    if width == 720 and height == 480:
        return input_video
    else:
        return center_crop_resize(input_video)


def get_video_dimensions(input_video_path: str) -> tuple:
    """
    Retrieves the dimensions of the video.
    """
    reader = imageio_ffmpeg.read_frames(input_video_path)
    metadata = next(reader)
    return metadata["size"]


def center_crop_resize(input_video_path: str, target_width: int = 720, target_height: int = 480) -> str:
    """
    Resizes and center-crops the video to the target dimensions.
    """
    cap = cv2.VideoCapture(input_video_path)

    orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    orig_fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    width_factor = target_width / orig_width
    height_factor = target_height / orig_height
    resize_factor = max(width_factor, height_factor)

    inter_width = int(orig_width * resize_factor)
    inter_height = int(orig_height * resize_factor)

    target_fps = 8
    ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
    skip = min(5, ideal_skip)  # Cap at 5

    while (total_frames / (skip + 1)) < 49 and skip > 0:
        skip -= 1

    processed_frames = []
    frame_count = 0
    total_read = 0

    while frame_count < 49 and total_read < total_frames:
        ret, frame = cap.read()
        if not ret:
            break

        if total_read % (skip + 1) == 0:
            resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)

            start_x = (inter_width - target_width) // 2
            start_y = (inter_height - target_height) // 2
            cropped = resized[start_y:start_y + target_height, start_x:start_x + target_width]

            processed_frames.append(cropped)
            frame_count += 1

        total_read += 1

    cap.release()

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
        temp_video_path = temp_file.name
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height))

        for frame in processed_frames:
            out.write(frame)

        out.release()

    return temp_video_path


def extract_last_frame(video_filename: str, output_image_filename: str):
    """
    Extracts the last frame from a video file and saves it as an image.
    """
    try:
        reader = imageio.get_reader(video_filename, 'ffmpeg')
        last_frame = None
        for frame in reader:
            last_frame = frame
        reader.close()

        if last_frame is not None:
            imageio.imwrite(output_image_filename, last_frame)
            print(f"Last frame saved successfully as '{output_image_filename}'.")
        else:
            print("The video contains no frames.")

    except FileNotFoundError:
        print(f"Error: The file '{video_filename}' was not found.")
    except ValueError as ve:
        print(f"ValueError: {ve}")
    except RuntimeError as re:
        print(f"RuntimeError: {re}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

def revise_scenes(scenes, config: dict):
    """
    Revise scenes based on the extracted scenes.
    """
    # Generate scenes JSON
    prompt = f'''Revise the JSON scenes to update the scene_description and action_sequence to engage the senses and imagination, suitable for creating a stunning, cinematic video experience.  Use descriptions of special effects in the scenes.  JSON scenes: {scenes}   
The scene_description (200 words or less) should include details such as attire, setting, mood, lighting, and any significant movements or expressions, painting a clear visual scene consistent with the video theme and different from other scenes.
The action_sequence (30 words or less) should describe the action in the scene.  The goal is to create input to create a stunning, cinematic video experience.
Only update the scene_description and action_sequence.  Do not delete any items as having scenes with the given start times are important.  We do not want to have the same scene_description and action_sequence for the items with repeatitive input text.  Please change these to be creative and consistent with dynamic video sequences.
Return only the json list, less jargon. The json list fields should be: start, text, scene_description, action_sequence'''

    result = get_openai_prompt_response(prompt, config, openai_model=config["openai_model"], temperature=0.33)
    result = result.replace("```", "").replace("```json\n", "").replace("json\n", "").replace("\n", "")
    scenes = json.loads(result)
    return scenes


def process_audio_scenes(scenes, config: dict):
    # Create unique identifier based on audio file name
    audio_basename = 'video'
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    unique_id = f"{audio_basename}_{timestamp}"

    # Create unique directories for images and videos
    print(f"Create unique directories for images and videos")
    audio_images_dir = os.path.join(config["base_working_dir"], unique_id)
    audio_videos_dir = os.path.join(config["base_video_dir"], unique_id)
    os.makedirs(audio_images_dir, exist_ok=True)
    os.makedirs(audio_videos_dir, exist_ok=True)

    return scenes, audio_images_dir, audio_videos_dir

def process_audio_images(config: dict, scenes, audio_images_dir):
    # Step 4: Load Flux pipeline and generate images
    print(f"Load Flux pipeline and generate images")
    flux_pipe = load_flux_pipe()
    height = 480
    width = 720
    guidance_scale = 3.9
    num_inference_steps = 48
    max_sequence_length = 512
    seed = -1

    # Generate images for each scene
    image_num = 1
    for scene in scenes:
        image_prompt = scene['scene_description']
        image = gen_flux_image(flux_pipe, image_prompt, config, height, width, guidance_scale, num_inference_steps, max_sequence_length, seed)
        filename = f"image_{str(image_num).zfill(2)}.jpg"
        image_path = os.path.join(audio_images_dir, filename)
        image.save(image_path, dpi=(300, 300))
        image_num += 1

    # Move the pipeline back to CPU and delete it
    flux_pipe.to('cpu')
    del flux_pipe
    reset_memory(device)
    return

def process_audio_video(config: dict, scenes, audio_images_dir, audio_videos_dir):
    # Step 6: Load Video Pipeline
    print(f"Load Video Pipeline")
    video_pipe = load_video_pipeline()

    # Temporary image path
    temp_image = os.path.join(audio_images_dir, "temp_image.jpg")
    video_num = 1

    # Step 7: Generate video sequences
    for i, scene in enumerate(scenes):
        prompt = scene["action_sequence"]

        # Use the initial image for each scene
        image_input = os.path.join(audio_images_dir, f"image_{str(i+1).zfill(2)}.jpg")

        num_video_segments = 2

        print(f"Scene {i+1} has {num_video_segments} segments")
        for j in range(num_video_segments):
            video_name = f"video_{str(video_num).zfill(2)}_{str(j+1).zfill(2)}.mp4"
            video_output_path = os.path.join(audio_videos_dir, video_name)
            generate_video(video_pipe, prompt, image_input, config, seed_value=-1, video_filename=video_output_path)
            time.sleep(1)  # Pause for 1 second

            # After generating the video, extract the last frame to use as input for the next segment
            extract_last_frame(video_output_path, temp_image)

            # Use the last frame as input for the next video segment in the same scene
            image_input = temp_image

            video_num += 1  # Increment video number for the next segment

    # Move the pipeline back to CPU before deleting
    video_pipe.to('cpu')
    del video_pipe
    reset_memory(device)
    
    return


def process_all_audios(scenes, config: dict):
    """
    Processes a list of audio files through the workflow.
    """
    scenes, audio_images_dir, audio_videos_dir = process_audio_scenes(scenes, config)
    # Create starting images for scenes
    process_audio_images(config, scenes, audio_images_dir)
    return config, scenes, audio_images_dir, audio_videos_dir

def create_scenes(config, goal):    
    """
    Creates scenes based on the extracted lyrics using OpenAI's API.
    """
    # Generate scenes JSON
    prompt = f'''
I am creating Python code to generate scenes for a short animation.

{goal}

Please generate a series of scenes in JSON format, where each scene is a **complete and independent description** with no knowledge of the previous or future scenes. Each scene should contain the following fields:
    "character": The character speaking in the scene (narrator, country, town, etc.)
    "emotion": The emotion of the character (angry, calm, disgust, fearful, happy, neutral, sad, surprised)
    "text": The narration or dialogue associated with the scene.
    "scene_description": A detailed description of the scene, using evocative language to engage the audience emotionally.
    "action_sequence": A description of the actions taking place in the scene.

Each scene should contribute to the overall goal of the animation but be written in a way that it stands alone.

Here's the format to use:    [ {{
     "character": "The character goes here..."
     "emotion": "The character emotion goes here...",
     "text": "Narration or dialogue goes here...",
     "scene_description": "Detailed description of the scene...",
     "action_sequence": "Description of the actions in the scene..."
    }},
    {{
     "character": "The character goes here..."
     "emotion": "The character emotion goes here...",
     "text": "Next narration or dialogue...",
     "scene_description": "Detailed description of the next scene...",
     "action_sequence": "Description of the actions in the next scene..."
    }}
    // Additional scenes...
    ]
    Return only the json, less jargon.
    '''
    result = get_openai_prompt_response(prompt, config, openai_model=config["openai_model"], temperature=0.66)
    result = result.replace("```", "").replace("```json\n", "").replace("json\n", "").replace("\n", "")
    scenes = json.loads(result)

    # Save the scenes to scenes.json
    with open(scenes_file_path, "w") as scenes_file:
        json.dump(scenes, scenes_file)
        
    return scenes

def get_first_file_in_directory(directory_path):
    """
    Returns the full file path to the first file found in the specified directory.
    If the directory does not contain any files, returns None.
    
    Parameters:
        directory_path (str): The path to the directory to search.
        
    Returns:
        str or None: The full path to the first file found, or None if no files are found.
    """
    try:
        # Iterate over the entries in the directory
        for entry in os.listdir(directory_path):
            full_path = os.path.join(directory_path, entry)
            # Check if the entry is a file (not a directory)
            if os.path.isfile(full_path):
                return full_path  # Return the full path of the first file found
        # No files found in the directory
        return None
    except FileNotFoundError:
        print(f"Error: The directory '{directory_path}' does not exist.")
        return None
    except PermissionError:
        print(f"Error: Permission denied for directory '{directory_path}'.")
        return None
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return None

def load_model(model_cls, model_cfg, ckpt_path, file_vocab):
    if file_vocab == "":
        file_vocab = "Emilia_ZH_EN"
        tokenizer = "pinyin"
    else:
        tokenizer = "custom"

    print("\nvocab : ", file_vocab, tokenizer)
    print("tokenizer : ", tokenizer)
    print("model : ", ckpt_path, "\n")

    vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
    model = CFM(
        transformer=model_cls(
            **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
        ),
        mel_spec_kwargs=dict(
            target_sample_rate=target_sample_rate,
            n_mel_channels=n_mel_channels,
            hop_length=hop_length,
        ),
        odeint_kwargs=dict(
            method=ode_method,
        ),
        vocab_char_map=vocab_char_map,
    ).to(device)

    model = load_checkpoint(model, ckpt_path, device, use_ema=True)

    return model

def chunk_text(text, max_chars=135):
    """
    Splits the input text into chunks, each with a maximum number of characters.
    Args:
        text (str): The text to be split.
        max_chars (int): The maximum number of characters per chunk.
    Returns:
        List[str]: A list of text chunks.
    """
    chunks = []
    current_chunk = ""
    # Split the text into sentences based on punctuation followed by whitespace
    sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[；：，。！？])', text)

    for sentence in sentences:
        if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
            current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
        else:
            if current_chunk:
                chunks.append(current_chunk.strip())
            current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence

    if current_chunk:
        chunks.append(current_chunk.strip())

    return chunks

def infer_batch(ref_audio, ref_text, gen_text_batches, ema_model, remove_silence, vocos, cross_fade_duration=0.15):
    audio, sr = ref_audio
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)

    rms = torch.sqrt(torch.mean(torch.square(audio)))
    if rms < target_rms:
        audio = audio * target_rms / rms
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
        audio = resampler(audio)
    audio = audio.to(device)

    generated_waves = []
    spectrograms = []

    if len(ref_text[-1].encode('utf-8')) == 1:
        ref_text = ref_text + " "

    # Check if gen_text_batches is empty
    if len(gen_text_batches) == 0:
        raise ValueError("gen_text_batches is empty, cannot proceed with inference.")

    for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
        # Prepare the text
        text_list = [ref_text + gen_text]
        final_text_list = convert_char_to_pinyin(text_list)

        # Calculate duration
        ref_audio_len = audio.shape[-1] // hop_length
        zh_pause_punc = r"。，、；：？！"
        ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
        gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
        duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)

        # inference
        with torch.inference_mode():
            try:
                generated, _ = ema_model.sample(
                    cond=audio,
                    text=final_text_list,
                    duration=duration,
                    steps=nfe_step,
                    cfg_strength=cfg_strength,
                    sway_sampling_coef=sway_sampling_coef,
                )
            except Exception as e:
                print(f"Error during model inference: {e}")
                continue  # Skip to the next batch

        # Ensure `generated` is not empty
        if generated is None or generated.shape[0] == 0:
            print(f"Generated is empty for batch {i}, skipping...")
            continue

        generated = generated[:, ref_audio_len:, :]
        generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
        generated_wave = vocos.decode(generated_mel_spec.cpu())
        if rms < target_rms:
            generated_wave = generated_wave * rms / target_rms

        # wav -> numpy
        generated_wave = generated_wave.squeeze().cpu().numpy()
        
        generated_waves.append(generated_wave)
        spectrograms.append(generated_mel_spec[0].cpu().numpy())

    # If no waves were generated, raise an error
    if len(generated_waves) == 0:
        raise ValueError("No waves were generated, something went wrong during inference.")

    # Combine all generated waves with cross-fading
    if cross_fade_duration <= 0:
        # Simply concatenate
        final_wave = np.concatenate(generated_waves)
    else:
        print(len(generated_waves))
        final_wave = generated_waves[0]
        for i in range(1, len(generated_waves)):
            prev_wave = final_wave
            next_wave = generated_waves[i]

            # Calculate cross-fade samples, ensuring it does not exceed wave lengths
            cross_fade_samples = int(cross_fade_duration * target_sample_rate)
            cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))

            if cross_fade_samples <= 0:
                # No overlap possible, concatenate
                final_wave = np.concatenate([prev_wave, next_wave])
                continue

            # Overlapping parts
            prev_overlap = prev_wave[-cross_fade_samples:]
            next_overlap = next_wave[:cross_fade_samples]

            # Fade out and fade in
            fade_out = np.linspace(1, 0, cross_fade_samples)
            fade_in = np.linspace(0, 1, cross_fade_samples)

            # Cross-faded overlap
            cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in

            # Combine
            new_wave = np.concatenate([
                prev_wave[:-cross_fade_samples],
                cross_faded_overlap,
                next_wave[cross_fade_samples:]
            ])

            final_wave = new_wave

    # Create a combined spectrogram
    combined_spectrogram = np.concatenate(spectrograms, axis=1)

    return final_wave, combined_spectrogram


def process_voice(ref_audio_orig, ref_text):
    print("Converting audio...")
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
        aseg = AudioSegment.from_file(ref_audio_orig)

        non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
        non_silent_wave = AudioSegment.silent(duration=0)
        for non_silent_seg in non_silent_segs:
            non_silent_wave += non_silent_seg
        aseg = non_silent_wave

        audio_duration = len(aseg)
        if audio_duration > 15000:
            print("Audio is over 15s, clipping to only first 15s.")
            aseg = aseg[:15000]
        aseg.export(f.name, format="wav")
        ref_audio = f.name

    if not ref_text.strip():
        print("No reference text provided, transcribing reference audio...")
        #pipe = pipeline(
        #    "automatic-speech-recognition",
        #    model="openai/whisper-large-v3-turbo",
        #    torch_dtype=torch.float16,
        #    device=device,
        #)
        pipe = get_whisper_pipe()
        ref_text = pipe(
            ref_audio,
            chunk_length_s=30,
            batch_size=128,
            generate_kwargs={"task": "transcribe"},
            return_timestamps=False,
        )["text"].strip()
        print("Finished transcription")
    else:
        print("Using custom reference text...")
    return ref_audio, ref_text    

def infer(ref_audio, ref_text, gen_text, remove_silence, ema_model, vocos, cross_fade_duration=0.15):
    print(gen_text)
    # Add the functionality to ensure it ends with ". "
    if not ref_text.endswith(". ") and not ref_text.endswith("。"):
        if ref_text.endswith("."):
            ref_text += " "
        else:
            ref_text += ". "

    # Split the input text into batches
    audio, sr = torchaudio.load(ref_audio)
    max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
    gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
    print('ref_text', ref_text)
    for i, gen_text in enumerate(gen_text_batches):
        print(f'gen_text {i}', gen_text)
    
    print(f"Generating audio in {len(gen_text_batches)} batches...")
    return infer_batch((audio, sr), ref_text, gen_text_batches, ema_model, remove_silence, vocos, cross_fade_duration)
    

def process(ref_audio, ref_text, text_gen, remove_silence, ema_model, config, vocos, wave_path):
    main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
    if "voices" not in config:
        voices = {"main": main_voice}
    else:
        voices = config["voices"]
        voices["main"] = main_voice
    for voice in voices:
        voices[voice]['ref_audio'], voices[voice]['ref_text'] = process_voice(voices[voice]['ref_audio'], voices[voice]['ref_text'])

    generated_audio_segments = []
    reg1 = r'(?=\[\w+\])'
    chunks = re.split(reg1, text_gen)
    reg2 = r'\[(\w+)\]'
    for text in chunks:
        match = re.match(reg2, text)
        if not match or voice not in voices:
            voice = "main"
        else:
            voice = match[1]
        text = re.sub(reg2, "", text)
        gen_text = text.strip()
        ref_audio = voices[voice]['ref_audio']
        ref_text = voices[voice]['ref_text']
        print(f"Voice: {voice}")
        audio, spectragram = infer(ref_audio, ref_text, gen_text, remove_silence, ema_model, vocos)
        generated_audio_segments.append(audio)
        # Append one second of silence after each audio segment
        one_second_silence = np.zeros(int(target_sample_rate * 2.0))
        generated_audio_segments.append(one_second_silence)

    wave_path = f'{audio_videos_dir}/dialog.wav'
    if generated_audio_segments:
        final_wave = np.concatenate(generated_audio_segments)
        with open(wave_path, "wb") as f:
            sf.write(f.name, final_wave, target_sample_rate)
            # Remove silence
            if remove_silence:
                aseg = AudioSegment.from_file(f.name)
                non_silent_segs = silence.split_on_silence(
                    aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500
                )
                non_silent_wave = AudioSegment.silent(duration=0)
                for non_silent_seg in non_silent_segs:
                    non_silent_wave += non_silent_seg
                aseg = non_silent_wave
                aseg.export(f.name, format="wav")
            print(f.name)

def synthesize_audio_from_scenes(scenes, audio_videos_dir):
    text_to_synthesize = ""
    scene_count = 0
    for scene in scenes:
        if scene_count > 0:
            text_to_synthesize = text_to_synthesize + f"[{scene['character'].lower()}] {scene['text']} "
        else:
            text_to_synthesize = text_to_synthesize + f"{scene['text']}"
        scene_count = scene_count + 1

    # Initialize the template
    config = {
        "model": "F5-TTS",
        "ref_audio": "/mnt/d/data/audio/intro.mp3",
        "ref_text": "",
        "gen_text": "",
        "remove_silence": True,  # Corrected 'true' to 'True'
        "output_dir": "samples",
        "voices": {}
    }
    
    # Create voices entries from unique characters
    for character in unique_characters_list:
        directory = f'/mnt/d/data/audio/characters/{character.lower()}'
        
        # Check if the directory exists
        if not os.path.isdir(directory):
            print(f"Directory does not exist for character '{character}'. Picking a random existing directory.")
            directory = get_random_existing_directory('/mnt/d/data/audio/characters/')
            if not directory:
                print(f"No existing directories found. Skipping character {character.lower()}.")
                continue  # Skip to the next character if no directories are found
    
        first_audio_file = get_first_file_in_directory(directory)
        
        if first_audio_file:
            config['voices'][character.lower()] = {
                "ref_audio": first_audio_file,
                "ref_text": ""
            }
        else:
            print(f"No audio file found in directory: {directory} for character: '{character}'")
    
    # Convert the config dictionary to a JSON string
    config_json = json.dumps(config, indent=4)

    # Convert the string to a dictionary
    config = json.loads(config_json)
    
    ref_audio = config["ref_audio"]
    ref_text = config["ref_text"]
    gen_text = config["gen_text"]
    #gen_text = '''A Town Mouse and a Country Mouse were acquaintances, and the Country Mouse one day invited his friend to come and see him at his home in the fields. The Town Mouse came, and they sat down to a dinner of barleycorns and roots, the latter of which had a distinctly earthy flavour. The fare was not much to the taste of the guest, and presently he broke out with [town] “My poor dear friend, you live here no better than the ants. Now, you should just see how I fare! My larder is a regular horn of plenty. You must come and stay with me, and I promise you you shall live on the fat of the land.” [main] So when he returned to town he took the Country Mouse with him, and showed him into a larder containing flour and oatmeal and figs and honey and dates. The Country Mouse had never seen anything like it, and sat down to enjoy the luxuries his friend provided: but before they had well begun, the door of the larder opened and someone came in. The two Mice scampered off and hid themselves in a narrow and exceedingly uncomfortable hole. Presently, when all was quiet, they ventured out again; but someone else came in, and off they scuttled again. This was too much for the visitor. [country] “Goodbye,” [main] said he, [country] “I’m off. You live in the lap of luxury, I can see, but you are surrounded by dangers; whereas at home I can enjoy my simple dinner of roots and corn in peace."'''
    gen_text = text_to_synthesize
    output_dir = config["output_dir"]
    model = config["model"]
    ckpt_file = ""
    vocab_file = ""
    remove_silence = config["remove_silence"]
    wave_path = Path(output_dir)/"out.wav"
    spectrogram_path = Path(output_dir)/"out.png"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load the vocos model once
    vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")

    # Load models outside of functions to prevent multiple loading
    F5TTS_model_cfg = dict(
        dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
    )
    E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
    
    # Load the TTS model once
    if model == "F5-TTS":
        if ckpt_file == "":
            repo_name = "F5-TTS"
            exp_name = "F5TTS_Base"
            ckpt_step = 1200000
            ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
    
        ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file, vocab_file)
    elif model == "E2-TTS":
        if ckpt_file == "":
            repo_name = "E2-TTS"
            exp_name = "E2TTS_Base"
            ckpt_step = 1200000
            ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
    
        ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file, vocab_file)

    # Now call process with the loaded ema_model
    process(ref_audio, ref_text, gen_text, remove_silence, ema_model, config, vocos, wave_path)
    del vocos
    del ema_model
    reset_memory(device)

    return wave_path

def get_random_existing_directory(base_directory):
    """
    Returns a random existing directory path within the base_directory.
    """
    # Get a list of all directories within base_directory
    dir_list = [
        os.path.join(base_directory, d)
        for d in os.listdir(base_directory)
        if os.path.isdir(os.path.join(base_directory, d))
    ]
    if not dir_list:
        return None  # Handle the case where no directories are found
    return random.choice(dir_list)

def get_whisper_pipe():
    model_id = "openai/whisper-large-v3-turbo"  # Updated model ID if necessary

    # Create the pipeline directly using model_id
    pipe = pipeline(
        "automatic-speech-recognition",
        model=model_id,
        chunk_length_s=30,
        batch_size=16,  # Adjust based on your hardware capabilities
        torch_dtype=torch_dtype,
        device=device,
    )

    return pipe

In [None]:
load_scene = False
number_of_generations = 1
for goal in CONFIG["goals"]:
    for i in range(number_of_generations):
        if load_scene:
            with open(scenes_file_path, "r") as scenes_file:
                scenes = json.load(scenes_file)
        else:
            scenes = create_scenes(CONFIG, goal)
        unique_characters = set([entry['character'] for entry in scenes])
        unique_characters_list = list(unique_characters)
        print(f'unique characters: {unique_characters_list}')
        # Extract unique characters with their emotions
        #unique_characters_with_emotions = {(entry['character'], entry['emotion']) for entry in scenes}
        #unique_characters_with_emotions_list = list(unique_characters_with_emotions)
        # Synthesize starting images for scenes
        config, scenes, audio_images_dir, audio_videos_dir = process_all_audios(scenes, CONFIG)
        # save scene in output directory
        with open(f'{audio_videos_dir}/scenes.json', "w") as scenes_file:
            json.dump(scenes, scenes_file)
        # Synthesize dialog for scenes
        synthesize_audio_from_scenes(scenes, audio_videos_dir)
        # Synthesize video for scenes
        process_audio_video(config, scenes, audio_images_dir, audio_videos_dir)