# Configuration & Workflow

- Configure base paths in the cells below (dataset root, weights path).
- First, generate `data_map.json` by running the script in the dataset root:

```bash
python create_datamap.py
```

- Then run the cells to load the mapping and perform inference. Do not create the mapping inside this notebook.


In [3]:
import os
import zipfile
from glob import glob
from tqdm.notebook import tqdm
import torch
import cv2
import numpy as np
import argparse
from google.colab import drive
from IPython.display import display, Image
import matplotlib.pyplot as plt
from google.colab import files

print("Import complete")

ModuleNotFoundError: No module named 'google.colab'

In [None]:
try:
    drive.mount('/content/drive')
    print("Google Drive mount complete")
except Exception as e:
    print(f"Google Drive mount error: {e}")

Mounted at /content/drive
Google Drive mount complete


In [6]:
print("\n Preparing dataset...")
dataset_zip_path = "/content/drive/MyDrive/Datasets/celeba_mask_hq.zip"
base_dir = "/content/"


 Preparing dataset...


In [7]:
if os.path.exists(dataset_zip_path):
        print(f"Unzipping '{dataset_zip_path}'")
        with zipfile.ZipFile(dataset_zip_path, 'r') as zip_ref:
            zip_ref.extractall(base_dir)
        print("Unzipping complete.")
else:
        print(f" '{dataset_zip_path}' Not Found")

 '/content/drive/MyDrive/Datasets/celeba_mask_hq.zip' Not Found


In [11]:
dataset_root = os.path.join(base_dir, "CelebAMask-HQ")
img_dir = os.path.join(dataset_root, "CelebA-HQ-img")
mask_base_dir = os.path.join(dataset_root, "CelebAMask-HQ-mask-anno")
data_map_path = os.path.join("./data_map.json")
data_map = {}
valid_ids = []

In [12]:
import json

print("\nExpecting data_map.json created externally by create_datamap.py")
if not os.path.exists(data_map_path):
    print(f"Data map file '{data_map_path}' not found. Please run create_datamap.py in the dataset root first.")
else:
    print("Found data_map.json; loading will happen in the next cell.")


Expecting data_map.json created externally by create_datamap.py
Found data_map.json; loading will happen in the next cell.


In [None]:
import os, json

print("\nLoading data map (fast mode)...")

# Fast load: keys -> int, no per-entry exists-checks
with open(data_map_path, 'r', encoding='utf-8') as f:
    loaded_map_str_keys = json.load(f)

data_map = {int(k): v for k, v in loaded_map_str_keys.items()}

# On-demand resolver for a single id (used by inference helpers)
def resolve_paths_for_id(img_id: int):
    rec = data_map.get(img_id)
    if rec is None:
        return None, None

    def _norm(rel):
        if not isinstance(rel, str) or not rel:
            return None
        rel = rel.replace("\\", "/").lstrip("/\\")
        if rel.startswith("CelebAMask-HQ/"):
            rel = rel.split("/", 1)[1]
        return os.path.normpath(os.path.join(dataset_root, rel))

    image_path = _norm(rec.get("image_path"))
    mask_path = None
    mp = rec.get("mask_paths")
    if isinstance(mp, list) and mp:
        mask_path = _norm(mp[0])
    elif isinstance(rec.get("mask_path"), str):
        mask_path = _norm(rec["mask_path"])
    return image_path, mask_path

print("Loaded map entries:", len(data_map))
print("Fast mode: no global os.path.exists or glob at load.")



Loading data map...
Loaded 0 valid items from './data_map.json'
   id=0  image_exists=False  mask_exists=False
     image_path: \content\CelebAMask-HQ\CelebA-HQ-img\0.jpg
     mask_path:  None
   id=1  image_exists=False  mask_exists=False
     image_path: \content\CelebAMask-HQ\CelebA-HQ-img\1.jpg
     mask_path:  None
   id=10  image_exists=False  mask_exists=False
     image_path: \content\CelebAMask-HQ\CelebA-HQ-img\10.jpg
     mask_path:  None


In [None]:
import random
import os

print("\nVerify (sample-only, fast)...")

