# üöÄ SadTalker: Optimized Cached Setup (Low Cost)

**Pre-process face + voice once ‚Üí generate videos from text instantly**

**Cost optimization:**
- ‚úÖ Pre-process face **once** (saves 3DMM coefficients)
- ‚úÖ Pre-process voice **once** (optional - use your voice model)
- ‚úÖ Generate videos from **text only** (no face/voice reprocessing)
- ‚úÖ **~10x faster** generation (no face detection/3DMM extraction each time)

## Step 1: Enable GPU

**Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU**

In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

## Step 2: Install dependencies

In [None]:
# Use Colab's pre-installed CUDA-enabled PyTorch
!pip install -q edge-tts face_alignment imageio imageio-ffmpeg librosa resampy pydub kornia yacs scikit-image basicsr facexlib gfpgan av safetensors gradio
!apt-get update -qq && apt-get install -y -qq ffmpeg 2>/dev/null || true

import torch
print(f"‚úì PyTorch {torch.__version__}")
print(f"‚úì CUDA available: {torch.cuda.is_available()}")

## Step 3: Download code + models (same as minimal setup)

In [None]:
# Copy Steps 3-4 from colab_minimal_setup.ipynb
# This cell should download code and models
# For now, assuming they're already downloaded

## Step 3.5: Setup Assets Directory

In [None]:
# Create assets directory structure
import os
BASE_DIR = "/content/SadTalker"
ASSETS_DIR = os.path.join(BASE_DIR, "assets")
os.makedirs(os.path.join(ASSETS_DIR, "image"), exist_ok=True)
os.makedirs(os.path.join(ASSETS_DIR, "audio"), exist_ok=True)

print("‚úì Assets directory created")
print(f"  Image folder: {os.path.join(ASSETS_DIR, 'image')}")
print(f"  Audio folder: {os.path.join(ASSETS_DIR, 'audio')}")
print("\nüìÅ Upload your files:")
print("  - female-image-01.jpg ‚Üí assets/image/")
print("  - female-voice-01.mp3 ‚Üí assets/audio/")

## Step 4: Optimized Cached Pipeline

**Two modes:**
1. **Setup Mode**: Upload face image + voice ‚Üí pre-process and cache
2. **Generate Mode**: Enter text ‚Üí use cached face/voice ‚Üí fast generation

In [None]:
import os
import sys
import subprocess
import pickle
import json
from pathlib import Path
from datetime import datetime
import asyncio
import edge_tts
from pydub import AudioSegment
import gradio as gr
import cv2
import numpy as np

# Fix numpy 2.x compatibility
if not hasattr(np, 'float'):
    np.float = float
if not hasattr(np, 'int'):
    np.int = int

BASE_DIR = "/content/SadTalker"
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
CACHE_DIR = os.path.join(BASE_DIR, "cache")
os.makedirs(CACHE_DIR, exist_ok=True)
os.chdir(BASE_DIR)
sys.path.insert(0, BASE_DIR)

RESULT_DIR = os.path.join(BASE_DIR, "results")
os.makedirs(RESULT_DIR, exist_ok=True)

# Cache file paths
FACE_CACHE_FILE = os.path.join(CACHE_DIR, "face_cache.pkl")
VOICE_CACHE_FILE = os.path.join(CACHE_DIR, "voice_cache.pkl")

# Pre-load your assets (adjust paths as needed)
ASSETS_DIR = os.path.join(BASE_DIR, "assets")
DEFAULT_IMAGE = os.path.join(ASSETS_DIR, "image", "female-image-01.jpg")
DEFAULT_VOICE = os.path.join(ASSETS_DIR, "audio", "female-voice-01.mp3")


def preprocess_and_cache_face(image_path: str, cache_id: str = "default"):
    """Pre-process face once and cache the results."""
    print("Pre-processing face (this runs once)...")
    
    from src.utils.preprocess import CropAndExtract
    from src.utils.init_path import init_path
    import torch
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sadtalker_paths = init_path(CHECKPOINT_DIR, os.path.join(BASE_DIR, 'src/config'), 256, False, 'full')
    
    preprocess_model = CropAndExtract(sadtalker_paths, device)
    
    # Extract face coefficients (expensive operation - done once)
    cache_frame_dir = os.path.join(CACHE_DIR, f"face_{cache_id}")
    os.makedirs(cache_frame_dir, exist_ok=True)
    
    first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
        image_path, cache_frame_dir, 'full', source_image_flag=True, pic_size=256
    )
    
    if first_coeff_path is None:
        return None, "Face detection failed"
    
    # Cache the results
    cache_data = {
        'first_coeff_path': first_coeff_path,
        'crop_pic_path': crop_pic_path,
        'crop_info': crop_info,
        'image_path': image_path,
        'cache_id': cache_id
    }
    
    with open(FACE_CACHE_FILE, 'wb') as f:
        pickle.dump(cache_data, f)
    
    del preprocess_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return cache_data, "‚úì Face pre-processed and cached!"


