# üé≠ SadTalker: Minimal Setup (No Full Clone)

**Uses Colab's pre-installed PyTorch.** Downloads only essential code + models. No full repo clone.

1. Upload **face image** and **audio** (base sound)
2. Click **Generate** ‚Üí talking video with lip sync
3. Gradio UI for easy interaction

## Step 1: Enable GPU

**Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU**

In [None]:
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

## Step 2: Install dependencies

In [None]:
# Use Colab's pre-installed CUDA-enabled PyTorch (no need to reinstall)
# Only install packages Colab may not have or that SadTalker needs
# Note: face-alignment (hyphen) installs but imports as face_alignment (underscore)
!pip install -q edge-tts face-alignment imageio imageio-ffmpeg librosa resampy pydub kornia yacs scikit-image basicsr facexlib gfpgan av safetensors gradio
!apt-get update -qq && apt-get install -y -qq ffmpeg 2>/dev/null || true

# Verify CUDA is available (Colab's PyTorch should already have it)
import torch
print(f"‚úì PyTorch {torch.__version__}")
print(f"‚úì CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úì CUDA device: {torch.cuda.get_device_name(0)}")

## Step 3: Download only essential code (src/ + inference.py)

In [None]:
import os
import zipfile
import urllib.request
import shutil

BASE_DIR = "/content/SadTalker"
TEMP_DIR = "/tmp/sadtalker_extract"
os.makedirs(BASE_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True)
os.chdir(BASE_DIR)

# Download repo as zip
repo_url = "https://github.com/OpenTalker/SadTalker/archive/refs/heads/main.zip"
zip_path = "/tmp/sadtalker.zip"
print("Downloading SadTalker repository...")
urllib.request.urlretrieve(repo_url, zip_path)

