# Configuration & Workflow

- Configure base paths in the cells below (dataset root, weights path).
- First, generate `data_map.json` by running the script in the dataset root:

```bash
python create_datamap.py
```

- Then run the cells to load the mapping and perform inference. Do not create the mapping inside this notebook.


In [1]:
import os
import zipfile
from glob import glob
from tqdm.notebook import tqdm
import torch
import cv2
import numpy as np
import argparse
from google.colab import drive
from IPython.display import display, Image
import matplotlib.pyplot as plt
from google.colab import files

print("Import complete")

Import complete


In [2]:
try:
    drive.mount('/content/drive')
    print("Google Drive mount complete")
except Exception as e:
    print(f"Google Drive mount error: {e}")

Mounted at /content/drive
Google Drive mount complete


In [3]:
import sys
sys.path.append('/content/models')

try:
    from models.megafs import MegaFS
    print("MegaFS module import complete")
except ImportError as e:
    print(f"Import fail, check the ./models directory in current session: {e}")


Import fail, check the ./models directory in current session: No module named 'models'


In [8]:
print("\n Preparing dataset...")
dataset_zip_path = "/content/drive/MyDrive/Datasets/celeba_mask_hq.zip"
base_dir = "/content/"


 Preparing dataset...


In [15]:
if os.path.exists(dataset_zip_path):
        print(f"Unzipping '{dataset_zip_path}'")
        with zipfile.ZipFile(dataset_zip_path, 'r') as zip_ref:
            zip_ref.extractall(base_dir)
        print("Unzipping complete.")
else:
        print(f" '{dataset_zip_path}' Not Found")

Unzipping '/content/drive/MyDrive/Datasets/celeba_mask_hq.zip'
Unzipping complete.


In [None]:
img_dir = os.path.join(base_dir, "CelebAMask-HQ/CelebA-HQ-img/")
mask_base_dir = os.path.join(base_dir, "CelebAMask-HQ/CelebAMask-HQ-mask-anno/")
data_map_path = os.path.join(base_dir, "data_map.json")
data_map = {}
valid_ids = []

In [None]:
import json

print("\nExpecting data_map.json created externally by create_datamap.py")
if not os.path.exists(data_map_path):
    print(f"Data map file '{data_map_path}' not found. Please run create_datamap.py in the dataset root first.")
else:
    print("Found data_map.json; loading will happen in the next cell.")


 Mapping image-mask paths from CelebAMask-HQ...
Searching for mask files paired to the img IDs...


MAPPING:   0%|          | 0/30000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import os
import json

print("\nLoading data map...")

valid_ids = []

if os.path.exists(data_map_path):
    with open(data_map_path, 'r', encoding='utf-8') as f:
        loaded_map_str_keys = json.load(f)
        # Convert string keys to integers
        data_map = {int(k): v for k, v in loaded_map_str_keys.items()}

    for img_id, rec in data_map.items():
        # Normalize image path to absolute
        rel_img = rec.get("image_path")
        if isinstance(rel_img, str):
            rec["image_path"] = os.path.normpath(os.path.join(base_dir, rel_img))
        else:
            rec["image_path"] = None

        # Handle masks: repository stores 'mask_paths' (list). Choose the first available as 'mask_path'.
        mask_list = rec.get("mask_paths")
        if isinstance(mask_list, list) and len(mask_list) > 0:
            rec["mask_path"] = os.path.normpath(os.path.join(base_dir, mask_list[0]))
        else:
            # Also support legacy single 'mask_path' string if present
            legacy = rec.get("mask_path")
            if isinstance(legacy, str) and legacy:
                rec["mask_path"] = os.path.normpath(os.path.join(base_dir, legacy))
            else:
                rec["mask_path"] = None

    # Collect valid ids that have both image and a mask file path
    valid_ids = sorted([k for k, v in data_map.items() if v.get("image_path") and v.get("mask_path")])
    print(f"Loaded {len(valid_ids)} valid items from '{data_map_path}'")

    IMG_ROOT = img_dir
    MASK_ROOT = mask_base_dir

    if valid_ids:
        sample_id = valid_ids[0]
        print(f"   - Sample ID: {sample_id}")
        print(f"     - Image path: {data_map[sample_id]['image_path']}")
        print(f"     - Mask path: {data_map[sample_id]['mask_path']}")
