In [None]:
# Marvis TTS - Enhanced Gradio Interface with Streaming and Voice Cloning
# Install required packages
!pip install -U transformers gradio soundfile librosa numpy torch

import torch
import gradio as gr
import soundfile as sf
import librosa
import numpy as np
import tempfile
import os
from transformers import AutoTokenizer, AutoProcessor, CsmForConditionalGeneration
from tokenizers.processors import TemplateProcessing
import threading
import time
from typing import Generator, Optional, Tuple
import io
from pathlib import Path
import datetime

class MarvisTTSInterface:
    def __init__(self):
        self.model_id = "Marvis-AI/marvis-tts-0.25m-v0.1-transformers"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = None
        self.model = None
        self.is_loaded = False
        # Create output directory for saving files
        self.output_dir = "marvis_tts_outputs"
        os.makedirs(self.output_dir, exist_ok=True)

    def save_audio_file(self, audio_data: np.ndarray, sample_rate: int = 24000, prefix: str = "marvis") -> str:
        """Save audio to a file and return the path"""
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"{prefix}_{timestamp}.wav"
        filepath = os.path.join(self.output_dir, filename)

        # Ensure audio is in the right format
        if len(audio_data.shape) > 1:
            audio_data = audio_data.squeeze()

        # Normalize audio to prevent clipping
        if np.max(np.abs(audio_data)) > 0:
            audio_data = audio_data / np.max(np.abs(audio_data)) * 0.95

        sf.write(filepath, audio_data, sample_rate, subtype="PCM_16")
        return filepath

    def load_model(self):
        """Load the Marvis TTS model and processor"""
        if self.is_loaded:
            return "Model already loaded!"

        try:
            print("Loading Marvis TTS model...")
            self.processor = AutoProcessor.from_pretrained(self.model_id)
            self.model = CsmForConditionalGeneration.from_pretrained(self.model_id).to(self.device)
            self.is_loaded = True
            return f"✅ Model loaded successfully on {self.device}!"
        except Exception as e:
            return f"❌ Error loading model: {str(e)}"

    def preprocess_audio(self, audio_file: str, target_length: float = 10.0) -> np.ndarray:
        """
        Preprocess reference audio for voice cloning
        Args:
            audio_file: Path to audio file
            target_length: Target length in seconds for voice reference
        """
        try:
            # Load audio file
            audio, sr = librosa.load(audio_file, sr=24000)

            # Ensure audio is the right length (10 seconds for optimal voice cloning)
            target_samples = int(target_length * sr)

            if len(audio) > target_samples:
                # Trim to target length
                audio = audio[:target_samples]
            elif len(audio) < target_samples:
                # Pad with silence if too short
                padding = target_samples - len(audio)
                audio = np.pad(audio, (0, padding), mode='constant', constant_values=0)

            # Normalize audio
            audio = audio / np.max(np.abs(audio))

            return audio
        except Exception as e:
            raise ValueError(f"Error preprocessing audio: {str(e)}")

    def generate_basic_tts(self, text: str, speaker_id: int = 0) -> Tuple[int, np.ndarray]:
        """Generate TTS without voice cloning"""
        if not self.is_loaded:
            raise ValueError("Model not loaded. Please load the model first.")

        try:
            # Clear any cached states to prevent contamination
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            # Add speaker ID to text
            formatted_text = f"[{speaker_id}]{text}"

            # Prepare inputs
            inputs = self.processor(formatted_text, add_special_tokens=True, return_tensors="pt").to(self.device)

            # Generate audio with fresh state
            with torch.no_grad():
                audio = self.model.generate(
                    input_ids=inputs['input_ids'],
                    output_audio=True,
                    do_sample=True,
                    temperature=0.7,
                    pad_token_id=self.processor.tokenizer.eos_token_id
                )

            # Convert to numpy and ensure proper shape
            audio_np = audio[0].cpu().numpy()
            if len(audio_np.shape) > 1:
                audio_np = audio_np.squeeze()

            return 24000, audio_np

        except Exception as e:
            raise ValueError(f"Error generating TTS: {str(e)}")

    def estimate_audio_duration(self, text: str) -> float:
        """Estimate audio duration based on text length (roughly 150 words per minute)"""
        word_count = len(text.split())
        return (word_count / 150) * 60  # Convert to seconds

    def smart_text_chunker(self, text: str, max_duration: float = 8.0) -> list:
        """Split text into chunks that won't exceed max_duration seconds"""
        sentences = text.replace('!', '.').replace('?', '.').split('.')
        chunks = []
        current_chunk = ""

        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue

            # Test if adding this sentence would exceed duration limit
            test_chunk = current_chunk + (" " if current_chunk else "") + sentence + "."
            if self.estimate_audio_duration(test_chunk) > max_duration and current_chunk:
                # Save current chunk and start new one
                chunks.append(current_chunk.strip())
                current_chunk = sentence + "."
            else:
                current_chunk = test_chunk

        # Add the last chunk if it exists
        if current_chunk.strip():
            chunks.append(current_chunk.strip())

        return chunks

    def generate_streaming_tts(self, text: str, speaker_id: int = 0) -> Generator[Tuple[int, np.ndarray], None, None]:
        """Generate streaming TTS with proper chunking"""
        if not self.is_loaded:
            raise ValueError("Model not loaded. Please load the model first.")

        try:
            # Clear any cached states
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            # Split text into chunks that won't exceed ~8 seconds
            chunks = self.smart_text_chunker(text, max_duration=8.0)

            for chunk in chunks:
                if chunk.strip():
                    # Generate audio for this chunk
                    formatted_text = f"[{speaker_id}]{chunk}"
                    inputs = self.processor(formatted_text, add_special_tokens=True, return_tensors="pt").to(self.device)

                    with torch.no_grad():
                        audio = self.model.generate(
                            input_ids=inputs['input_ids'],
                            output_audio=True,
                            do_sample=True,
                            temperature=0.7
                        )

                    audio_np = audio[0].cpu().numpy()
                    if len(audio_np.shape) > 1:
                        audio_np = audio_np.squeeze()

                    yield 24000, audio_np

        except Exception as e:
            raise ValueError(f"Error in streaming TTS: {str(e)}")

    def generate_voice_cloned_tts(self, text: str, reference_audio: str) -> Tuple[int, np.ndarray]:
        """
        Generate TTS with voice cloning using reference audio
        Note: Simplified approach - real voice cloning would need model fine-tuning
        """
        if not self.is_loaded:
            raise ValueError("Model not loaded. Please load the model first.")

        if not reference_audio:
            raise ValueError("Please provide a reference audio file for voice cloning.")

        try:
            # Clear any cached states
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            # For now, since this model doesn't have built-in voice cloning,
            # we'll try a different approach - use the reference audio to
            # influence the generation (this is a workaround)

            # Load and analyze reference audio
            ref_audio, sr = librosa.load(reference_audio, sr=24000)

            # Use basic TTS with different parameters to simulate voice adaptation
            # In a real implementation, you'd extract voice embeddings from ref_audio
            formatted_text = f"[2]{text}"  # Use speaker 2 as "cloned" voice

            inputs = self.processor(formatted_text, add_special_tokens=True, return_tensors="pt").to(self.device)

            with torch.no_grad():
                # Generate with different parameters to try to get variation
                audio = self.model.generate(
                    input_ids=inputs['input_ids'],
                    output_audio=True,
                    do_sample=True,
                    temperature=0.9,  # Higher temperature for more variation
                    top_k=50,
                    top_p=0.95,
                    pad_token_id=self.processor.tokenizer.eos_token_id
                )

            audio_np = audio[0].cpu().numpy()
            if len(audio_np.shape) > 1:
                audio_np = audio_np.squeeze()

            # Note: This is a simplified implementation
            # Real voice cloning would require:
            # 1. Voice encoder to extract speaker embeddings from reference
            # 2. Model architecture that accepts speaker embeddings as conditioning
            # 3. Training on speaker adaptation tasks

            return 24000, audio_np

        except Exception as e:
            raise ValueError(f"Error in voice cloning: {str(e)}")