# Extract entire zip to temp directory
print("Extracting essential files...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(TEMP_DIR)

# Move only src/ and inference.py to BASE_DIR
extracted_root = os.path.join(TEMP_DIR, "SadTalker-main")
if os.path.exists(extracted_root):
    # Move src/ directory
    src_src = os.path.join(extracted_root, "src")
    src_dst = os.path.join(BASE_DIR, "src")
    if os.path.exists(src_src):
        if os.path.exists(src_dst):
            shutil.rmtree(src_dst)
        shutil.move(src_src, src_dst)
        print("‚úì Moved src/ directory")
    
    # Move inference.py
    inf_src = os.path.join(extracted_root, "inference.py")
    inf_dst = os.path.join(BASE_DIR, "inference.py")
    if os.path.exists(inf_src):
        if os.path.exists(inf_dst):
            os.remove(inf_dst)
        shutil.move(inf_src, inf_dst)
        print("‚úì Moved inference.py")

# Cleanup
shutil.rmtree(TEMP_DIR, ignore_errors=True)
os.remove(zip_path)

# Fix numpy compatibility issue (VisibleDeprecationWarning removed in newer numpy)
preprocess_file = os.path.join(BASE_DIR, "src", "face3d", "util", "preprocess.py")
if os.path.exists(preprocess_file):
    with open(preprocess_file, "r") as f:
        lines = f.readlines()
    
    # Find and fix the problematic line
    fixed = False
    for i, line in enumerate(lines):
        if "np.VisibleDeprecationWarning" in line:
            # Check if already fixed
            if i > 0 and "try:" in lines[i-1]:
                continue
            
            # Get indentation
            indent = len(line) - len(line.lstrip())
            indent_str = " " * indent
            
            # Replace single line with try/except block
            lines[i] = f"{indent_str}try:\n{indent_str}    warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning)\n{indent_str}except AttributeError:\n{indent_str}    pass  # VisibleDeprecationWarning removed in newer numpy\n"
            fixed = True
            break
    
    if fixed:
        with open(preprocess_file, "w") as f:
            f.writelines(lines)
        print("‚úì Fixed numpy VisibleDeprecationWarning issue")

# Fix np.float in my_awing_arch.py
awing_file = os.path.join(BASE_DIR, "src", "face3d", "util", "my_awing_arch.py")
if os.path.exists(awing_file):
    with open(awing_file, "r") as f:
        content = f.read()
    if "np.float" in content and "np.float64" not in content.split("np.float")[0][-20:]:
        import re
        content = re.sub(r'\bnp\.float\b', 'np.float64', content)
        with open(awing_file, "w") as f:
            f.write(content)
        print("‚úì Fixed np.float ‚Üí np.float64 in my_awing_arch.py")

print(f"\n‚úì Essential code extracted to {BASE_DIR}")
print(f"‚úì Found src/ directory: {os.path.exists(os.path.join(BASE_DIR, 'src'))}")
print(f"‚úì Found inference.py: {os.path.exists(os.path.join(BASE_DIR, 'inference.py'))}")

# Fix face detection: Add face-alignment fallback to croper.py
croper_file = os.path.join(BASE_DIR, "src", "utils", "croper.py")
if os.path.exists(croper_file):
    with open(croper_file, "r") as f:
        lines = f.readlines()
    
    # Check if already patched
    content_str = ''.join(lines)
    if "_get_face_alignment" not in content_str:
        # Find insertion point: after imports, before class Preprocesser
        insert_idx = None
        for i, line in enumerate(lines):
            if "class Preprocesser:" in line:
                insert_idx = i
                break
        
        if insert_idx is not None:
            # Insert face-alignment helper functions
            patch_code = """# Optional: 1adrianb/face-alignment returns 68 points directly (no 98->68 conversion)
_FACE_ALIGNMENT = None

def _get_face_alignment(device='cuda'):
    \"\"\"Lazy-init face_alignment.FaceAlignment (68 landmarks, TWO_D).\"\"\"
    global _FACE_ALIGNMENT
    if _FACE_ALIGNMENT is None:
        try:
            import face_alignment
            LandmarksType = face_alignment.LandmarksType
            # TWO_D or _2D depending on package version
            lm_type = getattr(LandmarksType, 'TWO_D', getattr(LandmarksType, '_2D', 1))
            fa_device = 'cpu' if device == 'cpu' else 'cuda'
            _FACE_ALIGNMENT = face_alignment.FaceAlignment(
                lm_type, device=fa_device, face_detector='sfd'
            )
        except Exception:
            _FACE_ALIGNMENT = False
    return _FACE_ALIGNMENT if _FACE_ALIGNMENT is not None and _FACE_ALIGNMENT is not False else None


"""
            lines.insert(insert_idx, patch_code)
            
            # Add device to __init__
            for i, line in enumerate(lines):
                if "def __init__(self, device='cuda'):" in line:
                    # Find next line with self.predictor
                    for j in range(i+1, min(i+5, len(lines))):
                        if "self.predictor = KeypointExtractor" in lines[j]:
                            # Check if device already added
                            if "self.device = device" not in ''.join(lines[i:j+3]):
                                lines.insert(j+1, "        self.device = device\n")
                            break
                    break
            
            # Add fallback method before get_landmark
            fallback_method = """    def _get_landmark_face_alignment(self, img_np, det):
        \"\"\"Fallback: use 1adrianb/face-alignment to get 68 landmarks on crop (no 98->68).\"\"\"
        fa = _get_face_alignment(self.device)
        if fa is None:
            return None
        try:
            img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
            if img.size == 0 or img.shape[0] == 0 or img.shape[1] == 0:
                return None
            # face_alignment expects RGB; get_landmarks returns list of (68, 2) or (68, 3)
            preds = fa.get_landmarks(img)
            if not preds or len(preds) == 0:
                return None
            lm = np.array(preds[0], dtype=np.float64)
            if lm.ndim == 2 and lm.shape[0] == 68:
                lm = lm[:, :2]
            else:
                return None
            lm[:, 0] += int(det[0])
            lm[:, 1] += int(det[1])
            return lm
        except Exception:
            return None

"""
            
            # Find get_landmark method and insert fallback before it
            for i, line in enumerate(lines):
                if "    def get_landmark(self, img_np):" in line:
                    if "_get_landmark_face_alignment" not in ''.join(lines[max(0,i-10):i]):
                        lines.insert(i, fallback_method)
                    break
            
            # Modify get_landmark to add fallback at the end (before final except)
            in_get_landmark = False
            for i in range(len(lines)):
                if "    def get_landmark(self, img_np):" in lines[i]:
                    in_get_landmark = True
                elif in_get_landmark and lines[i].strip().startswith("def ") and "get_landmark" not in lines[i]:
                    # End of get_landmark method
                    break
                elif in_get_landmark and "except Exception:" in lines[i] and "# Fallback: face_alignment" not in ''.join(lines[max(0,i-3):i+1]):
                    # Insert fallback before this except
                    fallback_call = "            # Fallback: face_alignment (68 landmarks directly)\n            lm = self._get_landmark_face_alignment(img_np, det)\n            return lm\n\n"
                    lines.insert(i, fallback_call)
                    break
            
            with open(croper_file, "w") as f:
                f.writelines(lines)
            print("‚úì Added face-alignment fallback to croper.py")
        else:
            print("‚ö† Could not find Preprocesser class in croper.py")
    else:
        print("‚úì croper.py already patched with face-alignment fallback")

## Step 3.5: Fix numpy compatibility (if Step 3 had issues)

In [None]:
# DIRECT FIX for IndentationError - run this cell
import os
preprocess_file = "/content/SadTalker/src/face3d/util/preprocess.py"

if os.path.exists(preprocess_file):
    with open(preprocess_file, "r") as f:
        lines = f.readlines()
    
    # Find the problematic area (around line 12-13)
    new_lines = []
    i = 0
    while i < len(lines):
        line = lines[i]
        
        # If we see a standalone "try:" followed by warnings line without proper indent
        if "try:" in line and i+1 < len(lines):
            next_line = lines[i+1]
            if "np.VisibleDeprecationWarning" in next_line:
                try_indent = len(line) - len(line.lstrip())
                warn_indent = len(next_line) - len(next_line.lstrip())
                
                # If warnings line is NOT properly indented (should be +4 spaces)
                if warn_indent <= try_indent:
                    # Remove the broken "try:" line, keep warnings line
                    indent = warn_indent
                    # Replace with proper try/except block
                    new_lines.append(" " * indent + "try:\n")
                    new_lines.append(" " * indent + "    warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning)\n")
                    new_lines.append(" " * indent + "except AttributeError:\n")
                    new_lines.append(" " * indent + "    pass  # VisibleDeprecationWarning removed in newer numpy\n")
                    i += 2  # Skip both try: and warnings line
                    continue
        
        # If we see warnings line without try/except wrapper
        if "np.VisibleDeprecationWarning" in line and i > 0:
            prev_line = lines[i-1]
            if "try:" not in prev_line:
                indent = len(line) - len(line.lstrip())
                # Insert try/except before this line
                new_lines.append(" " * indent + "try:\n")
                new_lines.append(" " * indent + "    warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning)\n")
                new_lines.append(" " * indent + "except AttributeError:\n")
                new_lines.append(" " * indent + "    pass  # VisibleDeprecationWarning removed in newer numpy\n")
                i += 1
                continue
        
        new_lines.append(line)
        i += 1
    
    # Write fixed content
    with open(preprocess_file, "w") as f:
        f.writelines(new_lines)
    
    print("‚úì Fixed preprocess.py! Now fixing my_awing_arch.py...")
else:
    print("‚ö† preprocess.py not found. Run Step 3 first.")

# Also fix np.float in my_awing_arch.py
awing_file = "/content/SadTalker/src/face3d/util/my_awing_arch.py"
if os.path.exists(awing_file):
    with open(awing_file, "r") as f:
        content = f.read()
    if "np.float" in content:
        import re
        content = re.sub(r'\bnp\.float\b', 'np.float64', content)
        with open(awing_file, "w") as f:
            f.write(content)
        print("‚úì Fixed np.float ‚Üí np.float64 in my_awing_arch.py")
    else:
        print("‚úì my_awing_arch.py already fixed")
else:
    print("‚ö† my_awing_arch.py not found")

print("\n‚úì All fixes applied! Re-run Step 5 (Gradio UI) now.")

# Also fix torchvision compatibility issue
print("\nFixing torchvision compatibility...")
try:
    import torchvision.transforms as transforms_module
    if not hasattr(transforms_module, 'functional_tensor'):
        # In newer torchvision, functional_tensor was moved/renamed
        try:
            from torchvision.transforms import functional as functional_tensor
            transforms_module.functional_tensor = functional_tensor
            print("‚úì Fixed torchvision compatibility")
        except ImportError:
            # Create a minimal shim
            import torchvision.transforms.functional as F
            transforms_module.functional_tensor = F
            print("‚úì Fixed torchvision compatibility (shim)")
except Exception as e:
    print(f"‚ö† Torchvision fix warning: {e}")

## Step 4: Download only required models

## Step 4.5: Fix torchvision compatibility

In [None]:
# Fix torchvision compatibility: functional_tensor was moved in newer versions
# Patch basicsr file directly since inference.py runs in subprocess
import os
import site

# Find basicsr installation
basicsr_path = None
for path in site.getsitepackages():
    degradations_file = os.path.join(path, "basicsr", "data", "degradations.py")
    if os.path.exists(degradations_file):
        basicsr_path = degradations_file
        break

# Also check common Colab paths
if not basicsr_path:
    common_paths = [
        "/usr/local/lib/python3.12/dist-packages/basicsr/data/degradations.py",
        "/usr/local/lib/python3.11/dist-packages/basicsr/data/degradations.py",
        "/usr/local/lib/python3.10/dist-packages/basicsr/data/degradations.py",
    ]
    for p in common_paths:
        if os.path.exists(p):
            basicsr_path = p
            break

if basicsr_path:
    with open(basicsr_path, "r") as f:
        content = f.read()
    
    # Replace the problematic import
    old_import = "from torchvision.transforms.functional_tensor import rgb_to_grayscale"
    new_import = "from torchvision.transforms.functional import rgb_to_grayscale"
    
    if old_import in content and new_import not in content:
        content = content.replace(old_import, new_import)
        with open(basicsr_path, "w") as f:
            f.write(content)
        print(f"‚úì Fixed torchvision compatibility in {basicsr_path}")
    elif new_import in content:
        print("‚úì Torchvision compatibility already fixed")
    else:
        print(f"‚ö† Import line not found in {basicsr_path}")
else:
    print("‚ö† Could not find basicsr/degradations.py - trying alternative fix...")
    # Alternative: monkey-patch at import time
    import sys
    import importlib.util
    
    def patch_torchvision():
        import torchvision.transforms as transforms_module
        if not hasattr(transforms_module, 'functional_tensor'):
            from torchvision.transforms import functional as functional_tensor
            transforms_module.functional_tensor = functional_tensor
    
    # This will be applied when basicsr imports
    sys.modules['torchvision.transforms.functional_tensor'] = None
    patch_torchvision()
    print("‚úì Applied runtime patch (may need to restart kernel)")

In [None]:
import os
import urllib.request

CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Download safetensor models (newer, simpler - just one file per size)
print("Downloading SadTalker models (safetensor format)...")
models = [
    ("https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors", "SadTalker_V0.0.2_256.safetensors"),
    ("https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar", "mapping_00109-model.pth.tar"),  # For 'full' preprocess
]

for url, filename in models:
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    if not os.path.exists(filepath):
        print(f"  Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
        print(f"  ‚úì {filename}")
    else:
        print(f"  ‚úì {filename} (already exists)")

# Download GFPGAN weights for enhancer
GFPGAN_DIR = os.path.join(BASE_DIR, "gfpgan", "weights")
os.makedirs(GFPGAN_DIR, exist_ok=True)

print("\nDownloading GFPGAN enhancer weights...")
gfpgan_models = [
    ("https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth", "alignment_WFLW_4HG.pth"),
    ("https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", "detection_Resnet50_Final.pth"),
    ("https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", "GFPGANv1.4.pth"),
    ("https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth", "parsing_parsenet.pth"),
]

for url, filename in gfpgan_models:
    filepath = os.path.join(GFPGAN_DIR, filename)
    if not os.path.exists(filepath):
        print(f"  Downloading {filename}...")
        urllib.request.urlretrieve(url, filepath)
        print(f"  ‚úì {filename}")
    else:
        print(f"  ‚úì {filename} (already exists)")

print("\n‚úì All models downloaded!")

## Step 5: Gradio UI ‚Äî upload image + audio ‚Üí talking video

In [None]:
import os
import sys
import subprocess
from pathlib import Path
import gradio as gr
import cv2
import numpy as np

# Fix numpy 2.x compatibility: np.float was removed
if not hasattr(np, 'float'):
    np.float = float
if not hasattr(np, 'int'):
    np.int = int
if not hasattr(np, 'complex'):
    np.complex = complex
if not hasattr(np, 'bool'):
    np.bool = bool

BASE_DIR = "/content/SadTalker"
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")
os.chdir(BASE_DIR)
sys.path.insert(0, BASE_DIR)

RESULT_DIR = os.path.join(BASE_DIR, "results")
os.makedirs(RESULT_DIR, exist_ok=True)

# Lazy-load face detection (same as SadTalker uses)
_face_preprocessor = None

def get_face_preprocessor():
    """Lazy-load face detection preprocessor."""
    global _face_preprocessor
    if _face_preprocessor is None:
        try:
            from src.utils.croper import Preprocesser
            import torch
            device = "cuda" if torch.cuda.is_available() else "cpu"
            _face_preprocessor = Preprocesser(device=device)
        except Exception as e:
            print(f"‚ö† Face detection init warning: {e}")
    return _face_preprocessor


def validate_face_detection(image_path: str) -> tuple[bool, str]:
    """Validate face detection using SadTalker's method. Returns (success, message)."""
    preprocessor = get_face_preprocessor()
    if preprocessor is None:
        return True, ""  # Skip validation if not initialized
    
    try:
        img = cv2.imread(image_path)
        if img is None:
            return False, "Could not read image file."
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Use same method as SadTalker: get_landmark() with detect_faces()
        lm = preprocessor.get_landmark(img_rgb)
        
        if lm is None:
            return False, "‚ùå No face detected. Use a clear front-facing face image."
        
        return True, f"‚úì Face detected (68 landmarks found)"
    except Exception as e:
        return False, f"Face detection error: {str(e)}"


def run_sadtalker(image_path: str, audio_path: str, result_dir: str) -> str:
    """Run SadTalker inference. Returns path to generated MP4."""
    cmd = [
        sys.executable, "inference.py",
        "--driven_audio", audio_path,
        "--source_image", image_path,
        "--result_dir", result_dir,
        "--checkpoint_dir", CHECKPOINT_DIR,
        "--still", "--preprocess", "full", "--enhancer", "gfpgan"
    ]
    env = os.environ.copy()
    env["PYTHONPATH"] = BASE_DIR
    r = subprocess.run(cmd, cwd=BASE_DIR, env=env, capture_output=True, text=True)
    if r.returncode != 0:
        err = (r.stderr or "").strip() or (r.stdout or "").strip()
        raise RuntimeError(f"inference.py failed:\n{err}")
    mp4s = sorted(Path(result_dir).rglob("*.mp4"), key=os.path.getmtime, reverse=True)
    if not mp4s:
        raise FileNotFoundError("No output video found.")
    return str(mp4s[0])


def _to_path(x):
    if x is None: return None
    if isinstance(x, str): return x
    if hasattr(x, "name"): return x.name
    return x.get("path") or x.get("name")


def generate_video(image, audio):
    if image is None:
        return None, "Please upload a face image."
    if audio is None:
        return None, "Please upload an audio file (base sound)."
    image_path = _to_path(image)
    audio_path = _to_path(audio)
    if not image_path or not os.path.isfile(image_path):
        return None, "Invalid image file."
    if not audio_path or not os.path.isfile(audio_path):
        return None, "Invalid audio file."
    
    # Validate face detection (same method as SadTalker inference uses)
    face_ok, face_msg = validate_face_detection(image_path)
    if not face_ok:
        return None, face_msg
    
    try:
        video_path = run_sadtalker(image_path, audio_path, RESULT_DIR)
        return video_path, f"‚úì Done: {os.path.basename(video_path)}"
    except Exception as e:
        err_msg = str(e)
        # Check if error is related to face detection
        if "Can't get the coeffs" in err_msg or "No face is detected" in err_msg or "landmark" in err_msg.lower():
            return None, f"‚ùå Face detection failed: {err_msg}\n\nTip: Use a clear front-facing face image."
        return None, f"Error: {err_msg}"


with gr.Blocks(title="SadTalker ‚Äî Minimal") as demo:
    gr.Markdown("## Upload **image** + **audio** (base sound) ‚Üí talking video")
    with gr.Row():
        with gr.Column():
            image_in = gr.Image(type="filepath", label="Face image", sources=["upload"])
            audio_in = gr.Audio(type="filepath", label="Base sound (audio)", sources=["upload"])
            btn = gr.Button("Generate video", variant="primary")
        with gr.Column():
            video_out = gr.Video(label="Output video")
            status = gr.Textbox(label="Status", interactive=False)
    btn.click(
        fn=generate_video,
        inputs=[image_in, audio_in],
        outputs=[video_out, status]
    )

# In Colab, launch with share=True and server_name so the UI is accessible
demo.launch(share=True, server_name="0.0.0.0", server_port=7860)