# üöÄ SadTalker: Optimized Cached Setup (Low Cost)

**Pre-process face + voice once ‚Üí generate videos from text instantly**

**Cost optimization:**
- ‚úÖ Pre-process face **once** (saves 3DMM coefficients)
- ‚úÖ Pre-process voice **once** (optional - use your voice model)
- ‚úÖ Generate videos from **text only** (no face/voice reprocessing)
- ‚úÖ **~10x faster** generation (no face detection/3DMM extraction each time)

## Step 1: Enable GPU

**Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU**

In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

## Step 2: Install dependencies

In [None]:
# Use Colab's pre-installed CUDA-enabled PyTorch
# Note: face-alignment (hyphen) installs but imports as face_alignment (underscore)
!pip install -q edge-tts face-alignment imageio imageio-ffmpeg librosa resampy pydub kornia yacs scikit-image basicsr facexlib gfpgan av safetensors gradio
!apt-get update -qq && apt-get install -y -qq ffmpeg 2>/dev/null || true

import torch
print(f"‚úì PyTorch {torch.__version__}")
print(f"‚úì CUDA available: {torch.cuda.is_available()}")

## Step 3: Download code + models (same as minimal setup)

In [None]:
# Download only essential code (src/ + inference.py)
import os
import zipfile
import urllib.request
import shutil

BASE_DIR = "/content/SadTalker"
TEMP_DIR = "/tmp/sadtalker_extract"
os.makedirs(BASE_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True)
os.chdir(BASE_DIR)

# Download repo as zip
repo_url = "https://github.com/OpenTalker/SadTalker/archive/refs/heads/main.zip"
zip_path = "/tmp/sadtalker.zip"
print("Downloading SadTalker repository...")
urllib.request.urlretrieve(repo_url, zip_path)