# Initialize the TTS interface
tts_interface = MarvisTTSInterface()

# Gradio Interface Functions
def load_model_interface():
    """Load model interface for Gradio"""
    return tts_interface.load_model()

def basic_tts_interface(text: str, speaker_id: int):
    """Basic TTS interface for Gradio with file saving"""
    try:
        if not text.strip():
            return None, None, "Please enter some text to synthesize."

        sample_rate, audio = tts_interface.generate_basic_tts(text, speaker_id)

        # Save the audio file
        saved_file = tts_interface.save_audio_file(audio, sample_rate, "basic_tts")

        return (sample_rate, audio), saved_file, f"✅ Audio generated and saved to: {saved_file}"
    except Exception as e:
        return None, None, f"❌ Error: {str(e)}"

def streaming_tts_interface(text: str, speaker_id: int):
    """Real-time streaming TTS interface for Gradio"""
    try:
        if not text.strip():
            return None, None, "Please enter some text to synthesize."

        # Collect all streaming chunks for concatenation
        audio_chunks = []
        chunk_outputs = []

        # Generate and yield each chunk in real-time
        for i, (sample_rate, chunk) in enumerate(tts_interface.generate_streaming_tts(text, speaker_id)):
            audio_chunks.append(chunk)

            # Yield individual chunk for real-time playback
            chunk_outputs.append((sample_rate, chunk.copy()))

            # Create a temporary concatenated version for progress
            if len(audio_chunks) > 1:
                partial_audio = np.concatenate(audio_chunks)
            else:
                partial_audio = chunk

            # Yield both the individual chunk and the growing concatenated audio
            yield (sample_rate, chunk.copy()), (sample_rate, partial_audio.copy()), f"✅ Generated chunk {i+1}..."

        # Final concatenated audio
        if audio_chunks:
            full_audio = np.concatenate(audio_chunks)
            yield (sample_rate, audio_chunks[-1]), (sample_rate, full_audio), f"✅ Streaming complete! Generated {len(audio_chunks)} chunks."
        else:
            yield None, None, "❌ No audio generated."

    except Exception as e:
        yield None, None, f"❌ Error: {str(e)}"