if data_map:
    total = len(data_map)
    sample_ids = random.sample(list(data_map.keys()), min(20, total))

    missing = 0
    for mid in sample_ids:
        ip, mp = resolve_paths_for_id(mid)
        iex = os.path.exists(ip) if ip else False
        mex = os.path.exists(mp) if mp else False
        print(f"id={mid} image_exists={iex} mask_exists={mex}")
        if not (iex and mex):
            missing += 1

    print(f"Sampled {len(sample_ids)}; missing in sample: {missing}")
else:
    print("Data map is empty; skipping verification.")


In [None]:
IMG_ROOT = img_dir
MASK_ROOT = mask_base_dir

In [None]:
print("\n\n---\n\n")
print("📂 체크포인트 경로를 설정합니다...")

# Google Drive 내의 체크포인트 폴더 경로
checkpoint_dir = '/content/drive/MyDrive/Datasets/weights'

In [None]:
if os.path.isdir(checkpoint_dir):
    print(f"체크포인트 디렉토리 확인: '{checkpoint_dir}'")
    # 필요한 파일들이 있는지 간단히 확인
    required_files = ["ftm_final.pth", "stylegan2-ffhq-config-f.pth", "injection_final.pth", "lcr_final.pth"]
    missing_files = [f for f in required_files if not os.path.exists(os.path.join(checkpoint_dir, f))]
    if missing_files:
        print(f"[WARNING] following files do not exist: {missing_files}")
    else:
        print("ALL CHECKPOINTS CONFIRMED")
else:
    print(f"Following filepath not found: '{checkpoint_dir}'.")

In [None]:
# Deep weight load verification (non-destructive)
import os
import torch

print("\nVerifying all expected weights can be loaded...")

expected_files = {
    "ftm_final.pth": {"required_keys": ["e", "s"]},
    "injection_final.pth": {"required_keys": ["e", "s"]},
    "lcr_final.pth": {"required_keys": ["e", "s"]},
    "stylegan2-ffhq-config-f.pth": {"required_keys": ["g_ema"]},
}

if not os.path.isdir(checkpoint_dir):
    print(f"❌ checkpoint_dir not found: {checkpoint_dir}")
else:
    for fname, spec in expected_files.items():
        fpath = os.path.join(checkpoint_dir, fname)
        if not os.path.exists(fpath):
            print(f"❌ MISSING: {fname}")
            continue
        try:
            ckpt = torch.load(fpath, map_location=torch.device("cpu"))
        except Exception as e:
            print(f"❌ LOAD FAIL: {fname} -> {e}")
            continue
        missing_keys = [k for k in spec["required_keys"] if k not in ckpt]
        if missing_keys:
            print(f"⚠️ LOADED but missing keys in {fname}: {missing_keys}")
        else:
            print(f"✅ OK: {fname} (has {spec['required_keys']})")
        # free
        del ckpt



In [None]:
SWAP_TYPE = "ftm" # ftm, injection, lcr

In [None]:
handler = None
if 'MegaFS' in locals():
    try:
        handler = MegaFS(
            swap_type=SWAP_TYPE,
            img_root=IMG_ROOT,
            mask_root=MASK_ROOT,
            data_map=data_map, # 개선된 데이터 맵 전달
            checkpoint_dir=checkpoint_dir # 체크포인트 경로 전달
        )
        print(f"'{SWAP_TYPE}'-MegaFS model handler created.")
    except Exception as e:
        print(f"Error while generating handler: {e}")
        import traceback
        traceback.print_exc()
else:
    print("MegaFS class failed to import. Check previous cells")

## Helper functions for inference

In [None]:
# Toggle to use reference implementation
REF_MODE = True

if REF_MODE:
    import sys
    sys.path.append('/content/ref_models')
    try:
        from ref_models.megafs import MegaFS as RefMegaFS
        print("Reference MegaFS import complete")
    except Exception as e:
        print(f"Reference import failed: {e}")


In [None]:
# Initialize reference handler
ref_handler = None
if 'RefMegaFS' in locals():
    try:
        ref_handler = RefMegaFS(swap_type=SWAP_TYPE).cuda().eval()
        print(f"Reference handler ready for '{SWAP_TYPE}'")
    except Exception as e:
        print(f"Failed to init reference handler: {e}")


In [None]:
# Helper for reference swap using existing data_map
import numpy as np
import cv2
import torch

