In [None]:
from google.colab import drive, userdata
import shutil
import os

drive.mount('/content/drive', force_remount=True)
source_folder_path = '/content/drive/MyDrive/VLM_for_FIQA'
destination_folder_path = '/content/VLM_for_FIQA'

shutil.copytree(source_folder_path, destination_folder_path)

In [None]:
!pip install mtcnn --quiet
!pip install insightface --quiet
!pip install onnxruntime --quiet

In [None]:
import zipfile
from mtcnn import MTCNN
from mtcnn.utils.images import load_images_batch
import glob
from tqdm import tqdm
from insightface.utils import face_align
from pathlib import Path
import cv2
import numpy as np
import os
import csv
import shutil
import matplotlib.pyplot as plt
import multiprocessing as mp
import torch
import concurrent.futures

In [None]:
zip_path = '/content/VLM_for_FIQA/CelebA-HQ/celeba.zip'
extract_path = '/content/celeba'

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

In [None]:
def display_alignment(image_dict, detection_list):

    detection_list = [detection_list] if not isinstance(detection_list, list) else detection_list

    assert len(image_dict) == len(detection_list), "Number of images and detections should be the same."

    cols = 2
    rows = len(image_dict)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))

    if rows == 1:
        axes = np.array([axes])
    axes = axes.reshape(rows, cols)

    for i, (path, image, detection) in enumerate(zip(image_dict.keys(), image_dict.values(), detection_list)):
        aligned_image = align(image, detection)

        original_image = image.copy()
        x, y, w, h = detection['box']
        cv2.rectangle(original_image, (x, y), (x + w, y + h), (0, 255, 0), 2)
        for key, (px, py) in detection['keypoints'].items():
            cv2.circle(original_image, (px, py), 3, (0, 0, 255), -1)
            cv2.putText(original_image, key, (px + 5, py - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)

        axes[i, 0].imshow(original_image)
        axes[i, 0].set_title(f"{path} - Original Image")
        axes[i, 1].imshow(aligned_image)
        axes[i, 1].set_title(f"{path} - Aligned Image")

    plt.setp(axes, xticks=[], yticks=[], frame_on=False)
    plt.tight_layout()
    plt.show()

In [None]:
def align(image, detection, output_size=224):
    """Align a face image based on keypoint detection."""
    landmark_arr = np.array([
        detection['keypoints']['left_eye'],
        detection['keypoints']['right_eye'],
        detection['keypoints']['nose'],
        detection['keypoints']['mouth_left'],
        detection['keypoints']['mouth_right']
    ])

    aligned_image = face_align.norm_crop(image, landmark_arr, image_size=output_size)
    return aligned_image

def worker(path, image, detection):
    """Worker function for multiprocessing alignment."""
    return path, align(image, detection)

def align_batch(image_dict, detection_list):
    """Align a batch of images in parallel using multiprocessing."""
    assert len(image_dict) == len(detection_list), "Number of images and detections should be the same."
    tasks = list(zip(image_dict.keys(), image_dict.values(), detection_list))

    with mp.Pool(processes=mp.cpu_count()) as pool:
        results = pool.starmap(worker, tasks)

    return {path: aligned_image for path, aligned_image in results}

def display_alignment(image_dict, detection_list):
    """Display original and aligned images with bounding boxes and landmarks."""

    if not isinstance(detection_list, list):
        detection_list = [detection_list]

    assert len(image_dict) == len(detection_list), "Number of images and detections should be the same."

    cols = 2
    rows = len(image_dict)
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))

    if rows == 1:
        axes = np.expand_dims(axes, axis=0)

    for i, (path, image, detection) in enumerate(zip(image_dict.keys(), image_dict.values(), detection_list)):
        aligned_image = align(image, detection)

        original_image = image.copy()

        x, y, w, h = detection['box']
        cv2.rectangle(original_image, (x, y), (x + w, y + h), (0, 255, 0), 2)

        for key, (px, py) in detection['keypoints'].items():
            cv2.circle(original_image, (px, py), 3, (0, 0, 255), -1)

        axes[i, 0].imshow(original_image)
        axes[i, 0].set_title(f"{path} - Original Image")
        axes[i, 1].imshow(aligned_image)
        axes[i, 1].set_title(f"{path} - Aligned Image")

    for ax in axes.flatten():
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_frame_on(False)

    plt.tight_layout()
    plt.show()


In [None]:
batch_size = 128

image_directory = Path("/content/celeba/celeba_hq_256")
if not Path(image_directory).exists():
    raise FileNotFoundError(f"The directory '{image_directory}' does not exist. Please check the path.")
image_paths = [str(path) for path in image_directory.glob("*.jpg")]
if not image_paths:
    raise FileNotFoundError(f"No .jpg image files found in the specified directory: '{image_directory}'.")

output_dir = "/content/celeba_mtcnn_aligned"
if os.path.exists(output_dir):
  shutil.rmtree(output_dir)
os.makedirs(output_dir)

device = "GPU:0" if torch.cuda.is_available() else "CPU:0"
detector = MTCNN(device=device)
print(f"Using device: {device}")

failed_num = 0
for i in tqdm(range(0, len(image_paths), batch_size), desc="Processing Batches", unit="batch"):
    batch_image_paths = image_paths[i:i + batch_size]
    batch_detections = detector.detect_faces(batch_image_paths,  batch_stack_justification="center")

    assert len(batch_detections) == len(batch_image_paths), "Number of detections does not match the number of images in the batch."

    valid_images = {}
    valid_detections = []

    for detection, image_path in zip(batch_detections, batch_image_paths):
        image = cv2.imread(image_path)

        if detection is None or len(detection) == 0 or image is None:
            failed_num += 1
            continue

        best_detection = max(detection, key=lambda x: x['confidence']) if len(detection) > 1 else detection[0]

        valid_detections.append(best_detection)
        valid_images[os.path.basename(image_path)] = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    assert (
        len(valid_detections) == len(valid_images)
    ), "Number of valid detections does not match the number of images."

    aligned_images = align_batch(valid_images, valid_detections)

    for filename, aligned_image in aligned_images.items():
        output_path = os.path.join(output_dir, filename)
        cv2.imwrite(output_path, cv2.cvtColor(aligned_image, cv2.COLOR_RGB2BGR))


In [None]:
folder = '/content/celeba_mtcnn_aligned'
zip_filename = '/content/celeba_mtcnn_aligned'
shutil.make_archive(zip_filename, 'zip', folder)

In [None]:
destination_directory = "/content/drive/MyDrive/VLM_for_FIQA/CelebA-HQ"
shutil.copy('/content/celeba_mtcnn_aligned.zip', destination_directory)