In [None]:
import os
from skimage.metrics import structural_similarity as ssim
import cv2
import numpy as np
from tqdm import tqdm
import shutil

def filter_synthetic_samples_by_ssim(
    synthetic_image_paths,
    real_image_paths,
    ssim_threshold=0.8,
    output_retained_dir=None,
    output_filtered_dir=None
):

    retained_paths = []
    filtered_paths = []

    if not real_image_paths:
        print("Warning: No real image paths provided for comparison. Returning all synthetic paths.")
        return synthetic_image_paths

    real_images_processed = []
    print(f"Loading and processing {len(real_image_paths)} real images...")
    for real_path in tqdm(real_image_paths):
        try:
            real_img = cv2.imread(real_path)
            if real_img is None:
                print(f"Warning: Could not load real image {real_path}. Skipping.")
                continue
            real_img_gray = cv2.cvtColor(real_img, cv2.COLOR_BGR2GRAY)
            real_images_processed.append(real_img_gray)
        except Exception as e:
            print(f"Error processing real image {real_path}: {e}")

    if not real_images_processed:
        print("Error: Could not process any real images. Aborting filtering.")
        return []

    if output_retained_dir:
        os.makedirs(output_retained_dir, exist_ok=True)
    if output_filtered_dir:
        os.makedirs(output_filtered_dir, exist_ok=True)

    print(f"\nFiltering {len(synthetic_image_paths)} synthetic images...")
    for synth_path in tqdm(synthetic_image_paths):
        max_ssim_score = -1.0
        try:
            synth_img = cv2.imread(synth_path)
            if synth_img is None:
                print(f"Warning: Could not load synthetic image {synth_path}. Filtering out.")
                filtered_paths.append(synth_path)
                if output_filtered_dir:
                     shutil.copy(synth_path, os.path.join(output_filtered_dir, os.path.basename(synth_path)))
                continue

            synth_img_gray = cv2.cvtColor(synth_img, cv2.COLOR_BGR2GRAY)

            for real_img_gray in real_images_processed:
                if synth_img_gray.shape != real_img_gray.shape:
                    target_shape = real_images_processed[0].shape[:2] # (height, width)
                    synth_img_gray_resized = cv2.resize(synth_img_gray, (target_shape[1], target_shape[0]))
                    current_real_img = real_img_gray
                    if synth_img_gray_resized.shape != current_real_img.shape:
                         current_real_img = cv2.resize(real_img_gray, (target_shape[1], target_shape[0]))

                else:
                    synth_img_gray_resized = synth_img_gray
                    current_real_img = real_img_gray


                score = ssim(synth_img_gray_resized, current_real_img, data_range=synth_img_gray_resized.max() - synth_img_gray_resized.min())
                if score > max_ssim_score:
                    max_ssim_score = score

            if max_ssim_score >= ssim_threshold:
                retained_paths.append(synth_path)
                if output_retained_dir:
                    shutil.copy(synth_path, os.path.join(output_retained_dir, os.path.basename(synth_path)))
            else:
                filtered_paths.append(synth_path)
                if output_filtered_dir:
                    shutil.copy(synth_path, os.path.join(output_filtered_dir, os.path.basename(synth_path)))

        except Exception as e:
            print(f"Error processing synthetic image {synth_path}: {e}. Filtering out.")
            filtered_paths.append(synth_path)
            if output_filtered_dir:
                 shutil.copy(synth_path, os.path.join(output_filtered_dir, os.path.basename(synth_path)))


    print(f"\nFiltering complete. Retained: {len(retained_paths)}, Filtered: {len(filtered_paths)}")
    return retained_paths