# Extract entire zip to temp directory
print("Extracting essential files...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(TEMP_DIR)

# Move only src/ and inference.py to BASE_DIR
extracted_root = os.path.join(TEMP_DIR, "SadTalker-main")
if os.path.exists(extracted_root):
    # Move src/ directory
    src_src = os.path.join(extracted_root, "src")
    src_dst = os.path.join(BASE_DIR, "src")
    if os.path.exists(src_src):
        if os.path.exists(src_dst):
            shutil.rmtree(src_dst)
        shutil.move(src_src, src_dst)
        print("‚úì Moved src/ directory")
    
    # Move inference.py
    inf_src = os.path.join(extracted_root, "inference.py")
    inf_dst = os.path.join(BASE_DIR, "inference.py")
    if os.path.exists(inf_src):
        if os.path.exists(inf_dst):
            os.remove(inf_dst)
        shutil.move(inf_src, inf_dst)
        print("‚úì Moved inference.py")

# Cleanup
shutil.rmtree(TEMP_DIR, ignore_errors=True)
os.remove(zip_path)

# Fix numpy compatibility issue
preprocess_file = os.path.join(BASE_DIR, "src", "face3d", "util", "preprocess.py")
if os.path.exists(preprocess_file):
    with open(preprocess_file, "r") as f:
        lines = f.readlines()
    
    fixed = False
    for i, line in enumerate(lines):
        if "np.VisibleDeprecationWarning" in line:
            if i > 0 and "try:" in lines[i-1]:
                continue
            
            indent = len(line) - len(line.lstrip())
            indent_str = " " * indent
            
            lines[i] = f"{indent_str}try:\n{indent_str}    warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning)\n{indent_str}except AttributeError:\n{indent_str}    pass  # VisibleDeprecationWarning removed in newer numpy\n"
            fixed = True
            break
    
    if fixed:
        with open(preprocess_file, "w") as f:
            f.writelines(lines)
        print("‚úì Fixed numpy VisibleDeprecationWarning issue")

# Fix np.float in my_awing_arch.py
awing_file = os.path.join(BASE_DIR, "src", "face3d", "util", "my_awing_arch.py")
if os.path.exists(awing_file):
    with open(awing_file, "r") as f:
        content = f.read()
    if "np.float" in content and "np.float64" not in content.split("np.float")[0][-20:]:
        import re
        content = re.sub(r'\bnp\.float\b', 'np.float64', content)
        with open(awing_file, "w") as f:
            f.write(content)
        print("‚úì Fixed np.float ‚Üí np.float64 in my_awing_arch.py")

print(f"\n‚úì Essential code extracted to {BASE_DIR}")
print(f"‚úì Found src/ directory: {os.path.exists(os.path.join(BASE_DIR, 'src'))}")
print(f"‚úì Found inference.py: {os.path.exists(os.path.join(BASE_DIR, 'inference.py'))}")

# Fix face detection: Add face-alignment fallback to croper.py
croper_file = os.path.join(BASE_DIR, "src", "utils", "croper.py")
if os.path.exists(croper_file):
    with open(croper_file, "r") as f:
        lines = f.readlines()
    
    # Check if already patched
    content_str = ''.join(lines)
    if "_get_face_alignment" not in content_str:
        # Find insertion point: after imports, before class Preprocesser
        insert_idx = None
        for i, line in enumerate(lines):
            if "class Preprocesser:" in line:
                insert_idx = i
                break
        
        if insert_idx is not None:
            # Insert face-alignment helper functions
            patch_code = """# Optional: 1adrianb/face-alignment returns 68 points directly (no 98->68 conversion)
_FACE_ALIGNMENT = None

def _get_face_alignment(device='cuda'):
    \"\"\"Lazy-init face_alignment.FaceAlignment (68 landmarks, TWO_D).\"\"\"
    global _FACE_ALIGNMENT
    if _FACE_ALIGNMENT is None:
        try:
            import face_alignment
            LandmarksType = face_alignment.LandmarksType
            # TWO_D or _2D depending on package version
            lm_type = getattr(LandmarksType, 'TWO_D', getattr(LandmarksType, '_2D', 1))
            fa_device = 'cpu' if device == 'cpu' else 'cuda'
            _FACE_ALIGNMENT = face_alignment.FaceAlignment(
                lm_type, device=fa_device, face_detector='sfd'
            )
        except Exception:
            _FACE_ALIGNMENT = False
    return _FACE_ALIGNMENT if _FACE_ALIGNMENT is not None and _FACE_ALIGNMENT is not False else None


"""
            lines.insert(insert_idx, patch_code)
            
            # Add device to __init__
            for i, line in enumerate(lines):
                if "def __init__(self, device='cuda'):" in line:
                    # Find next line with self.predictor
                    for j in range(i+1, min(i+5, len(lines))):
                        if "self.predictor = KeypointExtractor" in lines[j]:
                            # Check if device already added
                            if "self.device = device" not in ''.join(lines[i:j+3]):
                                lines.insert(j+1, "        self.device = device\n")
                            break
                    break
            
            # Add fallback method before get_landmark
            fallback_method = """    def _get_landmark_face_alignment(self, img_np, det):
        \"\"\"Fallback: use 1adrianb/face-alignment to get 68 landmarks on crop (no 98->68).\"\"\"
        fa = _get_face_alignment(self.device)
        if fa is None:
            return None
        try:
            img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
            if img.size == 0 or img.shape[0] == 0 or img.shape[1] == 0:
                return None
            # face_alignment expects RGB; get_landmarks returns list of (68, 2) or (68, 3)
            preds = fa.get_landmarks(img)
            if not preds or len(preds) == 0:
                return None
            lm = np.array(preds[0], dtype=np.float64)
            if lm.ndim == 2 and lm.shape[0] == 68:
                lm = lm[:, :2]
            else:
                return None
            lm[:, 0] += int(det[0])
            lm[:, 1] += int(det[1])
            return lm
        except Exception:
            return None

"""
            
            # Find get_landmark method and insert fallback before it
            for i, line in enumerate(lines):
                if "    def get_landmark(self, img_np):" in line:
                    if "_get_landmark_face_alignment" not in ''.join(lines[max(0,i-10):i]):
                        lines.insert(i, fallback_method)
                    break
            
            # Modify get_landmark to add fallback at the end (before final except)
            in_get_landmark = False
            for i in range(len(lines)):
                if "    def get_landmark(self, img_np):" in lines[i]:
                    in_get_landmark = True
                elif in_get_landmark and lines[i].strip().startswith("def ") and "get_landmark" not in lines[i]:
                    # End of get_landmark method
                    break
                elif in_get_landmark and "except Exception:" in lines[i] and "# Fallback: face_alignment" not in ''.join(lines[max(0,i-3):i+1]):
                    # Insert fallback before this except
                    fallback_call = "            # Fallback: face_alignment (68 landmarks directly)\n            lm = self._get_landmark_face_alignment(img_np, det)\n            return lm\n\n"
                    lines.insert(i, fallback_call)
                    break
            
            with open(croper_file, "w") as f:
                f.writelines(lines)
            print("‚úì Added face-alignment fallback to croper.py")
        else:
            print("‚ö† Could not find Preprocesser class in croper.py")
    else:
        print("‚úì croper.py already patched with face-alignment fallback")

## Step 4: Download models

In [None]:
import os
import urllib.request

CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Download safetensor models
print("Downloading SadTalker models (safetensor format)...")
models = [
    ("https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors", "SadTalker_V0.0.2_256.safetensors"),
    ("https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar", "mapping_00109-model.pth.tar"),  # For 'full' preprocess
]

for url, filename in models:
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    if not os.path.exists(filepath):
        print(f"  Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
        print(f"  ‚úì {filename}")
    else:
        print(f"  ‚úì {filename} (already exists)")

# Download GFPGAN weights for enhancer
GFPGAN_DIR = os.path.join(BASE_DIR, "gfpgan", "weights")
os.makedirs(GFPGAN_DIR, exist_ok=True)

print("\nDownloading GFPGAN enhancer weights...")
gfpgan_models = [
    ("https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth", "alignment_WFLW_4HG.pth"),
    ("https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", "detection_Resnet50_Final.pth"),
    ("https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", "GFPGANv1.4.pth"),
    ("https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth", "parsing_parsenet.pth"),
]

for url, filename in gfpgan_models:
    filepath = os.path.join(GFPGAN_DIR, filename)
    if not os.path.exists(filepath):
        print(f"  Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
        print(f"  ‚úì {filename}")
    else:
        print(f"  ‚úì {filename} (already exists)")

print("\n‚úì All models downloaded!")

## Step 4.5: Fix torchvision compatibility

In [None]:
# Fix torchvision compatibility: functional_tensor was moved in newer versions
import os
import site

# Find basicsr installation
basicsr_path = None
for path in site.getsitepackages():
    degradations_file = os.path.join(path, "basicsr", "data", "degradations.py")
    if os.path.exists(degradations_file):
        basicsr_path = degradations_file
        break

# Also check common Colab paths
if not basicsr_path:
    common_paths = [
        "/usr/local/lib/python3.12/dist-packages/basicsr/data/degradations.py",
        "/usr/local/lib/python3.11/dist-packages/basicsr/data/degradations.py",
        "/usr/local/lib/python3.10/dist-packages/basicsr/data/degradations.py",
    ]
    for p in common_paths:
        if os.path.exists(p):
            basicsr_path = p
            break

if basicsr_path:
    with open(basicsr_path, "r") as f:
        content = f.read()
    
    # Replace the problematic import
    old_import = "from torchvision.transforms.functional_tensor import rgb_to_grayscale"
    new_import = "from torchvision.transforms.functional import rgb_to_grayscale"
    
    if old_import in content and new_import not in content:
        content = content.replace(old_import, new_import)
        with open(basicsr_path, "w") as f:
            f.write(content)
        print(f"‚úì Fixed torchvision compatibility in {basicsr_path}")
    elif new_import in content:
        print("‚úì Torchvision compatibility already fixed")
    else:
        print(f"‚ö† Import line not found in {basicsr_path}")
else:
    print("‚ö† Could not find basicsr/degradations.py - trying alternative fix...")
    # Alternative: monkey-patch at import time
    import sys
    
    def patch_torchvision():
        import torchvision.transforms as transforms_module
        if not hasattr(transforms_module, 'functional_tensor'):
            from torchvision.transforms import functional as functional_tensor
            transforms_module.functional_tensor = functional_tensor
    
    sys.modules['torchvision.transforms.functional_tensor'] = None
    patch_torchvision()
    print("‚úì Applied runtime patch")

## Step 3.5: Setup Assets Directory

In [None]:
# Create assets directory structure
import os
BASE_DIR = "/content/SadTalker"
ASSETS_DIR = os.path.join(BASE_DIR, "assets")
os.makedirs(os.path.join(ASSETS_DIR, "image"), exist_ok=True)
os.makedirs(os.path.join(ASSETS_DIR, "audio"), exist_ok=True)

print("‚úì Assets directory created")
print(f"  Image folder: {os.path.join(ASSETS_DIR, 'image')}")
print(f"  Audio folder: {os.path.join(ASSETS_DIR, 'audio')}")
print("\nüìÅ Upload your files:")
print("  - female-image-01.jpg ‚Üí assets/image/")
print("  - female-voice-01.mp3 ‚Üí assets/audio/")

## Step 4: Optimized Cached Pipeline

**Two modes:**
1. **Setup Mode**: Upload face image + voice ‚Üí pre-process and cache
2. **Generate Mode**: Enter text ‚Üí use cached face/voice ‚Üí fast generation

In [None]:
import os
import sys
import subprocess
import pickle
import json
from pathlib import Path
from datetime import datetime
import asyncio
import edge_tts
from pydub import AudioSegment
import gradio as gr
import cv2
import numpy as np

# Fix numpy 2.x compatibility
if not hasattr(np, 'float'):
    np.float = float
if not hasattr(np, 'int'):
    np.int = int

BASE_DIR = "/content/SadTalker"
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
CACHE_DIR = os.path.join(BASE_DIR, "cache")
os.makedirs(CACHE_DIR, exist_ok=True)
os.chdir(BASE_DIR)
sys.path.insert(0, BASE_DIR)

RESULT_DIR = os.path.join(BASE_DIR, "results")
os.makedirs(RESULT_DIR, exist_ok=True)

# Cache file paths
FACE_CACHE_FILE = os.path.join(CACHE_DIR, "face_cache.pkl")
VOICE_CACHE_FILE = os.path.join(CACHE_DIR, "voice_cache.pkl")

# Pre-load your assets (adjust paths as needed)
ASSETS_DIR = os.path.join(BASE_DIR, "assets")
DEFAULT_IMAGE = os.path.join(ASSETS_DIR, "image", "female-image-01.jpg")
DEFAULT_VOICE = os.path.join(ASSETS_DIR, "audio", "female-voice-01.mp3")


def preprocess_and_cache_face(image_path: str, cache_id: str = "default"):
    """Pre-process face once and cache the results."""
    try:
        print("Pre-processing face (this runs once)...")
        
        # Validate image path
        if not image_path or not os.path.exists(image_path):
            return None, f"‚ùå Image file not found: {image_path}"
        
        # Check if image is readable
        img = cv2.imread(image_path)
        if img is None:
            return None, f"‚ùå Could not read image file: {image_path}. Make sure it's a valid image (PNG/JPG)."
        
        from src.utils.preprocess import CropAndExtract
        from src.utils.init_path import init_path
        import torch
        
        # Check checkpoints exist
        if not os.path.exists(CHECKPOINT_DIR):
            return None, f"‚ùå Checkpoints directory not found: {CHECKPOINT_DIR}\nPlease run Step 3-4 to download models first."
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")
        
        try:
            sadtalker_paths = init_path(CHECKPOINT_DIR, os.path.join(BASE_DIR, 'src/config'), 256, False, 'full')
        except Exception as e:
            return None, f"‚ùå Failed to initialize paths: {str(e)}\nMake sure checkpoints are downloaded (Step 3-4)."
        
        try:
            preprocess_model = CropAndExtract(sadtalker_paths, device)
        except Exception as e:
            return None, f"‚ùå Failed to load preprocess model: {str(e)}\nCheck if all model files are downloaded."
        
        # Extract face coefficients (expensive operation - done once)
        cache_frame_dir = os.path.join(CACHE_DIR, f"face_{cache_id}")
        os.makedirs(cache_frame_dir, exist_ok=True)
        
        try:
            first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
                image_path, cache_frame_dir, 'full', source_image_flag=True, pic_size=256
            )
        except Exception as e:
            error_msg = str(e)
            if "face" in error_msg.lower() or "detect" in error_msg.lower():
                return None, f"‚ùå Face detection failed: {error_msg}\n\nTip: Use a clear front-facing face image with good lighting."
            return None, f"‚ùå Pre-processing error: {error_msg}"
        
        if first_coeff_path is None:
            return None, "‚ùå Face detection failed - no face found in image.\n\nTip: Use a clear front-facing face image."
        
        # Cache the results
        cache_data = {
            'first_coeff_path': first_coeff_path,
            'crop_pic_path': crop_pic_path,
            'crop_info': crop_info,
            'image_path': image_path,
            'cache_id': cache_id
        }
        
        with open(FACE_CACHE_FILE, 'wb') as f:
            pickle.dump(cache_data, f)
        
        del preprocess_model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return cache_data, "‚úì Face pre-processed and cached!"
    
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        return None, f"‚ùå Unexpected error during face pre-processing:\n{str(e)}\n\nDetails:\n{error_details}"


def load_face_cache():
    """Load cached face data."""
    if os.path.exists(FACE_CACHE_FILE):
        with open(FACE_CACHE_FILE, 'rb') as f:
            return pickle.load(f)
    return None


def preprocess_and_cache_voice(audio_path: str, cache_id: str = "default"):
    """Pre-process voice file (convert to WAV, store path)."""
    print("Pre-processing voice file...")
    
    # Convert MP3 to WAV if needed
    if audio_path.endswith('.mp3'):
        wav_path = audio_path.replace('.mp3', '.wav')
        if not os.path.exists(wav_path):
            audio = AudioSegment.from_mp3(audio_path)
            audio.export(wav_path, format="wav")
        audio_path = wav_path
    
    # Cache voice file path
    cache_data = {
        'voice_path': audio_path,
        'cache_id': cache_id
    }
    
    with open(VOICE_CACHE_FILE, 'wb') as f:
        pickle.dump(cache_data, f)
    
    return cache_data, f"‚úì Voice file cached: {os.path.basename(audio_path)}"


def load_voice_cache():
    """Load cached voice data."""
    if os.path.exists(VOICE_CACHE_FILE):
        with open(VOICE_CACHE_FILE, 'rb') as f:
            return pickle.load(f)
    return None


def auto_setup_from_assets():
    """Automatically setup using default assets if they exist."""
    face_cache = load_face_cache()
    voice_cache = load_voice_cache()
    
    setup_done = []
    
    # Setup face if image exists and not cached
    if os.path.exists(DEFAULT_IMAGE) and not face_cache:
        print("Auto-setting up face from assets...")
        cache_data, msg = preprocess_and_cache_face(DEFAULT_IMAGE, "female-01")
        setup_done.append(f"Face: {msg}")
    
    # Setup voice if audio exists and not cached
    if os.path.exists(DEFAULT_VOICE) and not voice_cache:
        print("Auto-setting up voice from assets...")
        cache_data, msg = preprocess_and_cache_voice(DEFAULT_VOICE, "female-01")
        setup_done.append(f"Voice: {msg}")
    
    return "\n".join(setup_done) if setup_done else "‚úì Assets already cached or not found"


async def text_to_speech_async(text: str, voice: str, out_path: str):
    """Generate speech from text."""
    mp3_path = out_path.replace(".wav", ".mp3")
    communicate = edge_tts.Communicate(text, voice)
    await communicate.save(mp3_path)
    audio = AudioSegment.from_mp3(mp3_path)
    audio.export(out_path, format="wav")
    if os.path.exists(mp3_path):
        os.remove(mp3_path)
    return out_path


def generate_video_fast(text: str, use_cached_voice: bool = False):
    """Fast generation using cached face + TTS or cached voice."""
    # Load cached face
    face_cache = load_face_cache()
    if not face_cache:
        return None, "‚ùå No cached face found. Run Setup Mode first."
    
    ts = datetime.now().strftime("%Y_%m_%d_%H.%M.%S")
    audio_path = os.path.join(RESULT_DIR, f"audio_{ts}.wav")
    
    # Option 1: Use cached voice file (if available and requested)
    if use_cached_voice:
        voice_cache = load_voice_cache()
        if voice_cache and os.path.exists(voice_cache['voice_path']):
            # Use the cached voice file directly
            import shutil
            shutil.copy(voice_cache['voice_path'], audio_path)
            print(f"Using cached voice: {os.path.basename(voice_cache['voice_path'])}")
        else:
            return None, "‚ùå No cached voice found. Run Setup Mode first or use TTS."
    else:
        # Option 2: Generate speech from text using TTS
        VOICES = {
            "en-US-JennyNeural": "en-US-JennyNeural",
            "en-US-GuyNeural": "en-US-GuyNeural",
        }
        voice_id = "en-US-JennyNeural"  # Default female voice
        print("Generating speech from text...")
        asyncio.run(text_to_speech_async(text.strip(), voice_id, audio_path))
    
    # Use cached face data for fast inference
    print("Running fast inference with cached face...")
    
    # Create temp dir for this generation
    gen_dir = os.path.join(RESULT_DIR, f"gen_{ts}")
    os.makedirs(gen_dir, exist_ok=True)
    
    # Copy cached coeff to gen dir
    import shutil
    cached_coeff = face_cache['first_coeff_path']
    gen_coeff_path = os.path.join(gen_dir, os.path.basename(cached_coeff))
    shutil.copy(cached_coeff, gen_coeff_path)
    
    # Run inference with cached face
    cmd = [
        sys.executable, "inference.py",
        "--driven_audio", audio_path,
        "--source_image", face_cache['image_path'],
        "--result_dir", gen_dir,
        "--checkpoint_dir", CHECKPOINT_DIR,
        "--still", "--preprocess", "full", "--enhancer", "gfpgan"
    ]
    
    env = os.environ.copy()
    env["PYTHONPATH"] = BASE_DIR
    
    r = subprocess.run(cmd, cwd=BASE_DIR, env=env, capture_output=True, text=True)
    
    if r.returncode != 0:
        err = (r.stderr or "").strip() or (r.stdout or "").strip()
        return None, f"Error: {err}"
    
    # Find output video
    mp4s = sorted(Path(gen_dir).rglob("*.mp4"), key=os.path.getmtime, reverse=True)
    if not mp4s:
        return None, "No output video found"
    
    return str(mp4s[0]), f"‚úì Generated: {os.path.basename(mp4s[0])}"


# Gradio UI
with gr.Blocks(title="SadTalker ‚Äî Optimized Cached") as demo:
    gr.Markdown("## üöÄ Optimized: Pre-process once ‚Üí Generate fast")
    
    with gr.Tabs():
        with gr.TabItem("1Ô∏è‚É£ Setup (Run Once)"):
            gr.Markdown("### Pre-process face + voice ‚Üí Cache for fast generation")
            
            # Auto-setup from assets
            auto_setup_btn = gr.Button("üöÄ Auto-Setup from Assets", variant="primary")
            auto_setup_status = gr.Textbox(label="Auto-Setup Status", interactive=False)
            
            def do_auto_setup():
                return auto_setup_from_assets()
            
            auto_setup_btn.click(fn=do_auto_setup, outputs=[auto_setup_status])
            
            gr.Markdown("---\n### Or Manual Setup:")
            
            with gr.Row():
                with gr.Column():
                    setup_image = gr.Image(type="filepath", label="Face Image")
                    setup_cache_id = gr.Textbox(label="Face Cache ID", value="default")
                    setup_face_btn = gr.Button("Pre-process Face", variant="secondary")
                
                with gr.Column():
                    setup_voice = gr.Audio(type="filepath", label="Voice Audio File")
                    setup_voice_id = gr.Textbox(label="Voice Cache ID", value="default")
                    setup_voice_btn = gr.Button("Cache Voice", variant="secondary")
            
            setup_status = gr.Textbox(label="Setup Status", interactive=False)
            
            def do_setup_face(image, cache_id):
                try:
                    if not image:
                        return "‚ùå Please upload a face image"
                    
                    # Handle Gradio file upload format
                    if isinstance(image, str):
                        image_path = image
                    elif hasattr(image, "name"):
                        image_path = image.name
                    elif isinstance(image, dict):
                        image_path = image.get("path") or image.get("name")
                    else:
                        return "‚ùå Invalid image format. Please upload a PNG or JPG image."
                    
                    if not image_path:
                        return "‚ùå Could not get image path. Please try uploading again."
                    
                    cache_data, msg = preprocess_and_cache_face(image_path, cache_id or "default")
                    
                    if cache_data is None:
                        return msg  # Error message
                    
                    return msg  # Success message
                    
                except Exception as e:
                    import traceback
                    return f"‚ùå Error in setup: {str(e)}\n\n{traceback.format_exc()}"
            
            def do_setup_voice(audio, cache_id):
                if not audio:
                    return "Please upload a voice audio file"
                audio_path = audio if isinstance(audio, str) else audio.get("path") or getattr(audio, "name", None)
                cache_data, msg = preprocess_and_cache_voice(audio_path, cache_id or "default")
                return msg
            
            setup_face_btn.click(fn=do_setup_face, inputs=[setup_image, setup_cache_id], outputs=[setup_status])
            setup_voice_btn.click(fn=do_setup_voice, inputs=[setup_voice, setup_voice_id], outputs=[setup_status])
        
        with gr.TabItem("2Ô∏è‚É£ Generate (Fast)"):
            gr.Markdown("### Enter text ‚Üí Generate video (uses cached face + voice)")
            gen_text = gr.Textbox(label="Text to speak", lines=4, placeholder="Enter the text for the avatar to read...")
            
            with gr.Row():
                gen_mode = gr.Radio(
                    choices=["Use TTS (Text-to-Speech)", "Use Cached Voice File"],
                    value="Use TTS (Text-to-Speech)",
                    label="Audio Source"
                )
                gen_btn = gr.Button("üöÄ Generate Video", variant="primary", scale=2)
            
            gen_video = gr.Video(label="Output Video")
            gen_status = gr.Textbox(label="Status", interactive=False, lines=3)
            
            def do_generate(text, mode):
                if not text or not text.strip():
                    return None, "Please enter some text"
                
                use_cached = (mode == "Use Cached Voice File")
                video_path, status = generate_video_fast(text, use_cached_voice=use_cached)
                return video_path, status
            
            gen_btn.click(fn=do_generate, inputs=[gen_text, gen_mode], outputs=[gen_video, gen_status])

demo.launch(share=True, server_name="0.0.0.0", server_port=7860)