# MegaFS Modularized - Face Swapping Notebook

This notebook demonstrates the modularized MegaFS face swapping system with automatic repository cloning.

## Setup Instructions

1. **Update repository URL** in the first cell below
2. **Upload your dataset** to Google Drive:
   - Upload `celeba_mask_hq.zip` to `/content/drive/MyDrive/Datasets/`
3. **Upload weight files** to Google Drive:
   - Place all weight files in `/content/drive/MyDrive/Datasets/weights/`
4. **Run all cells** - everything will be set up automatically

## Features

- **Automatic setup**: Git clone, dataset extraction, data mapping
- **Modular architecture**: Better debugging and maintenance
- **Comprehensive logging**: Error handling and performance profiling
- **Configuration management**: Easy parameter adjustment


In [None]:
# Setup and clone repository
import os
import sys
import subprocess
import shutil

# IMPORTANT: Update this URL with your actual GitHub repository
repo_url = "https://github.com/n01r1r/MegaFS.git"  # ⚠️ CHANGE THIS URL IF NEEDED ⚠️
repo_dir = "/content/MegaFS"

print("=" * 60)
print("MegaFS Modularized - Automatic Setup")
print("=" * 60)
print(f"Repository URL: {repo_url}")
print(f"Target directory: {repo_dir}")
print("=" * 60)

# Clone the repository if not already present
if not os.path.exists(repo_dir):
    print("INFO: Cloning MegaFS repository...")
    try:
        subprocess.run(["git", "clone", repo_url, repo_dir], check=True)
        print("SUCCESS: Repository cloned successfully")
    except subprocess.CalledProcessError as e:
        print(f"ERROR: Failed to clone repository: {e}")
        print("Please check the repository URL and try again")
        print("Make sure the repository is public or you have access")
        sys.exit(1)
else:
    print("INFO: Repository already exists, updating...")
    try:
        subprocess.run(["git", "-C", repo_dir, "pull"], check=True)
        print("SUCCESS: Repository updated")
    except subprocess.CalledProcessError as e:
        print(f"WARNING: Failed to update repository: {e}")

# Add the cloned repository to Python path
sys.path.insert(0, repo_dir)

# Change to the repository directory
os.chdir(repo_dir)

print("SUCCESS: Repository setup complete")
print(f"INFO: Working directory: {os.getcwd()}")
print("=" * 60)

ModuleNotFoundError: No module named 'google.colab'

In [None]:
# Import required libraries
import zipfile
from glob import glob
from tqdm.notebook import tqdm
import torch
import cv2
import numpy as np
import argparse
from google.colab import drive
from IPython.display import display, Image
import matplotlib.pyplot as plt
from google.colab import files

# Import modularized MegaFS components
from config import Config, DEFAULT_CONFIGS
from models.megafs import MegaFS
from models.weight_loaders import verify_all_weights
from utils.debug_utils import check_system_requirements
from utils.data_utils import DataMapManager

print("SUCCESS: All imports complete - Modularized MegaFS ready")

Mounted at /content/drive
Google Drive mount complete


In [None]:
# Mount Google Drive
try:
    drive.mount('/content/drive')
    print("SUCCESS: Google Drive mounted")
except Exception as e:
    print(f"ERROR: Google Drive mount failed: {e}")

# Dataset preparation
print("INFO: Preparing dataset...")
dataset_zip_path = "/content/drive/MyDrive/Datasets/celeba_mask_hq.zip"
base_dir = "/content/"


 Preparing dataset...


In [None]:
# Extract dataset if zip file exists
if os.path.exists(dataset_zip_path):
    print(f"INFO: Extracting dataset from '{dataset_zip_path}'")
    with zipfile.ZipFile(dataset_zip_path, 'r') as zip_ref:
        zip_ref.extractall(base_dir)
    print("SUCCESS: Dataset extraction complete")
else:
    print(f"WARNING: Dataset zip file not found at '{dataset_zip_path}'")
    print("Please ensure the dataset is uploaded to Google Drive")

 '/content/drive/MyDrive/Datasets/celeba_mask_hq.zip' Not Found


In [None]:
# Dataset configuration
dataset_root = os.path.join(base_dir, "CelebAMask-HQ")
img_dir = os.path.join(dataset_root, "CelebA-HQ-img")
mask_base_dir = os.path.join(dataset_root, "CelebAMask-HQ-mask-anno")
data_map_path = os.path.join(repo_dir, "data_map.json")  # Use data_map.json from cloned repo

print(f"INFO: Dataset root: {dataset_root}")
print(f"INFO: Data map path: {data_map_path}")

# Initialize data manager
data_manager = DataMapManager(data_map_path)
data_map = data_manager.data_map
valid_ids = []

