In [None]:
#!/usr/bin/env python3

"""
Image Data Loading Pipeline (v1)

This module provides a custom PyTorch Dataset and DataLoader setup
designed specifically for inference tasks where we need to process
images and know their original file paths.

It encapsulates all "Pass 1" logic:
1. Recursive file scanning.
2. Filtering by image extension.
3. SHA256 hashing for de-duplication.
4. Safe image loading (handles corrupt files, RGBA, etc.).

It provides:
- `InferenceDataset`: A PyTorch Dataset that finds all unique
  images and returns (PIL_Image, file_path) tuples.
- `collate_fn`: A custom collate function for the DataLoader that
  batches images and paths separately, and isolates failures.

USAGE (in another script):

from image_dataloader import check_dependencies, InferenceDataset, collate_fn
from torch.utils.data import DataLoader

try:
    check_dependencies()
    
    SOURCE_PATH = r"C:\path\to\your\photos"
    BATCH_SIZE = 32

    dataset = InferenceDataset(SOURCE_PATH)
    print("Building file list (scanning and hashing)...")
    dataset.build_file_list()
    print(f"Found {len(dataset)} unique images.")

    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_fn,
        num_workers=4,  # Adjust based on your CPU
        shuffle=False # Order doesn't matter for inference
    )

    for batch_num, (batch_images, batch_paths, failed_paths) in enumerate(loader):
        print(f"\n--- Processing Batch {batch_num + 1} ---")
        print(f"  Images to process: {len(batch_images)}")
        print(f"  Paths: {batch_paths}")
        print(f"  Failed to load: {failed_paths}")
        
        # --- YOUR INFERENCE LOGIC HERE ---
        # e.g., results_pass_1 = model_1(batch_images)
        # ...then map results back to batch_paths to move files.

except ImportError as e:
    print(e, file=sys.stderr)

"""

import os
import sys
import hashlib
from contextlib import contextmanager

# --- Dependency Imports (grouped for checking) ---
try:
    from PIL import Image
except ImportError:
    Image = None

try:
    import torch
    from torch.utils.data import Dataset, DataLoader
except ImportError:
    torch = None
    Dataset = object # Dummy class for inheritance if torch is missing

try:
    import pillow_heif
except ImportError:
    pillow_heif = None

# --- Constants ---
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.heic', '.heif')
HASH_CHUNK_SIZE = 8192  # Read files in 8KB chunks for hashing

# --- Prerequisite Functions ---

def check_dependencies():
    """
    Checks if all required libraries are installed.
    Raises ImportError if a critical one is missing.
    """
    print("Checking dependencies...")
    if Image is None:
        raise ImportError("Error: 'Pillow' library not found. Please install with: pip install Pillow")
    if torch is None:
        raise ImportError("Error: 'torch' library not found. Please install with: pip install torch")
    
    if pillow_heif is None:
        print("Warning: 'pillow-heif' library not found. HEIC/HEIF files will not be processed.", file=sys.stderr)
        print("Install with: pip install pillow-heif", file=sys.stderr)
    else:
        # Register the HEIF plugin with PIL
        pillow_heif.register_heif_opener()
        print("HEIF/HEIC support enabled.")
    print("All critical dependencies are present.")

# --- Helper Functions (used by the Dataset) ---

def _calculate_hash(file_path):
    """Calculates the SHA256 hash of a file efficiently."""
    hasher = hashlib.sha256()
    try:
        with open(file_path, 'rb') as f:
            while chunk := f.read(HASH_CHUNK_SIZE):
                hasher.update(chunk)
        return hasher.hexdigest()
    except (IOError, OSError) as e:
        print(f"   [Warning] Could not hash file {file_path}: {e}", file=sys.stderr)
        return None

@contextmanager
def _suppress_pil_warnings():
    """Context manager to suppress known PIL warnings."""
    import warnings
    warnings.filterwarnings(
        "ignore",
        "(Possibly corrupt EXIF data|Image size.*exceeds pixel limit)"
    )
    yield
    warnings.resetwarnings()

def _load_image_for_batch(path):
    """
    Safely opens, converts, and loads one image.
    Returns None if the image fails to load.
    """
    try:
        with _suppress_pil_warnings(), Image.open(path) as image:
            if image.mode == 'RGBA':
                # Convert RGBA to RGB (common for PNGs)
                image_rgb = image.convert('RGB')
            else:
                # Load image data into memory to close the file handle
                image.load()
                image_rgb = image
            return image_rgb
    except Exception as e:
        # We don't print here; we let the collate_fn report it
        # print(f"   [Error] Failed to load {path}: {e}", file=sys.stderr)
        return None

# --- Core Dataloader Components ---