else:
    print(f"Data map file '{data_map_path}' not found. Please create it first.")


In [None]:
import random
import os
import json

print("\nVerifying data map entries on disk...")

if data_map:
    check_count = min(100, len(data_map))
    sample_ids = random.sample(list(data_map.keys()), check_count)

    missing_sample = []
    for img_id in sample_ids:
        rec = data_map[img_id]
        if not (os.path.exists(rec["image_path"]) and os.path.exists(rec["mask_path"])):
            missing_sample.append(img_id)

    if missing_sample:
        print(f"Missing {len(missing_sample)} of {check_count} sampled entries; pruning all missing entries...")
        to_delete = []
        for img_id, rec in list(data_map.items()):
            if not (os.path.exists(rec["image_path"]) and os.path.exists(rec["mask_path"])):
                to_delete.append(img_id)
        for img_id in to_delete:
            del data_map[img_id]

        valid_ids = sorted(data_map.keys())
        try:
            with open(data_map_path, 'w') as f:
                json.dump(data_map, f, indent=4)
            print(f"Pruned {len(to_delete)} entries; updated data map saved to '{data_map_path}'.")
            print(f"Remaining entries: {len(valid_ids)}")
        except Exception as e:
            print(f"Error while saving pruned data map: {e}")
    else:
        print(f"No missing files found in the sampled {check_count} entries.")
else:
    print("Data map is empty; skipping verification.")


In [None]:
IMG_ROOT = img_dir
MASK_ROOT = mask_base_dir

In [None]:
print("\n\n---\n\n")
print("📂 체크포인트 경로를 설정합니다...")

# Google Drive 내의 체크포인트 폴더 경로
checkpoint_dir = '/content/drive/MyDrive/Datasets/weights'

In [None]:
if os.path.isdir(checkpoint_dir):
    print(f"체크포인트 디렉토리 확인: '{checkpoint_dir}'")
    # 필요한 파일들이 있는지 간단히 확인
    required_files = ["ftm_final.pth", "stylegan2-ffhq-config-f.pth", "injection_final.pth", "lcr_final.pth"]
    missing_files = [f for f in required_files if not os.path.exists(os.path.join(checkpoint_dir, f))]
    if missing_files:
        print(f"[WARNING] following files do not exist: {missing_files}")
    else:
        print("ALL CHECKPOINTS CONFIRMED")
else:
    print(f"Following filepath not found: '{checkpoint_dir}'.")

In [None]:
print("\n\n---\n\n")
print("Initializing models...")

In [None]:
SWAP_TYPE = "ftm" # ftm, injection, lcr

In [None]:
handler = None
if 'MegaFS' in locals():
    try:
        handler = MegaFS(
            swap_type=SWAP_TYPE,
            img_root=IMG_ROOT,
            mask_root=MASK_ROOT,
            data_map=data_map, # 개선된 데이터 맵 전달
            checkpoint_dir=checkpoint_dir # 체크포인트 경로 전달
        )
        print(f"'{SWAP_TYPE}'-MegaFS model handler created.")
    except Exception as e:
        print(f"Error while generating handler: {e}")
        import traceback
        traceback.print_exc()
else:
    print("MegaFS class failed to import. Check previous cells")

## Helper functions for inference

In [None]:
def run_swap(handler_instance, src_id, tgt_id, refine=True):
    """단일 이미지 쌍에 대해 스왑을 실행하고 결과를 표시합니다."""
    if not handler_instance:
        print("❌ 핸들러가 초기화되지 않았습니다.")
        return

    print(f"\n🔄 Source ID: {src_id}, Target ID: {tgt_id} 얼굴 교체를 시작합니다...")
    try:
        result_image = handler_instance.run(src_id, tgt_id, refine)

        # 결과 저장 및 표시
        result_filename = f"swap_result_{src_id}_to_{tgt_id}.jpg"
        cv2.imwrite(f"/content/{result_filename}", result_image)
        print(f"💾 결과가 '{result_filename}' 파일로 저장되었습니다.")

        # 이미지 표시
        img_disp = Image(f'/content/{result_filename}')
        display(img_disp)

    except Exception as e:
        print(f"❌ ID {src_id} 또는 {tgt_id} 처리 중 오류 발생: {e}")
        print("   - 데이터셋에 해당 ID가 존재하는지 확인해주세요.")

