# Qwen3-TTS Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kurtvalcorza/notebooks/blob/main/Qwen3_TTS.ipynb)

This notebook runs the [Qwen3-TTS](https://huggingface.co/spaces/Qwen/Qwen3-TTS) demo.

### Notes
- **Runtime**: GPU is required (T4 is sufficient).
- **Dependencies**: Installs `flash-attn` if possible, but falls back to standard attention if needed.


In [None]:
#@title 1. Setup and Installation
#@markdown This step clones the repository and installs necessary dependencies.

import os
import subprocess
import sys

# Clone the repository
if not os.path.exists("Qwen3-TTS"):
    !git clone https://huggingface.co/spaces/Qwen/Qwen3-TTS

%cd Qwen3-TTS

# Install dependencies
# Removing 'spaces' as it is HF specific and we will mock it
%pip install -r requirements.txt
%pip uninstall -y spaces

# Install ffmpeg for audio processing
!apt-get install -y ffmpeg

print("Setup complete.")

In [None]:
#@title 2. Mocking HF Spaces
#@markdown We replace the `spaces` library with a dummy implementation since we are running on a dedicated Colab GPU.

import sys
from functools import wraps

# Mock the spaces module
class MockSpaces:
    def GPU(self, duration=60):
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper
        return decorator

sys.modules["spaces"] = MockSpaces()

In [None]:
#@title 3. Load Models and Define Functions
#@markdown This cell loads the model logic adapted from `app.py`.

import os
import gradio as gr
import numpy as np
import torch
from huggingface_hub import snapshot_download, login

# Login is optional for public models, but good practice if you have a token
HF_TOKEN = os.environ.get('HF_TOKEN')
if HF_TOKEN:
    login(token=HF_TOKEN)

# Global model holders
loaded_models = {}
MODEL_SIZES = ["0.6B", "1.7B"]

def get_model_path(model_type: str, model_size: str) -> str:
    return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")

def get_model(model_type: str, model_size: str):
    global loaded_models
    key = (model_type, model_size)
    if key not in loaded_models:
        # The repo expects 'qwen_tts' to be importable. Since we are in the root of the cloned repo,
        # it should work. if not, we might need to append sys.path
        try:
            from qwen_tts import Qwen3TTSModel
        except ImportError:
            sys.path.append(os.getcwd())
            from qwen_tts import Qwen3TTSModel

        model_path = get_model_path(model_type, model_size)
        
        # Check for flash attention
        attn_implementation = "eager"
        try:
            import flash_attn
            attn_implementation = "flash_attention_2"
        except ImportError:
            pass
            
        print(f"Loading {model_type} {model_size} with {attn_implementation}...")
        loaded_models[key] = Qwen3TTSModel.from_pretrained(
            model_path,
            device_map="cuda",
            dtype=torch.bfloat16,
            token=HF_TOKEN,
            attn_implementation=attn_implementation
        )
    return loaded_models[key]

def _normalize_audio(wav, eps=1e-12, clip=True):
    x = np.asarray(wav)
    if np.issubdtype(x.dtype, np.integer):
        info = np.iinfo(x.dtype)
        if info.min < 0:
            y = x.astype(np.float32) / max(abs(info.min), info.max)
        else:
            mid = (info.max + 1) / 2.0
            y = (x.astype(np.float32) - mid) / mid
    elif np.issubdtype(x.dtype, np.floating):
        y = x.astype(np.float32)
        m = np.max(np.abs(y)) if y.size else 0.0
        if m > 1.0 + 1e-6:
            y = y / (m + eps)
    else:
        raise TypeError(f"Unsupported dtype: {x.dtype}")
    if clip:
        y = np.clip(y, -1.0, 1.0)
    if y.ndim > 1:
        y = np.mean(y, axis=-1).astype(np.float32)
    return y

def _audio_to_tuple(audio):
    if audio is None:
        return None
    if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[0], int):
        sr, wav = audio
        wav = _normalize_audio(wav)
        return wav, int(sr)
    if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
        sr = int(audio["sampling_rate"])
        wav = _normalize_audio(audio["data"])
        return wav, sr
    return None

# --- Generation Functions ---

def generate_voice_design(text, language, voice_description):
    if not text or not text.strip():
        return None, "Error: Text is required."
    if not voice_description or not voice_description.strip():
        return None, "Error: Voice description is required."
    try:
        tts = get_model("VoiceDesign", "1.7B")
        wavs, sr = tts.generate_voice_design(
            text=text.strip(),
            language=language,
            instruct=voice_description.strip(),
            non_streaming_mode=True,
            max_new_tokens=2048,
        )
        return (sr, wavs[0]), "Voice design generation completed successfully!"
    except Exception as e:
        return None, f"Error: {type(e).__name__}: {e}"