In [None]:
# Check data map status
print("INFO: Checking data map status...")
if not os.path.exists(data_map_path):
    print(f"ERROR: Data map file '{data_map_path}' not found.")
    print("INFO: Generating data_map.json from dataset...")
    
    # Generate data map if it doesn't exist
    if os.path.exists(dataset_root):
        try:
            # Run create_datamap.py in the dataset root
            import subprocess
            result = subprocess.run([
                "python", "create_datamap.py"
            ], cwd=dataset_root, capture_output=True, text=True)
            
            if result.returncode == 0:
                print("SUCCESS: data_map.json generated")
                # Reload data manager
                data_manager = DataMapManager(data_map_path)
                data_map = data_manager.data_map
            else:
                print(f"ERROR: Failed to generate data_map.json: {result.stderr}")
        except Exception as e:
            print(f"ERROR: Failed to generate data_map.json: {e}")
    else:
        print("ERROR: Dataset root not found. Please check dataset extraction.")
else:
    print("SUCCESS: Found data_map.json")
    print(f"INFO: Loaded {len(data_map)} entries from data map")


Expecting data_map.json created externally by create_datamap.py
Found data_map.json; loading will happen in the next cell.


In [None]:
# Data map is already loaded by DataMapManager
# Get valid IDs for testing
print("INFO: Getting valid dataset IDs...")
valid_ids = data_manager.get_valid_ids(dataset_root, sample_size=100)
print(f"SUCCESS: Found {len(valid_ids)} valid IDs")

# Verify sample data
if valid_ids:
    stats = data_manager.verify_sample(sample_size=10, dataset_root=dataset_root)
    print(f"INFO: Sample verification - {stats}")
else:
    print("WARNING: No valid IDs found. Check dataset paths and data map.")



Loading data map...
Loaded 0 valid items from './data_map.json'
   id=0  image_exists=False  mask_exists=False
     image_path: \content\CelebAMask-HQ\CelebA-HQ-img\0.jpg
     mask_path:  None
   id=1  image_exists=False  mask_exists=False
     image_path: \content\CelebAMask-HQ\CelebA-HQ-img\1.jpg
     mask_path:  None
   id=10  image_exists=False  mask_exists=False
     image_path: \content\CelebAMask-HQ\CelebA-HQ-img\10.jpg
     mask_path:  None


In [None]:
# System requirements check
print("INFO: Checking system requirements...")
check_system_requirements()


In [None]:
IMG_ROOT = img_dir
MASK_ROOT = mask_base_dir

In [None]:
# Configuration setup
print("INFO: Setting up configuration...")

# Google Drive checkpoint directory
checkpoint_dir = '/content/drive/MyDrive/Datasets/weights'

# Create configuration
config = Config(
    swap_type="ftm",  # Change to "injection" or "lcr" as needed
    dataset_root=dataset_root,
    img_root=img_dir,
    mask_root=mask_base_dir,
    checkpoint_dir=checkpoint_dir
)

print("SUCCESS: Configuration created")
config.print_config()

# Verify that we're using the cloned repository files
print(f"INFO: Using models from: {repo_dir}/models")
print(f"INFO: Using utils from: {repo_dir}/utils")
print(f"INFO: Using data_map from: {data_map_path}")

In [None]:
# Verify weight files
print("INFO: Verifying weight files...")
if not verify_all_weights(checkpoint_dir):
    print("ERROR: Weight verification failed. Please check your weight files.")
    print("Required files:")
    print("  - ftm_final.pth")
    print("  - injection_final.pth") 
    print("  - lcr_final.pth")
    print("  - stylegan2-ffhq-config-f.pth")
else:
    print("SUCCESS: All weight files verified")

In [None]:
# Swap type configuration
SWAP_TYPE = "ftm"  # Change to "injection" or "lcr" as needed
print(f"INFO: Using swap type: {SWAP_TYPE}")

In [None]:
# Initialize MegaFS with modularized components
print("INFO: Initializing MegaFS...")
handler = None

try:
    # Update config with current swap type
    config.swap.swap_type = SWAP_TYPE
    
    # Initialize MegaFS with configuration and data map
    handler = MegaFS(
        config=config,
        data_map=data_map,
        debug=True  # Enable debug logging
    )
    print(f"SUCCESS: {SWAP_TYPE}-MegaFS model handler created")
    
except Exception as e:
    print(f"ERROR: Failed to initialize MegaFS: {e}")
    import traceback
    traceback.print_exc()

## Helper functions for inference

In [None]:
# Reference implementation removed - using modularized MegaFS
print("INFO: Using modularized MegaFS implementation")


In [None]:
# Reference handler removed - using modularized handler
print("INFO: Modularized handler will be initialized in the next cell")


In [None]:
# Reference swap function removed - using modularized run_swap function
print("INFO: Using modularized run_swap function for face swapping")