In [None]:
def run_batch_swap(handler_instance, id_pairs, refine=True):
    """여러 이미지 쌍에 대해 배치 스왑을 실행하고 결과를 결합하여 표시합니다."""
    if not handler_instance:
        print("❌ 핸들러가 초기화되지 않았습니다.")
        return

    all_results = []
    print(f"\n⚙️ 총 {len(id_pairs)}개의 쌍에 대한 배치 작업을 시작합니다...")
    for src_id, tgt_id in tqdm(id_pairs, desc="배치 처리"):
        try:
            result = handler_instance.run(src_id, tgt_id, refine)
            all_results.append(result)
        except Exception as e:
            print(f"❌ ID 쌍 ({src_id}, {tgt_id}) 처리 중 오류 발생: {e}")
            continue

    if all_results:
        # 모든 결과 이미지를 세로로 연결
        final_image = cv2.vconcat(all_results)

        # 결과 저장 및 표시
        result_filename = f"batch_result_{len(id_pairs)}_pairs.jpg"
        cv2.imwrite(f"/content/{result_filename}", final_image)
        print(f"\n💾 배치 결과가 '{result_filename}' 파일로 저장되었습니다.")

        # 이미지 표시 (너무 클 수 있으므로 크기 조정)
        height, width, _ = final_image.shape
        scale = 800 / width # 너비를 800px로 조정

        img_disp = Image(f'/content/{result_filename}', width=int(width*scale), height=int(height*scale))
        display(img_disp)
    else:
        print("처리된 결과가 없습니다.")

# INFERENCE / TEST codes

In [None]:
print("\n\n---\n\n")
print("🎨 단일 이미지 스왑을 실행합니다.")

# --- 여기서 소스와 타겟 ID를 변경하세요 ---
SOURCE_ID = 2332
TARGET_ID = 2107
# -----------------------------------------

if handler:
    if SOURCE_ID in valid_ids and TARGET_ID in valid_ids:
        run_swap(handler, SOURCE_ID, TARGET_ID, refine=True)
    else:
        print(f"❌ ID 오류: 소스({SOURCE_ID}) 또는 타겟({TARGET_ID}) ID가 유효한 데이터 목록에 없습니다.")
else:
    print("⚠️ 모델이 초기화되지 않았습니다. 이전 단계를 확인해주세요.")

In [None]:
print("\n\n---\n\n")
print("🏭 배치 처리를 실행합니다.")

# --- 여기에 처리할 (소스, 타겟) ID 쌍 목록을 추가하세요 ---
batch_pairs = [
    (100, 200),
    (300, 400),
    (500, 600)
    # 필요한 만큼 쌍을 추가하세요.
]
# ----------------------------------------------------

if handler:
    # 배치 쌍 유효성 검사
    valid_pairs = []
    for src, tgt in batch_pairs:
        if src in valid_ids and tgt in valid_ids:
            valid_pairs.append((src, tgt))
        else:
            print(f"⚠️ ID 쌍 ({src}, {tgt})은(는) 유효하지 않아 건너뜁니다.")

    if valid_pairs:
        run_batch_swap(handler, valid_pairs, refine=True)
    else:
        print("처리할 유효한 ID 쌍이 없습니다.")
else:
    print("⚠️ 모델이 초기화되지 않았습니다. 이전 단계를 확인해주세요.")

In [None]:
print("\n\n---\n\n")
print("📥 결과 파일 다운로드를 준비합니다.")

# 결과 파일들 찾기
result_files = [f for f in os.listdir('/content/') if f.startswith(("swap_result_", "batch_result_")) and f.endswith(".jpg")]

if result_files:
    print(f"📁 발견된 결과 파일: {len(result_files)}개")
    for file in result_files:
        print(f"  - {file}")

    print("\n💾 파일을 다운로드하려면 아래 코드를 별도의 셀에서 실행하세요:")
    for file in result_files:
        print(f"files.download('/content/{file}')")
else:
    print("❌ 다운로드할 결과 파일을 찾을 수 없습니다.")