class InferenceDataset(Dataset):
    """
    A PyTorch Dataset that scans a directory, de-duplicates files
    based on their hash, and returns (PIL_Image, file_path) tuples
    for use in an inference pipeline.
    """
    def __init__(self, root_path, image_extensions=IMAGE_EXTENSIONS):
        if not os.path.isdir(root_path):
            raise NotADirectoryError(f"Root path not found: {root_path}")
            
        self.root_path = root_path
        self.image_extensions = image_extensions
        self.unique_image_paths = []
        self.seen_hashes = {}
        
        # Statistics
        self.total_files_scanned = 0
        self.skipped_duplicates = 0
        self.hashing_errors = 0

    def build_file_list(self):
        """
        Performs the "Pass 1" logic to scan, hash, and de-duplicate
        all images in the root_path. This must be called before
        using the Dataset with a DataLoader.
        """
        print(f"Scanning {self.root_path} for images...")
        for dirpath, _, filenames in os.walk(self.root_path):
            print(f"  Scanning: {dirpath}")
            for filename in filenames:
                if not filename.lower().endswith(self.image_extensions):
                    continue
                
                self.total_files_scanned += 1
                full_path = os.path.join(dirpath, filename)

                file_hash = _calculate_hash(full_path)
                if not file_hash:
                    self.hashing_errors += 1
                    continue

                if file_hash in self.seen_hashes:
                    # print(f"   -> Duplicate of: {self.seen_hashes[file_hash]} (Skipping)")
                    self.skipped_duplicates += 1
                    continue
                else:
                    self.seen_hashes[file_hash] = full_path
                    self.unique_image_paths.append(full_path)
        
        print("\n--- Scan Complete ---")
        print(f"Total files scanned: {self.total_files_scanned}")
        print(f"Duplicate files found: {self.skipped_duplicates}")
        print(f"Hashing/Read errors: {self.hashing_errors}")
        print(f"Unique images to process: {len(self.unique_image_paths)}")

    def __len__(self):
        """Returns the number of unique images found."""
        return len(self.unique_image_paths)

    def __getitem__(self, index):
        """
        Loads one image and returns it with its original path.
        Returns (None, path) if loading fails.
        """
        file_path = self.unique_image_paths[index]
        image = _load_image_for_batch(file_path)
        return image, file_path


def collate_fn(batch):
    """
    Custom collate function for the InferenceDataLoader.
    
    Input:
      - batch: A list of (image, file_path) tuples from __getitem__.
               Some images may be None if loading failed.
               
    Output:
      - A tuple of three lists:
        1. batch_images: A list of successfully loaded PIL.Image objects.
        2. batch_paths: A list of file paths corresponding to batch_images.
        3. failed_paths: A list of file paths that failed to load.
    """
    batch_images = []
    batch_paths = []
    failed_paths = []

    for image, path in batch:
        if image is not None:
            batch_images.append(image)
            batch_paths.append(path)
        else:
            failed_paths.append(path)
            
    return batch_images, batch_paths, failed_paths




In [None]:
# --- Example Usage ---
if __name__ == "__main__":
    """
    This block demonstrates how to use the InferenceDataset
    and collate_fn with a PyTorch DataLoader.
    """
    
    # --- CONFIGURE YOUR TEST PATH HERE ---
    # IMPORTANT: Update this to a small test directory of images
    SOURCE_PATH = r"C:\path\to\your\test_photos"
    BATCH_SIZE = 8
    
    if "C:\\path\\to" in SOURCE_PATH:
        print("Error: Please update the 'SOURCE_PATH' variable in the", file=sys.stderr)
        print("if __name__ == '__main__': block to point to a test folder.", file=sys.stderr)
        sys.exit(1)

    try:
        # 1. Check dependencies
        check_dependencies()
        
        # 2. Create dataset and build the file list
        dataset = InferenceDataset(SOURCE_PATH)
        dataset.build_file_list()

        if len(dataset) == 0:
            print("\nNo images found in the source path. Exiting.")
            sys.exit(0)

        # 3. Create the DataLoader
        loader = DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            collate_fn=collate_fn,
            num_workers=0,  # 0 is safer for testing. Increase to 2 or 4 for speed.
            shuffle=False
        )

        print(f"\n--- Starting DataLoader Test (1 Batch) ---")
        
        # 4. Loop through the loader
        for i, (batch_images, batch_paths, failed_paths) in enumerate(loader):
            print(f"\nSuccessfully loaded batch {i+1}")
            print(f"  Images in batch: {len(batch_images)}")
            
            if batch_images:
                print(f"  First image size: {batch_images[0].size}")
                print(f"  First image path: {batch_paths[0]}")

            if failed_paths:
                print(f"  Failed to load {len(failed_paths)} images:")
                for path in failed_paths:
                    print(f"    - {path}")
            
            # Stop after one batch for this demo
            print("\nDemo complete. This is where your inference logic would run.")
            break

    except ImportError as e:
        print(f"ImportError: {e}", file=sys.stderr)
    except NotADirectoryError as e:
        print(f"Error: {e}", file=sys.stderr)
    except Exception as e:
        print(f"An unexpected error occurred: {e}", file=sys.stderr)
