# 🎤 Spark-TTS on Google Colab

This notebook allows you to run Spark-TTS (Text-to-Speech with Voice Cloning) on Google Colab.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/YOUR_USERNAME/Spark-TTS/blob/main/Spark_TTS_Colab.ipynb)

## Features:
- 🎯 **Voice Cloning**: Clone any voice from a reference audio
- 🎨 **Voice Creation**: Generate synthetic voices with custom parameters
- 🚀 **GPU Acceleration**: Utilizes Colab's free GPU for faster inference
- 🌐 **Public Access**: Share your TTS interface with others

---

## 📋 Setup Instructions

1. **Enable GPU**: Go to `Runtime` > `Change runtime type` > Select `GPU` as Hardware accelerator
2. **Run the cells below** in order
3. **Upload your model** when prompted
4. **Access the web interface** via the generated link

---

In [None]:
#@title 🔧 Check GPU Availability

import torch
import os

print("🔍 System Information:")
print(f"Python version: {os.sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"🔥 GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("⚠️ GPU not available. Consider enabling GPU in Runtime settings.")

In [None]:
#@title 📦 Install Dependencies

print("📦 Installing required packages...")

# Install core packages
!pip install -q gradio soundfile torch torchaudio transformers

# Install additional dependencies
!pip install -q librosa numpy scipy

print("✅ Dependencies installed successfully!")

In [None]:
#@title 📥 Clone Spark-TTS Repository

import os

# Clone the repository
if not os.path.exists('/content/Spark-TTS'):
    print("📥 Cloning Spark-TTS repository...")
    !git clone https://github.com/SparkAudio/Spark-TTS.git /content/Spark-TTS
    print("✅ Repository cloned successfully!")
else:
    print("📁 Repository already exists, updating...")
    !cd /content/Spark-TTS && git pull
    print("✅ Repository updated!")

# Change to the project directory
os.chdir('/content/Spark-TTS')
print(f"📂 Current directory: {os.getcwd()}")

In [None]:
#@title 🤖 Download Pre-trained Model

import os
from google.colab import files

model_dir = "/content/Spark-TTS/pretrained_models/Spark-TTS-0.5B"

# Create model directory
os.makedirs(model_dir, exist_ok=True)
os.makedirs("/content/example/results", exist_ok=True)

print("🤖 Model Setup Options:")
print("1. Auto-download (if URL available)")
print("2. Manual upload")
print()

# Check if model already exists
if os.path.exists(f"{model_dir}/config.json"):
    print("✅ Model already exists!")
else:
    print("📥 Model not found. Please upload your model files.")
    print(f"📁 Upload your model files to: {model_dir}")
    print("Required files: config.json, pytorch_model.bin, tokenizer files, etc.")
    print()
    print("💡 Tip: You can upload files using the file browser on the left panel")
    
    # Option to upload files
    upload_choice = input("\nDo you want to upload files now? (y/n): ")
    if upload_choice.lower() == 'y':
        print("📤 Please select your model files:")
        uploaded = files.upload()
        
        # Move uploaded files to model directory
        for filename, content in uploaded.items():
            with open(f"{model_dir}/{filename}", 'wb') as f:
                f.write(content)
            print(f"✅ Uploaded: {filename}")

In [None]:
#@title 🌐 Create and Launch WebUI

# Create the Colab-optimized webui file
webui_code = '''
import os
import torch
import soundfile as sf
import logging
import gradio as gr
import platform
from datetime import datetime
from cli.SparkTTS import SparkTTS
from sparktts.utils.token_parser import LEVELS_MAP_UI

# Colab detection
IN_COLAB = True

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        print(f"🔥 Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("⚠️ GPU not available, using CPU")
    return device

def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B"):
    print(f"📂 Loading model from: {model_dir}")
    device = get_device()
    model = SparkTTS(model_dir, device)
    return model

def run_tts(text, model, prompt_text=None, prompt_speech=None, 
           gender=None, pitch=None, speed=None, save_dir="/content/example/results"):
    if prompt_text is not None:
        prompt_text = None if len(prompt_text) <= 1 else prompt_text
    
    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    save_path = os.path.join(save_dir, f"{timestamp}.wav")
    
    with torch.no_grad():
        wav = model.inference(text, prompt_speech, prompt_text, gender, pitch, speed)
        sf.write(save_path, wav, samplerate=16000)
    
    return save_path

# Initialize model
try:
    model = initialize_model()
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Failed to load model: {e}")
    model = None

def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
    if not text.strip():
        return None, "Please enter text to synthesize"
    if model is None:
        return None, "Model not loaded"
    
    prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
    prompt_text_clean = None if len(prompt_text) < 2 else prompt_text
    
    try:
        audio_path = run_tts(text, model, prompt_text_clean, prompt_speech)
        return audio_path, "✅ Audio generated successfully!"
    except Exception as e:
        return None, f"❌ Error: {str(e)}"

def voice_creation(text, gender, pitch, speed):
    if not text.strip():
        return None, "Please enter text to synthesize"
    if model is None:
        return None, "Model not loaded"
    
    try:
        pitch_val = LEVELS_MAP_UI[int(pitch)]
        speed_val = LEVELS_MAP_UI[int(speed)]
        audio_path = run_tts(text, model, gender=gender, pitch=pitch_val, speed=speed_val)
        return audio_path, "✅ Audio generated successfully!"
    except Exception as e:
        return None, f"❌ Error: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Spark-TTS Colab") as demo:
    gr.HTML("""
    <div style="text-align: center; margin-bottom: 20px;">
        <h1>🎤 Spark-TTS on Google Colab</h1>
        <p>High-quality Text-to-Speech synthesis with voice cloning</p>
    </div>
    """)
    
    with gr.Tabs():
        with gr.TabItem("🎯 Voice Clone"):
            with gr.Row():
                prompt_wav_upload = gr.Audio(sources=["upload"], type="filepath", 
                                           label="📁 Upload Reference Audio")
                prompt_wav_record = gr.Audio(sources=["microphone"], type="filepath", 
                                           label="🎙️ Record Reference Audio")
            
            with gr.Row():
                text_input = gr.Textbox(label="📝 Text to Synthesize", lines=3, 
                                      value="Hello, this is a test of voice cloning.")
                prompt_text_input = gr.Textbox(label="📄 Reference Text (Optional)", lines=3)
            
            generate_btn = gr.Button("🚀 Generate Speech", variant="primary")
            
            with gr.Row():
                audio_output = gr.Audio(label="🔊 Generated Audio")
                status_output = gr.Textbox(label="Status", interactive=False)
            
            generate_btn.click(voice_clone, 
                             inputs=[text_input, prompt_text_input, prompt_wav_upload, prompt_wav_record],
                             outputs=[audio_output, status_output])
        
        with gr.TabItem("🎨 Voice Creation"):
            with gr.Row():
                with gr.Column():
                    gender = gr.Radio(["male", "female"], value="male", label="👤 Gender")
                    pitch = gr.Slider(1, 5, value=3, step=1, label="🎵 Pitch")
                    speed = gr.Slider(1, 5, value=3, step=1, label="⚡ Speed")
                
                with gr.Column():
                    text_creation = gr.Textbox(label="📝 Input Text", lines=4,
                                             value="You can customize voice parameters.")
                    create_btn = gr.Button("🎯 Create Voice", variant="primary")
            
            with gr.Row():
                audio_creation = gr.Audio(label="🔊 Generated Audio")
                status_creation = gr.Textbox(label="Status", interactive=False)
            
            create_btn.click(voice_creation,
                           inputs=[text_creation, gender, pitch, speed],
                           outputs=[audio_creation, status_creation])

# Launch the interface
demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
'''

# Write the webui code to a file
with open('/content/Spark-TTS/webui_colab.py', 'w') as f:
    f.write(webui_code)

print("✅ WebUI code created successfully!")
print("🚀 Launching Spark-TTS WebUI...")
print("⏳ This may take a few moments...")

# Run the webui
exec(open('/content/Spark-TTS/webui_colab.py').read())

## 💡 Usage Tips

### Voice Cloning:
1. **Upload a reference audio** (WAV, MP3) with clear speech
2. **Enter the text** you want to synthesize
3. **Optionally provide reference text** for better results
4. Click **Generate Speech**

### Voice Creation:
1. **Select gender** (male/female)
2. **Adjust pitch and speed** (1=lowest, 5=highest)
3. **Enter your text**
4. Click **Create Voice**

### Tips:
- 📁 **Reference audio**: Use clear, noise-free audio for best results
- ⏱️ **Length**: 3-30 seconds of reference audio works best
- 🎧 **Quality**: Higher quality reference audio = better cloned voice
- 💾 **Download**: Right-click on generated audio to save

---

In [None]:
#@title 💾 Download Generated Audio Files

import os
import zipfile
from google.colab import files

results_dir = "/content/example/results"

if os.path.exists(results_dir) and os.listdir(results_dir):
    print("📁 Found generated audio files:")
    audio_files = [f for f in os.listdir(results_dir) if f.endswith('.wav')]
    
    for i, file in enumerate(audio_files, 1):
        print(f"{i}. {file}")
    
    if len(audio_files) == 1:
        # Download single file
        file_path = os.path.join(results_dir, audio_files[0])
        files.download(file_path)
        print(f"✅ Downloaded: {audio_files[0]}")
    else:
        # Create zip file for multiple files
        zip_path = "/content/generated_audio.zip"
        with zipfile.ZipFile(zip_path, 'w') as zipf:
            for file in audio_files:
                file_path = os.path.join(results_dir, file)
                zipf.write(file_path, file)
        
        files.download(zip_path)
        print(f"✅ Downloaded zip file with {len(audio_files)} audio files")
else:
    print("❌ No generated audio files found.")
    print("💡 Generate some audio using the interface above first.")

## 🔧 Troubleshooting

### Common Issues:

1. **Model loading errors**:
   - Ensure all model files are uploaded correctly
   - Check that the model directory structure is correct

2. **GPU memory issues**:
   - Restart runtime: `Runtime` > `Restart runtime`
   - Use shorter text inputs

3. **Interface not loading**:
   - Wait for the model to load completely
   - Check the console for error messages

4. **Audio quality issues**:
   - Use high-quality reference audio (16kHz+)
   - Ensure reference audio is clear and noise-free

### Need Help?
- 📖 Check the [Spark-TTS documentation](https://github.com/SparkAudio/Spark-TTS)
- 🐛 Report issues on [GitHub](https://github.com/SparkAudio/Spark-TTS/issues)

---

### 📄 License
This project is licensed under the Apache License 2.0. See the LICENSE file for details.

### 🙏 Credits
- **Spark-TTS** by SparkAudio
- **Gradio** for the web interface
- **Google Colab** for free GPU access