def streaming_tts_interface_simple(text: str, speaker_id: int):
    """Simplified streaming interface that returns final concatenated result with file saving"""
    try:
        if not text.strip():
            return None, None, "Please enter some text to synthesize."

        # Collect all streaming chunks
        audio_chunks = []
        chunk_count = 0

        for sample_rate, chunk in tts_interface.generate_streaming_tts(text, speaker_id):
            audio_chunks.append(chunk)
            chunk_count += 1

        # Concatenate all chunks for final output
        if audio_chunks:
            full_audio = np.concatenate(audio_chunks)

            # Save the complete streaming audio
            saved_file = tts_interface.save_audio_file(full_audio, sample_rate, "streaming_tts")

            return (sample_rate, full_audio), saved_file, f"✅ Streaming audio generated and saved! ({chunk_count} chunks) - File: {saved_file}"
        else:
            return None, None, "❌ No audio generated."

    except Exception as e:
        return None, None, f"❌ Error: {str(e)}"

def voice_cloning_interface(text: str, reference_audio):
    """Voice cloning interface for Gradio with file saving"""
    try:
        if not text.strip():
            return None, None, "Please enter some text to synthesize."

        if reference_audio is None:
            return None, None, "Please upload a reference audio file."

        sample_rate, audio = tts_interface.generate_voice_cloned_tts(text, reference_audio)

        # Save the voice cloned audio
        saved_file = tts_interface.save_audio_file(audio, sample_rate, "voice_cloned")

        return (sample_rate, audio), saved_file, f"✅ Voice variation audio generated and saved to: {saved_file}"

    except Exception as e:
        return None, None, f"❌ Error: {str(e)}"