def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size):
    if not target_text or not target_text.strip():
        return None, "Error: Target text is required."
    audio_tuple = _audio_to_tuple(ref_audio)
    if audio_tuple is None:
        return None, "Error: Reference audio is required."
    if not use_xvector_only and (not ref_text or not ref_text.strip()):
        return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
    try:
        tts = get_model("Base", model_size)
        wavs, sr = tts.generate_voice_clone(
            text=target_text.strip(),
            language=language,
            ref_audio=audio_tuple,
            ref_text=ref_text.strip() if ref_text else None,
            x_vector_only_mode=use_xvector_only,
            max_new_tokens=2048,
        )
        return (sr, wavs[0]), "Voice clone generation completed successfully!"
    except Exception as e:
        return None, f"Error: {type(e).__name__}: {e}"

def generate_custom_voice(text, language, speaker, instruct, model_size):
    if not text or not text.strip():
        return None, "Error: Text is required."
    if not speaker:
        return None, "Error: Speaker is required."
    try:
        tts = get_model("CustomVoice", model_size)
        wavs, sr = tts.generate_custom_voice(
            text=text.strip(),
            language=language,
            speaker=speaker.lower().replace(" ", "_"),
            instruct=instruct.strip() if instruct else None,
            non_streaming_mode=True,
            max_new_tokens=2048,
        )
        return (sr, wavs[0]), "Generation completed successfully!"
    except Exception as e:
        return None, f"Error: {type(e).__name__}: {e}"


In [None]:
#@title 4. Launch Gradio Interface

SPEAKERS = [
    "Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
]
LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]

# Using the UI code from app.py but putting it inside a function to keep scope clean
def build_ui():
    theme = gr.themes.Soft(
        font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
    )
    css = """
    .gradio-container {max-width: none !important;}
    .tab-content {padding: 20px;}
    """
    
    with gr.Blocks(theme=theme, css=css, title="Qwen3-TTS Demo") as demo:
        gr.Markdown("# Qwen3-TTS Demo")
        
        with gr.Tabs():
            # Tab 1: Voice Design
            with gr.Tab("Voice Design (Instruct to Speech)"):
                gr.Markdown("**Note: Only 1.7B model supports Voice Design.**")
                with gr.Row():
                    with gr.Column(scale=1):
                        vd_text = gr.Textbox(label="Text to Synthesize", lines=3, placeholder="Enter text here...")
                        vd_language = gr.Dropdown(choices=LANGUAGES, value="Auto", label="Language")
                        vd_desc = gr.Textbox(label="Voice Description", lines=2, placeholder="e.g., A gentle, soothing female voice.")
                        vd_button = gr.Button("Generate", variant="primary")
                    with gr.Column(scale=1):
                        vd_output = gr.Audio(label="Generated Audio")
                        vd_status = gr.Markdown()
                vd_button.click(
                    generate_voice_design,
                    inputs=[vd_text, vd_language, vd_desc],
                    outputs=[vd_output, vd_status]
                )

            # Tab 2: Voice Clone (Zero-Shot)
            with gr.Tab("Voice Clone (Zero-Shot)"):
                with gr.Row():
                    with gr.Column(scale=1):
                        vc_target_text = gr.Textbox(label="Target Text", lines=3)
                        vc_language = gr.Dropdown(choices=LANGUAGES, value="Auto", label="Language")
                        with gr.Group():
                            vc_ref_audio = gr.Audio(label="Reference Audio", type="numpy")
                            vc_ref_text = gr.Textbox(label="Reference Text (Optional if 'Use x-vector only')")
                        vc_xvector = gr.Checkbox(label="Use x-vector only (ignores ref text)", value=False)
                        vc_model_size = gr.Radio(choices=MODEL_SIZES, value="1.7B", label="Model Size")
                        vc_button = gr.Button("Generate", variant="primary")
                    with gr.Column(scale=1):
                        vc_output = gr.Audio(label="Generated Audio")
                        vc_status = gr.Markdown()
                vc_button.click(
                    generate_voice_clone,
                    inputs=[vc_ref_audio, vc_ref_text, vc_target_text, vc_language, vc_xvector, vc_model_size],
                    outputs=[vc_output, vc_status]
                )

            # Tab 3: Custom Voices (Pre-set)
            with gr.Tab("Custom Voices (Pre-set)"):
                with gr.Row():
                    with gr.Column(scale=1):
                        cv_text = gr.Textbox(label="Text", lines=3)
                        cv_language = gr.Dropdown(choices=LANGUAGES, value="Auto", label="Language")
                        cv_speaker = gr.Dropdown(choices=SPEAKERS, value="Serena", label="Speaker")
                        cv_instruct = gr.Textbox(label="Instruction (Optional)", placeholder="e.g., speaking fast and likely")
                        cv_model_size = gr.Radio(choices=MODEL_SIZES, value="1.7B", label="Model Size")
                        cv_button = gr.Button("Generate", variant="primary")
                    with gr.Column(scale=1):
                        cv_output = gr.Audio(label="Generated Audio")
                        cv_status = gr.Markdown()
                cv_button.click(
                    generate_custom_voice,
                    inputs=[cv_text, cv_language, cv_speaker, cv_instruct, cv_model_size],
                    outputs=[cv_output, cv_status]
                )

    return demo

demo = build_ui()
demo.launch(share=True, debug=True)