In [None]:
def run_swap(handler_instance, src_id, tgt_id, refine=True):
    """Run face swap for a single image pair using modularized MegaFS."""
    if not handler_instance:
        print("ERROR: Handler not initialized")
        return

    print(f"INFO: Starting face swap - Source ID: {src_id}, Target ID: {tgt_id}")
    
    try:
        # Use the modularized run method
        result_path, result_image = handler_instance.run(
            src_idx=src_id,
            tgt_idx=tgt_id,
            refine=refine,
            save_path=f"/content/swap_result_{src_id}_to_{tgt_id}.jpg"
        )
        
        if result_path:
            print(f"SUCCESS: Result saved to {result_path}")
            img_disp = Image(result_path)
            display(img_disp)
        else:
            print("WARNING: Face swap completed but failed to save result")
            
    except Exception as e:
        print(f"ERROR: Face swap failed for IDs {src_id} -> {tgt_id}: {e}")
        print("Check if the dataset IDs exist and paths are correct")

In [None]:
def run_batch_swap(handler_instance, id_pairs, refine=True):
    """Run face swap for multiple image pairs using modularized MegaFS."""
    if not handler_instance:
        print("ERROR: Handler not initialized")
        return

    all_results = []
    print(f"INFO: Starting batch processing for {len(id_pairs)} pairs...")
    
    for src_id, tgt_id in tqdm(id_pairs, desc="Batch processing"):
        try:
            # Use the modularized run method
            result_path, result_image = handler_instance.run(
                src_idx=src_id,
                tgt_idx=tgt_id,
                refine=refine
            )
            
            if result_image is not None:
                all_results.append(result_image)
            else:
                print(f"WARNING: Skipping pair ({src_id}, {tgt_id}) - no result")
                
        except Exception as e:
            print(f"ERROR: Failed to process pair ({src_id}, {tgt_id}): {e}")
            continue

    if all_results:
        final_image = cv2.vconcat(all_results)
        result_filename = f"batch_result_{len(id_pairs)}_pairs.jpg"
        cv2.imwrite(f"/content/{result_filename}", final_image)
        print(f"SUCCESS: Batch result saved to {result_filename}")
        
        height, width, _ = final_image.shape
        scale = 800 / width
        img_disp = Image(f'/content/{result_filename}', width=int(width*scale), height=int(height*scale))
        display(img_disp)
    else:
        print("ERROR: No results processed")

# INFERENCE / TEST codes

In [None]:
# Single image face swap
print("INFO: Running single image face swap...")

# Configure source and target IDs
SOURCE_ID = 2332  # Change these IDs as needed
TARGET_ID = 2107

print(f"INFO: Source ID: {SOURCE_ID}, Target ID: {TARGET_ID}")

if handler:
    run_swap(handler, SOURCE_ID, TARGET_ID, refine=True)
else:
    print("ERROR: MegaFS handler not initialized. Check previous cells.")


In [None]:
# Alternative single image swap (if you want to test different IDs)
print("INFO: Alternative single image swap...")

# You can test different IDs here
ALT_SOURCE_ID = 100
ALT_TARGET_ID = 200

print(f"INFO: Alternative test - Source ID: {ALT_SOURCE_ID}, Target ID: {ALT_TARGET_ID}")

if handler and valid_ids:
    # Check if IDs are valid
    if ALT_SOURCE_ID in valid_ids and ALT_TARGET_ID in valid_ids:
        run_swap(handler, ALT_SOURCE_ID, ALT_TARGET_ID, refine=True)
    else:
        print(f"WARNING: IDs not in valid set. Available IDs: {len(valid_ids)}")
        print(f"Using first two valid IDs instead...")
        if len(valid_ids) >= 2:
            run_swap(handler, valid_ids[0], valid_ids[1], refine=True)
else:
    print("ERROR: Handler not initialized or no valid IDs available")

In [None]:
# Batch processing
print("INFO: Running batch processing...")

# Define batch pairs - modify these IDs as needed
batch_pairs = [
    (100, 200),
    (300, 400),
    (500, 600)
    # Add more pairs as needed
]

print(f"INFO: Processing {len(batch_pairs)} pairs...")

if handler and valid_ids:
    # Validate batch pairs
    valid_pairs = []
    for src, tgt in batch_pairs:
        if src in valid_ids and tgt in valid_ids:
            valid_pairs.append((src, tgt))
        else:
            print(f"WARNING: Skipping invalid pair ({src}, {tgt})")

    if valid_pairs:
        print(f"INFO: Running batch swap with {len(valid_pairs)} valid pairs")
        run_batch_swap(handler, valid_pairs, refine=True)
    else:
        print("ERROR: No valid pairs found. Check your ID selection.")
        print(f"Available valid IDs: {len(valid_ids)}")
else:
    print("ERROR: Handler not initialized or no valid IDs available")

In [None]:
# Result file management
print("INFO: Checking for result files...")

# Find result files
result_files = [f for f in os.listdir('/content/') if f.startswith(("swap_result_", "batch_result_")) and f.endswith(".jpg")]

if result_files:
    print(f"SUCCESS: Found {len(result_files)} result files:")
    for file in result_files:
        print(f"  - {file}")

    print("\nINFO: To download files, run the following commands:")
    for file in result_files:
        print(f"files.download('/content/{file}')")
else:
    print("WARNING: No result files found. Run face swap cells first.")