def run_swap_ref(ref_model, src_id, tgt_id):
    if not ref_model:
        print("❌ 레퍼런스 모델이 초기화되지 않았습니다.")
        return
    # resolve paths via existing resolver
    src_img_path, _ = resolve_paths_for_id(src_id)
    tgt_img_path, _ = resolve_paths_for_id(tgt_id)
    if not (src_img_path and os.path.exists(src_img_path)):
        print(f"❌ 소스 이미지 없음: {src_img_path}")
        return
    if not (tgt_img_path and os.path.exists(tgt_img_path)):
        print(f"❌ 타겟 이미지 없음: {tgt_img_path}")
        return
    # load and preprocess to 256
    src = cv2.cvtColor(cv2.imread(src_img_path), cv2.COLOR_BGR2RGB)
    tgt = cv2.cvtColor(cv2.imread(tgt_img_path), cv2.COLOR_BGR2RGB)
    src = cv2.resize(src, (256, 256))
    tgt = cv2.resize(tgt, (256, 256))
    ts = torch.from_numpy(src.transpose(2,0,1)).float().mul_(1/255.0)
    tt = torch.from_numpy(tgt.transpose(2,0,1)).float().mul_(1/255.0)
    ts = (ts - 0.5)/0.5
    tt = (tt - 0.5)/0.5
    with torch.no_grad():
        out = ref_model(ts.unsqueeze(0).cuda(), tt.unsqueeze(0).cuda())
        img = out[0].clamp(0,1).permute(1,2,0).cpu().numpy()
        img_bgr = (img*255).astype(np.uint8)[:,:,::-1]
        result = np.hstack((cv2.cvtColor(src, cv2.COLOR_RGB2BGR), cv2.cvtColor(tgt, cv2.COLOR_RGB2BGR), img_bgr))
        out_name = f"ref_swap_{src_id}_to_{tgt_id}.jpg"
        cv2.imwrite(f"/content/{out_name}", result)
        print(f"💾 레퍼런스 결과 저장: {out_name}")


In [None]:
def run_swap(handler_instance, src_id, tgt_id, refine=True):
    """단일 이미지 쌍에 대해 스왑을 실행하고 결과를 표시합니다."""
    if not handler_instance:
        print("❌ 핸들러가 초기화되지 않았습니다.")
        return

    # On-demand resolve and existence checks
    src_img, _ = resolve_paths_for_id(src_id)
    tgt_img, tgt_mask = resolve_paths_for_id(tgt_id)
    if not (src_img and os.path.exists(src_img)):
        print(f"❌ 소스 이미지가 없습니다: {src_img}")
        return
    if not (tgt_img and os.path.exists(tgt_img)):
        print(f"❌ 타겟 이미지가 없습니다: {tgt_img}")
        return
    #if not (tgt_mask and os.path.exists(tgt_mask)):
    #    print(f"❌ 타겟 마스크가 없습니다: {tgt_mask}")
    #    return

    print(f"\n🔄 Source ID: {src_id}, Target ID: {tgt_id} 얼굴 교체를 시작합니다...")
    try:
        handler_instance.data_map = data_map
        result = handler_instance.run(src_id, tgt_id, refine)
        save_path, result_image = (result if isinstance(result, tuple) else (None, result))

        result_filename = f"swap_result_{src_id}_to_{tgt_id}.jpg"
        cv2.imwrite(f"/content/{result_filename}", result_image)
        print(f"💾 결과가 '{result_filename}' 파일로 저장되었습니다.")

        img_disp = Image(f'/content/{result_filename}')
        display(img_disp)

    except Exception as e:
        print(f"❌ ID {src_id} 또는 {tgt_id} 처리 중 오류 발생: {e}")
        print("   - 데이터셋에 해당 ID가 존재하는지 확인해주세요.")