def load_face_cache():
    """Load cached face data."""
    if os.path.exists(FACE_CACHE_FILE):
        with open(FACE_CACHE_FILE, 'rb') as f:
            return pickle.load(f)
    return None


def preprocess_and_cache_voice(audio_path: str, cache_id: str = "default"):
    """Pre-process voice file (convert to WAV, store path)."""
    print("Pre-processing voice file...")
    
    # Convert MP3 to WAV if needed
    if audio_path.endswith('.mp3'):
        wav_path = audio_path.replace('.mp3', '.wav')
        if not os.path.exists(wav_path):
            audio = AudioSegment.from_mp3(audio_path)
            audio.export(wav_path, format="wav")
        audio_path = wav_path
    
    # Cache voice file path
    cache_data = {
        'voice_path': audio_path,
        'cache_id': cache_id
    }
    
    with open(VOICE_CACHE_FILE, 'wb') as f:
        pickle.dump(cache_data, f)
    
    return cache_data, f"‚úì Voice file cached: {os.path.basename(audio_path)}"


def load_voice_cache():
    """Load cached voice data."""
    if os.path.exists(VOICE_CACHE_FILE):
        with open(VOICE_CACHE_FILE, 'rb') as f:
            return pickle.load(f)
    return None


def auto_setup_from_assets():
    """Automatically setup using default assets if they exist."""
    face_cache = load_face_cache()
    voice_cache = load_voice_cache()
    
    setup_done = []
    
    # Setup face if image exists and not cached
    if os.path.exists(DEFAULT_IMAGE) and not face_cache:
        print("Auto-setting up face from assets...")
        cache_data, msg = preprocess_and_cache_face(DEFAULT_IMAGE, "female-01")
        setup_done.append(f"Face: {msg}")
    
    # Setup voice if audio exists and not cached
    if os.path.exists(DEFAULT_VOICE) and not voice_cache:
        print("Auto-setting up voice from assets...")
        cache_data, msg = preprocess_and_cache_voice(DEFAULT_VOICE, "female-01")
        setup_done.append(f"Voice: {msg}")
    
    return "\n".join(setup_done) if setup_done else "‚úì Assets already cached or not found"


async def text_to_speech_async(text: str, voice: str, out_path: str):
    """Generate speech from text."""
    mp3_path = out_path.replace(".wav", ".mp3")
    communicate = edge_tts.Communicate(text, voice)
    await communicate.save(mp3_path)
    audio = AudioSegment.from_mp3(mp3_path)
    audio.export(out_path, format="wav")
    if os.path.exists(mp3_path):
        os.remove(mp3_path)
    return out_path


def generate_video_fast(text: str, use_cached_voice: bool = False):
    """Fast generation using cached face + TTS or cached voice."""
    # Load cached face
    face_cache = load_face_cache()
    if not face_cache:
        return None, "‚ùå No cached face found. Run Setup Mode first."
    
    ts = datetime.now().strftime("%Y_%m_%d_%H.%M.%S")
    audio_path = os.path.join(RESULT_DIR, f"audio_{ts}.wav")
    
    # Option 1: Use cached voice file (if available and requested)
    if use_cached_voice:
        voice_cache = load_voice_cache()
        if voice_cache and os.path.exists(voice_cache['voice_path']):
            # Use the cached voice file directly
            import shutil
            shutil.copy(voice_cache['voice_path'], audio_path)
            print(f"Using cached voice: {os.path.basename(voice_cache['voice_path'])}")
        else:
            return None, "‚ùå No cached voice found. Run Setup Mode first or use TTS."
    else:
        # Option 2: Generate speech from text using TTS
        VOICES = {
            "en-US-JennyNeural": "en-US-JennyNeural",
            "en-US-GuyNeural": "en-US-GuyNeural",
        }
        voice_id = "en-US-JennyNeural"  # Default female voice
        print("Generating speech from text...")
        asyncio.run(text_to_speech_async(text.strip(), voice_id, audio_path))
    
    # Use cached face data for fast inference
    print("Running fast inference with cached face...")
    
    # Create temp dir for this generation
    gen_dir = os.path.join(RESULT_DIR, f"gen_{ts}")
    os.makedirs(gen_dir, exist_ok=True)
    
    # Copy cached coeff to gen dir
    import shutil
    cached_coeff = face_cache['first_coeff_path']
    gen_coeff_path = os.path.join(gen_dir, os.path.basename(cached_coeff))
    shutil.copy(cached_coeff, gen_coeff_path)
    
    # Run inference with cached face
    cmd = [
        sys.executable, "inference.py",
        "--driven_audio", audio_path,
        "--source_image", face_cache['image_path'],
        "--result_dir", gen_dir,
        "--checkpoint_dir", CHECKPOINT_DIR,
        "--still", "--preprocess", "full", "--enhancer", "gfpgan"
    ]
    
    env = os.environ.copy()
    env["PYTHONPATH"] = BASE_DIR
    
    r = subprocess.run(cmd, cwd=BASE_DIR, env=env, capture_output=True, text=True)
    
    if r.returncode != 0:
        err = (r.stderr or "").strip() or (r.stdout or "").strip()
        return None, f"Error: {err}"
    
    # Find output video
    mp4s = sorted(Path(gen_dir).rglob("*.mp4"), key=os.path.getmtime, reverse=True)
    if not mp4s:
        return None, "No output video found"
    
    return str(mp4s[0]), f"‚úì Generated: {os.path.basename(mp4s[0])}"


