In [None]:
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from skimage.color import rgb2hsv # For color space conversion
from tqdm import tqdm # For progress bar
from multiprocessing import Pool, cpu_count # Import multiprocessing modules

# --- Helper function for processing a single image ---
def process_single_image_quality(img_path):
    """
    Processes a single image to extract its quality metrics.
    Designed to be run in a multiprocessing pool.
    Returns (file_size_kb, laplacian_var, h_channel_pixels, s_channel_pixels, v_channel_pixels, resolution)
    or None if processing fails.
    """
    file_size_kb = os.path.getsize(img_path) / 1024
    width, height = None, None
    laplacian_var = None
    h_channels, s_channels, v_channels = [], [], []

    try:
        # Use Pillow for resolution
        with Image.open(img_path) as img_pil:
            width, height = img_pil.size

        # Use OpenCV for other analyses
        img_cv = cv2.imread(img_path)
        if img_cv is None:
            # print(f"Could not read image: {img_path}") # Avoid prints in child processes
            return None

        # Clarity (Laplacian Variance)
        gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
        laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()

        # Color Distribution (HSV)
        # Ensure image is RGB for rgb2hsv if it's grayscale or a single channel
        if len(img_cv.shape) == 2 or img_cv.shape[2] == 1:
            img_cv = cv2.cvtColor(img_cv, cv2.COLOR_GRAY2BGR)
        hsv_img = rgb2hsv(img_cv) # skimage expects RGB, but cv2.imread loads BGR, so rgb2hsv handles BGR implicitly.
                                  # If you load with PIL and convert to numpy, ensure it's RGB.
        h_channels.extend(hsv_img[:,:,0].flatten())
        s_channels.extend(hsv_img[:,:,1].flatten())
        v_channels.extend(hsv_img[:,:,2].flatten())

        return (file_size_kb, laplacian_var, h_channels, s_channels, v_channels, (width, height))

    except Exception as e:
        # print(f"Error processing {img_path}: {e}") # Avoid prints in child processes
        return None