# Create Gradio Interface
def create_gradio_interface():
    """Create the main Gradio interface"""

    with gr.Blocks(title="Marvis TTS - Enhanced Interface", theme=gr.themes.Soft()) as interface:

        gr.Markdown("""
        # 🎙️ Marvis TTS - Enhanced Interface

        A powerful text-to-speech system with streaming and voice cloning capabilities.

        **Features:**
        - 🚀 Real-time streaming TTS
        - 🎭 Voice cloning with reference audio
        - 🔊 High-quality 24kHz audio output
        - 💻 GPU acceleration support
        """)

        # Model Loading Section
        with gr.Row():
            with gr.Column():
                load_btn = gr.Button("🔄 Load Marvis TTS Model", variant="primary")
                model_status = gr.Textbox(label="Model Status", interactive=False)

        load_btn.click(fn=load_model_interface, outputs=model_status)

        # Main TTS Tabs
        with gr.Tabs():

            # Basic TTS Tab
            with gr.TabItem("🎯 Basic TTS"):
                with gr.Row():
                    with gr.Column():
                        basic_text = gr.Textbox(
                            label="Text to Synthesize",
                            placeholder="Enter text here...",
                            lines=3
                        )
                        basic_speaker = gr.Slider(
                            label="Speaker ID",
                            minimum=0,
                            maximum=3,
                            value=0,
                            step=1
                        )
                        basic_generate_btn = gr.Button("🗣️ Generate Speech", variant="primary")

                    with gr.Column():
                        basic_audio_output = gr.Audio(label="Generated Audio")
                        basic_download_file = gr.File(label="💾 Download Audio File", visible=True)
                        basic_status = gr.Textbox(label="Status", interactive=False)

                basic_generate_btn.click(
                    fn=basic_tts_interface,
                    inputs=[basic_text, basic_speaker],
                    outputs=[basic_audio_output, basic_download_file, basic_status]
                )

            # Streaming TTS Tab
            with gr.TabItem("⚡ Streaming TTS"):
                gr.Markdown("""
                **Real-time Streaming**: Audio is generated and played back in chunks as they're created.
                """)
                with gr.Row():
                    with gr.Column():
                        streaming_text = gr.Textbox(
                            label="Text to Synthesize (will be streamed)",
                            placeholder="Enter longer text for streaming demo...",
                            lines=4
                        )
                        streaming_speaker = gr.Slider(
                            label="Speaker ID",
                            minimum=0,
                            maximum=3,
                            value=0,
                            step=1
                        )
                        with gr.Row():
                            streaming_generate_btn = gr.Button("📡 Start Streaming", variant="primary")
                            streaming_simple_btn = gr.Button("🔄 Generate Full Audio", variant="secondary")

                    with gr.Column():
                        # Real-time streaming output
                        gr.Markdown("**🎵 Current Chunk (Real-time)**")
                        streaming_chunk_output = gr.Audio(label="Current Audio Chunk", autoplay=True)

                        # Full concatenated output
                        gr.Markdown("**💾 Full Audio (Downloadable)**")
                        streaming_full_output = gr.Audio(label="Complete Streamed Audio")
                        streaming_download_file = gr.File(label="💾 Download Complete Audio", visible=True)

                        streaming_status = gr.Textbox(label="Status", interactive=False)

                # Real-time streaming with chunks
                streaming_generate_btn.click(
                    fn=streaming_tts_interface,
                    inputs=[streaming_text, streaming_speaker],
                    outputs=[streaming_chunk_output, streaming_full_output, streaming_status]
                )

                # Simple full generation
                streaming_simple_btn.click(
                    fn=streaming_tts_interface_simple,
                    inputs=[streaming_text, streaming_speaker],
                    outputs=[streaming_full_output, streaming_download_file, streaming_status]
                )

            # Voice Cloning Tab
            with gr.TabItem("🎭 Voice Cloning"):
                gr.Markdown("""
                ⚠️ **Note**: This model doesn't have built-in voice cloning capabilities.
                This tab demonstrates a simplified approach using different generation parameters.
                Real voice cloning would require a model trained specifically for speaker adaptation.
                """)
                with gr.Row():
                    with gr.Column():
                        clone_text = gr.Textbox(
                            label="Text to Synthesize",
                            placeholder="Enter text to speak in a different voice style...",
                            lines=3
                        )
                        reference_audio = gr.Audio(
                            label="Reference Audio (for analysis - limited effect)",
                            type="filepath"
                        )
                        clone_generate_btn = gr.Button("🎪 Generate with Voice Variation", variant="primary")

                    with gr.Column():
                        clone_audio_output = gr.Audio(label="Generated Audio")
                        clone_download_file = gr.File(label="💾 Download Audio File", visible=True)
                        clone_status = gr.Textbox(label="Status", interactive=False)

                clone_generate_btn.click(
                    fn=voice_cloning_interface,
                    inputs=[clone_text, reference_audio],
                    outputs=[clone_audio_output, clone_download_file, clone_status]
                )

        # File Management Section
        with gr.Accordion("📁 File Management", open=False):
            gr.Markdown(f"""
            **Output Directory**: `{tts_interface.output_dir}/`

            All generated audio files are automatically saved with timestamps:
            - **Basic TTS**: `basic_tts_YYYYMMDD_HHMMSS.wav`
            - **Streaming TTS**: `streaming_tts_YYYYMMDD_HHMMSS.wav`
            - **Voice Variation**: `voice_cloned_YYYYMMDD_HHMMSS.wav`

            Files are saved locally and available for download via the file components above.
            """)

        # Information Section
        with gr.Accordion("ℹ️ Information & Tips", open=False):
            gr.Markdown("""
            ### Usage Tips:

            **Basic TTS:**
            - Choose different speaker IDs (0-3) for voice variety
            - Optimal for short to medium length texts

            **Streaming TTS:**
            - 🎵 **Real-time chunks**: Individual chunks play as they're generated (with autoplay)
            - 💾 **Full audio**: Complete concatenated audio available for download
            - 🔄 **Two modes**: Real-time streaming or full generation
            - ⏱️ **Smart chunking**: Automatically splits long text into ~8-second segments

            **Voice Cloning:**
            - ⚠️ **Important**: This model doesn't have true voice cloning capabilities
            - The "voice cloning" tab uses different generation parameters for voice variation
            - Real voice cloning requires specialized model architecture and training
            - Upload reference audio to analyze characteristics (limited effect)

            ### Technical Details:
            - Sample Rate: 24,000 Hz
            - Model: Marvis TTS 250M parameters
            - Architecture: Conversational Speech Model (CSM)
            - Codec: Kyutai's mimi codec with RVQ tokens
            """)

    return interface

# Launch the interface
if __name__ == "__main__":
    interface = create_gradio_interface()
    interface.launch(
        share=True,  # Create shareable link
        debug=True,  # Enable debug mode
        server_port=7860  # Default Gradio port
    )