In [None]:
def run_batch_swap(handler_instance, id_pairs, refine=True):
    """여러 이미지 쌍에 대해 배치 스왑을 실행하고 결과를 결합하여 표시합니다."""
    if not handler_instance:
        print("❌ 핸들러가 초기화되지 않았습니다.")
        return

    all_results = []
    print(f"\n⚙️ 총 {len(id_pairs)}개의 쌍에 대한 배치 작업을 시작합니다...")
    for src_id, tgt_id in tqdm(id_pairs, desc="배치 처리"):
        # Resolve paths on-demand per pair
        src_img, _ = resolve_paths_for_id(src_id)
        tgt_img, tgt_mask = resolve_paths_for_id(tgt_id)
        if not (src_img and os.path.exists(src_img) and tgt_img and os.path.exists(tgt_img) and tgt_mask and os.path.exists(tgt_mask)):
            print(f"⚠️ 건너뜀 (경로 누락): ({src_id}, {tgt_id})")
            continue
        try:
            result = handler_instance.run(src_id, tgt_id, refine)
            # handler.run returns (save_path, result) or (None, result)
            _, result_image = (result if isinstance(result, tuple) else (None, result))
            all_results.append(result_image)
        except Exception as e:
            print(f"❌ ID 쌍 ({src_id}, {tgt_id}) 처리 중 오류 발생: {e}")
            continue

    if all_results:
        final_image = cv2.vconcat(all_results)
        result_filename = f"batch_result_{len(id_pairs)}_pairs.jpg"
        cv2.imwrite(f"/content/{result_filename}", final_image)
        print(f"\n💾 배치 결과가 '{result_filename}' 파일로 저장되었습니다.")
        height, width, _ = final_image.shape
        scale = 800 / width
        img_disp = Image(f'/content/{result_filename}', width=int(width*scale), height=int(height*scale))
        display(img_disp)
    else:
        print("처리된 결과가 없습니다.")

# INFERENCE / TEST codes

In [None]:
# Use reference model for single inference
print("\n\n---\n\n")
print("🎨 [Reference] 단일 이미지 스왑을 실행합니다.")

# --- 여기서 소스와 타겟 ID를 변경하세요 ---
SOURCE_ID = 2332
TARGET_ID = 2107
# -----------------------------------------

if ref_handler:
    run_swap_ref(ref_handler, SOURCE_ID, TARGET_ID)
else:
    print("⚠️ 레퍼런스 모델이 초기화되지 않았습니다.")


In [None]:
print("\n\n---\n\n")
print("🎨 단일 이미지 스왑을 실행합니다.")

# --- 여기서 소스와 타겟 ID를 변경하세요 ---
SOURCE_ID = 2332
TARGET_ID = 2107
# -----------------------------------------

if handler:
    run_swap(handler, SOURCE_ID, TARGET_ID, refine=True)
else:
    print("⚠️ 모델이 초기화되지 않았습니다. 이전 단계를 확인해주세요.")

In [None]:
print("\n\n---\n\n")
print("🏭 배치 처리를 실행합니다.")

# --- 여기에 처리할 (소스, 타겟) ID 쌍 목록을 추가하세요 ---
batch_pairs = [
    (100, 200),
    (300, 400),
    (500, 600)
    # 필요한 만큼 쌍을 추가하세요.
]
# ----------------------------------------------------

if handler:
    # 배치 쌍 유효성 검사
    valid_pairs = []
    for src, tgt in batch_pairs:
        if src in valid_ids and tgt in valid_ids:
            valid_pairs.append((src, tgt))
        else:
            print(f"⚠️ ID 쌍 ({src}, {tgt})은(는) 유효하지 않아 건너뜁니다.")

    if valid_pairs:
        run_batch_swap(handler, valid_pairs, refine=True)
    else:
        print("처리할 유효한 ID 쌍이 없습니다.")
else:
    print("⚠️ 모델이 초기화되지 않았습니다. 이전 단계를 확인해주세요.")

In [None]:
print("\n\n---\n\n")
print("📥 결과 파일 다운로드를 준비합니다.")

# 결과 파일들 찾기
result_files = [f for f in os.listdir('/content/') if f.startswith(("swap_result_", "batch_result_")) and f.endswith(".jpg")]

if result_files:
    print(f"📁 발견된 결과 파일: {len(result_files)}개")
    for file in result_files:
        print(f"  - {file}")

    print("\n💾 파일을 다운로드하려면 아래 코드를 별도의 셀에서 실행하세요:")
    for file in result_files:
        print(f"files.download('/content/{file}')")
else:
    print("❌ 다운로드할 결과 파일을 찾을 수 없습니다.")