def analyze_image_quality_multithreaded(base_path="xinye10/train"):
    image_data = {'0': {'resolutions': [], 'file_sizes': [], 'laplacian_vars': [], 'h_channels': [], 's_channels': [], 'v_channels': []},
                  '1': {'resolutions': [], 'file_sizes': [], 'laplacian_vars': [], 'h_channels': [], 's_channels': [], 'v_channels': []}}

    num_processes = cpu_count() # Use all available CPU cores
    print(f"Using {num_processes} processes for image quality analysis.")

    for class_id in ['0', '1']:
        class_path = os.path.join(base_path, class_id)
        if not os.path.exists(class_path):
            print(f"Warning: Directory {class_path} not found. Skipping class {class_id}.")
            continue

        print(f"Analyzing images in class: {class_id}")
        image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg'))]
        image_paths = [os.path.join(class_path, img_name) for img_name in image_files]

        with Pool(processes=num_processes) as pool:
            # Use tqdm to show progress for the parallel processing
            results = list(tqdm(pool.imap_unordered(process_single_image_quality, image_paths),
                                total=len(image_paths),
                                desc=f"Processing Class {class_id}"))

        # Collect results from all processes
        for result in results:
            if result is not None:
                file_size_kb, laplacian_var, h_channels, s_channels, v_channels, resolution = result
                image_data[class_id]['file_sizes'].append(file_size_kb)
                image_data[class_id]['laplacian_vars'].append(laplacian_var)
                image_data[class_id]['h_channels'].extend(h_channels)
                image_data[class_id]['s_channels'].extend(s_channels)
                image_data[class_id]['v_channels'].extend(v_channels)
                if resolution[0] is not None and resolution[1] is not None: # Check if resolution was successfully obtained
                    image_data[class_id]['resolutions'].append(resolution)

    # --- Plotting and Statistics (remains the same as your original code) ---
    plt.figure(figsize=(18, 12))

    # File Size
    plt.subplot(2, 3, 1)
    if image_data['0']['file_sizes'] or image_data['1']['file_sizes']:
        plt.hist(image_data['0']['file_sizes'], bins=50, alpha=0.7, label='Class 0 (Fake)', color='skyblue')
        plt.hist(image_data['1']['file_sizes'], bins=50, alpha=0.7, label='Class 1 (Real)', color='lightcoral')
        plt.title('File Size Distribution (KB)')
        plt.xlabel('File Size (KB)')
        plt.ylabel('Count')
        plt.legend()
    else:
        plt.text(0.5, 0.5, "No file size data to plot.", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)


    # Sharpness (Laplacian Variance)
    plt.subplot(2, 3, 2)
    if image_data['0']['laplacian_vars'] or image_data['1']['laplacian_vars']:
        plt.hist(image_data['0']['laplacian_vars'], bins=50, alpha=0.7, label='Class 0 (Fake)', color='skyblue')
        plt.hist(image_data['1']['laplacian_vars'], bins=50, alpha=0.7, label='Class 1 (Real)', color='lightcoral')
        plt.title('Laplacian Variance (Sharpness)')
        plt.xlabel('Laplacian Variance')
        plt.ylabel('Count')
        plt.legend()
    else:
        plt.text(0.5, 0.5, "No sharpness data to plot.", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)


    # Hue (H)
    plt.subplot(2, 3, 4)
    if image_data['0']['h_channels'] or image_data['1']['h_channels']:
        plt.hist(image_data['0']['h_channels'], bins=50, alpha=0.7, label='Class 0 (Fake)', color='skyblue')
        plt.hist(image_data['1']['h_channels'], bins=50, alpha=0.7, label='Class 1 (Real)', color='lightcoral')
        plt.title('Hue (H) Channel Distribution')
        plt.xlabel('Hue Value (0-1)')
        plt.ylabel('Pixel Count')
        plt.legend()
    else:
        plt.text(0.5, 0.5, "No hue data to plot.", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)


    # Saturation (S)
    plt.subplot(2, 3, 5)
    if image_data['0']['s_channels'] or image_data['1']['s_channels']:
        plt.hist(image_data['0']['s_channels'], bins=50, alpha=0.7, label='Class 0 (Fake)', color='skyblue')
        plt.hist(image_data['1']['s_channels'], bins=50, alpha=0.7, label='Class 1 (Real)', color='lightcoral')
        plt.title('Saturation (S) Channel Distribution')
        plt.xlabel('Saturation Value (0-1)')
        plt.ylabel('Pixel Count')
        plt.legend()
    else:
        plt.text(0.5, 0.5, "No saturation data to plot.", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)


    # Value (V)
    plt.subplot(2, 3, 6)
    if image_data['0']['v_channels'] or image_data['1']['v_channels']:
        plt.hist(image_data['0']['v_channels'], bins=50, alpha=0.7, label='Class 0 (Fake)', color='skyblue')
        plt.hist(image_data['1']['v_channels'], bins=50, alpha=0.7, label='Class 1 (Real)', color='lightcoral')
        plt.title('Value (V) Channel Distribution')
        plt.xlabel('Value (0-1)')
        plt.ylabel('Pixel Count')
        plt.legend()
    else:
        plt.text(0.5, 0.5, "No value data to plot.", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)


    plt.tight_layout()
    plt.show()

    # Resolution Statistics
    print("\nResolution Statistics:")
    for class_id in ['0', '1']:
        if image_data[class_id]['resolutions']:
            widths = [res[0] for res in image_data[class_id]['resolutions'] if res[0] is not None]
            heights = [res[1] for res in image_data[class_id]['resolutions'] if res[1] is not None]
            if widths and heights: # Check if there are valid widths/heights
                print(f"Class {class_id} - Avg Width: {np.mean(widths):.2f}, Avg Height: {np.mean(heights):.2f}")
                print(f"Class {class_id} - Min Resolution: {min(image_data[class_id]['resolutions'])}, Max Resolution: {max(image_data[class_id]['resolutions'])}")
            else:
                print(f"Class {class_id} - No valid resolution data found.")
        else:
            print(f"Class {class_id} - No resolution data found.")


# Run the analysis
analyze_image_quality_multithreaded()

Using 96 processes for image quality analysis.
Analyzing images in class: 0


Processing Class 0:   4%|▎         | 4902/134049 [16:06<5:32:43,  6.47it/s]  