# Gradio UI
with gr.Blocks(title="SadTalker ‚Äî Optimized Cached") as demo:
    gr.Markdown("## üöÄ Optimized: Pre-process once ‚Üí Generate fast")
    
    with gr.Tabs():
        with gr.TabItem("1Ô∏è‚É£ Setup (Run Once)"):
            gr.Markdown("### Pre-process face + voice ‚Üí Cache for fast generation")
            
            # Auto-setup from assets
            auto_setup_btn = gr.Button("üöÄ Auto-Setup from Assets", variant="primary")
            auto_setup_status = gr.Textbox(label="Auto-Setup Status", interactive=False)
            
            def do_auto_setup():
                return auto_setup_from_assets()
            
            auto_setup_btn.click(fn=do_auto_setup, outputs=[auto_setup_status])
            
            gr.Markdown("---\n### Or Manual Setup:")
            
            with gr.Row():
                with gr.Column():
                    setup_image = gr.Image(type="filepath", label="Face Image")
                    setup_cache_id = gr.Textbox(label="Face Cache ID", value="default")
                    setup_face_btn = gr.Button("Pre-process Face", variant="secondary")
                
                with gr.Column():
                    setup_voice = gr.Audio(type="filepath", label="Voice Audio File")
                    setup_voice_id = gr.Textbox(label="Voice Cache ID", value="default")
                    setup_voice_btn = gr.Button("Cache Voice", variant="secondary")
            
            setup_status = gr.Textbox(label="Setup Status", interactive=False)
            
            def do_setup_face(image, cache_id):
                if not image:
                    return "Please upload a face image"
                image_path = image if isinstance(image, str) else image.get("path") or getattr(image, "name", None)
                cache_data, msg = preprocess_and_cache_face(image_path, cache_id or "default")
                return msg
            
            def do_setup_voice(audio, cache_id):
                if not audio:
                    return "Please upload a voice audio file"
                audio_path = audio if isinstance(audio, str) else audio.get("path") or getattr(audio, "name", None)
                cache_data, msg = preprocess_and_cache_voice(audio_path, cache_id or "default")
                return msg
            
            setup_face_btn.click(fn=do_setup_face, inputs=[setup_image, setup_cache_id], outputs=[setup_status])
            setup_voice_btn.click(fn=do_setup_voice, inputs=[setup_voice, setup_voice_id], outputs=[setup_status])
        
        with gr.TabItem("2Ô∏è‚É£ Generate (Fast)"):
            gr.Markdown("### Enter text ‚Üí Generate video (uses cached face + voice)")
            gen_text = gr.Textbox(label="Text to speak", lines=4, placeholder="Enter the text for the avatar to read...")
            
            with gr.Row():
                gen_mode = gr.Radio(
                    choices=["Use TTS (Text-to-Speech)", "Use Cached Voice File"],
                    value="Use TTS (Text-to-Speech)",
                    label="Audio Source"
                )
                gen_btn = gr.Button("üöÄ Generate Video", variant="primary", scale=2)
            
            gen_video = gr.Video(label="Output Video")
            gen_status = gr.Textbox(label="Status", interactive=False, lines=3)
            
            def do_generate(text, mode):
                if not text or not text.strip():
                    return None, "Please enter some text"
                
                use_cached = (mode == "Use Cached Voice File")
                video_path, status = generate_video_fast(text, use_cached_voice=use_cached)
                return video_path, status
            
            gen_btn.click(fn=do_generate, inputs=[gen_text, gen_mode], outputs=[gen_video, gen_status])

demo.launch(share=True, server_name="0.0.0.0", server_port=7860)