In [1]:
#!/usr/bin/env python3

"""
AI Photo Sorter (v6 - Hybrid Detection/Classification)
Sorts a directory of images using a multi-stage "waterfall" classification
process that combines specialized object detectors with powerful CLIP models
for high-accuracy sorting.

This version implements a new 5-pass logic:
1. Pass 1: Face Detector (MobileFaceDet)
2. Pass 2: Person Detector (DETR)
3. Pass 3: Human Classifier (CLIP-L-14)
4. Pass 4: View Classifier (CLIP-L-14)
5. Pass 5: View Classifier (CLIP-L-14-336)
6. Remaining images are classified as 'junk'.

REQUIREMENTS:
pip install torch transformers Pillow pillow-heif

--- USAGE IN JUPYTER NOTEBOOK ---
(Usage is identical to v5)
"""

import os
import sys
import shutil
import hashlib
import traceback
from contextlib import contextmanager
from collections import defaultdict

# --- Dependency Imports (grouped for checking) ---
try:
    from PIL import Image
except ImportError:
    Image = None

try:
    import torch
except ImportError:
    torch = None

try:
    from transformers import pipeline
except ImportError:
    pipeline = None

try:
    import pillow_heif
except ImportError:
    pillow_heif = None


class AIPhotoSorter:
    """
    Encapsulates all logic for scanning, classifying, and sorting images
    using a multi-stage hybrid waterfall classification system.
    """

    # --- Constants as Class Attributes ---
    BATCH_SIZE_DEFAULT = 32
    IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.heic', '.heif')
    HASH_CHUNK_SIZE = 8192  # Read files in 8KB chunks for hashing

    # --- NEW: Labels and Thresholds for Hybrid Waterfall Logic ---
    
    # Stage 1: Face Detector
    FACE_DETECTOR_THRESHOLD = 0.9  # High confidence for 'face'

    # Stage 2: Person Detector
    PERSON_DETECTOR_LABEL = 'person'
    PERSON_DETECTOR_THRESHOLD = 0.9 # High confidence for 'person'

    # Stage 3: Human Classifier (CLIP)
    HUMAN_LABELS = [
        "a photo of a person", "a photo of a face", "a portrait", "a selfie",
        "a photo of a baby", "a photo of a child", "a photo of a family",
        "a photo of a group of people"
    ]
    HUMAN_CLIP_THRESHOLD = 0.7  # For openai/clip-vit-large-patch14

    # Stage 4 & 5: View Classifiers (CLIP)
    VIEW_LABELS = [
        "a scenic landscape", "a photo of a mountain", "a photo of a beach",
        "a photo of a flower", "a photo of a forest", "a photo of a sunset",
        "a photo of a building", "architecture", "a photo of a city",
        "a photo of delicious food", "a photo of a meal"
    ]
    VIEW_CLIP_L14_THRESHOLD = 0.75 # For openai/clip-vit-large-patch14
    VIEW_CLIP_L14_336_THRESHOLD = 0.75 # For openai/clip-vit-large-patch14-336


    def __init__(self, root_path, output_path, batch_size=BATCH_SIZE_DEFAULT):
        self.root_path = root_path
        self.output_path = output_path
        self.batch_size = batch_size

        # Output directories
        self.output_dirs = {
            "family_photos": os.path.join(self.output_path, "family_photos"),
            "views": os.path.join(self.output_path, "views"),
            "junk": os.path.join(self.output_path, "junk")
        }

        # --- NEW: AI Models for Hybrid Waterfall ---
        self.face_detector_pipeline = None  # d-li/mobilefacedet
        self.person_detector_pipeline = None # facebook/detr-resnet-50
        self.clip_l14_pipeline = None      # openai/clip-vit-large-patch14
        self.clip_l14_336_pipeline = None  # openai/clip-vit-large-patch14-336
        
        # State
        self.seen_hashes = {}
        self.unique_image_paths = []
        
        # Statistics
        self.stats = defaultdict(int)

    @staticmethod
    def check_dependencies():
        """
        Checks if all required libraries are installed.
        Raises ImportError if a critical one is missing.
        """
        print("Checking dependencies...")
        if Image is None:
            raise ImportError("Error: 'Pillow' library not found. Please install with: pip install Pillow")
        if torch is None:
            raise ImportError("Error: 'torch' library not found. Please install with: pip install torch")
        if pipeline is None:
            raise ImportError("Error: 'transformers' library not found. Please install with: pip install transformers")
        
        if pillow_heif is None:
            print("Warning: 'pillow-heif' library not found. HEIC/HEIF files will not be processed.", file=sys.stderr)
            print("Install with: pip install pillow-heif", file=sys.stderr)
        else:
            pillow_heif.register_heif_opener()
            print("HEIF/HEIC support enabled.")
        print("All critical dependencies are present.")

    # --- Public Entry Point ---

    def run_sort(self):
        """
        Executes the entire sorting process from start to finish.
        This is the main public method to call.
        """
        print("--- AI Photo Sorter (v6) ---")
        print(f"Source: {self.root_path}")
        print(f"Destination: {self.output_path}")
        print(f"Batch Size: {self.batch_size}")
        print("-" * 25)

        try:
            self._setup_output_dirs()
            self._load_models()
            self._run_pass_1_scan_files()
            self._run_pass_2_process_batches()
            self._print_final_report()

        except Exception as e:
            print(f"\n--- [FATAL ERROR] ---", file=sys.stderr)
            print(f"An unexpected error occurred: {e}", file=sys.stderr)
            traceback.print_exc(file=sys.stderr)

    # --- Private Methods: Setup ---

    def _setup_output_dirs(self):
        """Creates all necessary output directories."""
        print("Setting up output directories...")
        for dir_path in self.output_dirs.values():
            os.makedirs(dir_path, exist_ok=True)
        print("Output directories ready.")

    def _load_models(self):
        """Loads and initializes the four AI models (pipelines)."""
        print("Loading AI models... (This may take a few minutes and download files on first run)")
        try:
            device = 0 if torch.cuda.is_available() else -1
            
            print("Loading Model 1: d-li/mobilefacedet (Face Detector)...")
            self.face_detector_pipeline = pipeline(
                "object-detection",
                model="d-li/mobilefacedet",
                device=device
            )

            print("Loading Model 2: facebook/detr-resnet-50 (Person Detector)...")
            self.person_detector_pipeline = pipeline(
                "object-detection",
                model="facebook/detr-resnet-50",
                device=device
            )
            
            print("Loading Model 3: openai/clip-vit-large-patch14 (Classifier)...")
            self.clip_l14_pipeline = pipeline(
                "zero-shot-image-classification",
                model="openai/clip-vit-large-patch14",
                device=device
            )
            
            print("Loading Model 4: openai/clip-vit-large-patch14-336 (Classifier)...")
            self.clip_l14_336_pipeline = pipeline(
                "zero-shot-image-classification",
                model="openai/clip-vit-large-patch14-336",
                device=device
            )
            
            print("Models loaded successfully.")
        except Exception as e:
            print(f"Error loading models: {e}", file=sys.stderr)
            print("Please ensure you have an internet connection.", file=sys.stderr)
            raise # Re-raise to stop the process

    # --- Private Methods: Pass 1 (Scanning) ---

    def _run_pass_1_scan_files(self):
        """
        Recursively scans the root_path, finds all images, calculates hashes,
        and builds a de-duplicated list of unique images to process.
        """
        # (This method is identical to v5)
        print("\n--- Pass 1: Scanning for duplicates and building file list ---")
        
        for dirpath, _, filenames in os.walk(self.root_path):
            print(f"Scanning: {dirpath}")
            for filename in filenames:
                if not filename.lower().endswith(self.IMAGE_EXTENSIONS):
                    continue
                
                self.stats['total_files_scanned'] += 1
                full_path = os.path.join(dirpath, filename)

                file_hash = self._calculate_hash(full_path)
                if not file_hash:
                    self.stats['total_errors_pass1'] += 1
                    continue

                if file_hash in self.seen_hashes:
                    print(f"   -> Duplicate of: {self.seen_hashes[file_hash]} (Skipping)")
                    self.stats['skipped_duplicates'] += 1
                    continue
                else:
                    self.seen_hashes[file_hash] = full_path
                    self.unique_image_paths.append(full_path)

        print(f"\n--- Pass 1 Complete ---")
        print(f"Total files scanned: {self.stats['total_files_scanned']}")
        print(f"Duplicate files found: {self.stats['skipped_duplicates']}")
        print(f"Unique images to process: {len(self.unique_image_paths)}")
        print(f"Hashing/Read errors: {self.stats['total_errors_pass1']}")

    # --- Private Methods: Pass 2 (Processing) ---

    def _run_pass_2_process_batches(self):
        """
        Iterates over the list of unique image paths in fixed-size batches
        and calls _process_one_batch for each.
        """
        # (This method is identical to v5)
        print("\n--- Pass 2: Processing unique images in batches ---")
        
        num_batches = (len(self.unique_image_paths) + self.batch_size - 1) // self.batch_size
        
        for i in range(0, len(self.unique_image_paths), self.batch_size):
            batch_paths = self.unique_image_paths[i : i + self.batch_size]
            current_batch_num = (i // self.batch_size) + 1
            
            print(f"\nProcessing Batch {current_batch_num} / {num_batches} (Size: {len(batch_paths)})...")
            
            try:
                self._process_one_batch(batch_paths)
            except Exception as e:
                print(f"   [CRITICAL BATCH ERROR] Failed to process batch: {e}", file=sys.stderr)
                print("     -> Skipping this entire batch. These files will not be moved.", file=sys.stderr)
                traceback.print_exc(file=sys.stderr)
                self.stats['total_errors_pass2'] += len(batch_paths)

    def _process_one_batch(self, batch_paths):
        """
        Processes a single batch of images through the 5-stage hybrid waterfall.
        """
        
        # 1. Load image objects for this batch
        valid_paths_in_batch, batch_image_objects = self._load_batch_images(batch_paths)
        
        if not batch_image_objects:
            print("   -> All images in batch failed to load. Skipping.")
            return

        # This tracks which *original index* in the batch still needs processing
        remaining_indices = list(range(len(batch_image_objects)))
        # This stores the final decision for each *original index*
        final_decisions = {} # { 0: ("family_photos", "face"), 1: ("junk", "other"), ... }

        # --- Helper for Object Detection Pass ---
        def _run_object_detection_pass(model, images, indices, target_labels, threshold, category_name):
            if not images:
                return
            
            print(f"   Running Pass ({category_name} Detector) on {len(images)} images...")
            all_results = self._run_object_detection_batch(model, images)
            
            # Iterate backwards so we can safely remove indices
            for i in range(len(indices) - 1, -1, -1):
                original_index = indices[i]
                detections = all_results[i] # This is a list of dicts
                
                found_match = False
                for obj in detections:
                    if obj['label'] in target_labels and obj['score'] > threshold:
                        final_decisions[original_index] = (category_name, obj['label'])
                        indices.pop(i) # Remove index from remaining list
                        found_match = True
                        break # Move to the next image
        
        # --- Helper for Classification Pass ---
        def _run_classification_pass(model, images, indices, labels, threshold, category_name):
            if not images:
                return
            
            print(f"   Running Pass ({category_name} Classifier) on {len(images)} images...")
            results = self._run_zero_shot_batch(model, images, labels)
            
            for i in range(len(indices) - 1, -1, -1):
                original_index = indices[i]
                result = results[i][0] # Get top result
                
                if result['label'] in labels and result['score'] > threshold:
                    final_decisions[original_index] = (category_name, result['label'])
                    indices.pop(i)
        
        # --- Stage 1: Face Detector Pass ---
        images_to_check = [batch_image_objects[i] for i in remaining_indices]
        _run_object_detection_pass(self.face_detector_pipeline, images_to_check, remaining_indices, 
                                   {'face'}, self.FACE_DETECTOR_THRESHOLD, "family_photos")

        # --- Stage 2: Person Detector Pass ---
        images_to_check = [batch_image_objects[i] for i in remaining_indices]
        _run_object_detection_pass(self.person_detector_pipeline, images_to_check, remaining_indices, 
                                   {self.PERSON_DETECTOR_LABEL}, self.PERSON_DETECTOR_THRESHOLD, "family_photos")

        # --- Stage 3: Human Classifier (CLIP L-14) ---
        images_to_check = [batch_image_objects[i] for i in remaining_indices]
        _run_classification_pass(self.clip_l14_pipeline, images_to_check, remaining_indices, 
                                 self.HUMAN_LABELS, self.HUMAN_CLIP_THRESHOLD, "family_photos")

        # --- Stage 4: View Classifier (CLIP L-14) ---
        images_to_check = [batch_image_objects[i] for i in remaining_indices]
        _run_classification_pass(self.clip_l14_pipeline, images_to_check, remaining_indices, 
                                 self.VIEW_LABELS, self.VIEW_CLIP_L14_THRESHOLD, "views")

        # --- Stage 5: View Classifier (CLIP L-14-336) ---
        images_to_check = [batch_image_objects[i] for i in remaining_indices]
        _run_classification_pass(self.clip_l14_336_pipeline, images_to_check, remaining_indices, 
                                 self.VIEW_LABELS, self.VIEW_CLIP_L14_336_THRESHOLD, "views")

        # --- Stage 6: Junk Pass ---
        print(f"   ... {len(remaining_indices)} images remaining, classifying as 'junk'.")
        for index in remaining_indices:
            final_decisions[index] = ("junk", "other")
            
        # --- Final Step: Move Files ---
        print("   Moving files to destinations...")
        for i in range(len(valid_paths_in_batch)):
            full_path = valid_paths_in_batch[i]
            original_filename = os.path.basename(full_path)
            
            dest_folder_name, label = final_decisions[i]
            
            dest_dir = self.output_dirs[dest_folder_name]
            dest_path = self._get_unique_dest_path(dest_dir, original_filename)
            
            try:
                shutil.move(full_path, dest_path)
                print(f"     -> Moved {original_filename} to {dest_folder_name} (as: {label})")
                self.stats['processed_files_moved'] += 1
            except Exception as e:
                print(f"   [Error] Failed to move {original_filename}: {e}", file=sys.stderr)
                self.stats['total_errors_pass2'] += 1


    def _load_batch_images(self, batch_paths):
        """Helper to load all images for a batch, skipping failures."""
        # (This method is identical to v5)
        batch_image_objects = []
        valid_paths_in_batch = []
        for path in batch_paths:
            image = self._load_image_for_batch(path)
            if image:
                batch_image_objects.append(image)
                valid_paths_in_batch.append(path)
            else:
                self.stats['total_errors_pass2'] += 1
        return valid_paths_in_batch, batch_image_objects

    # --- Private Methods: Inference ---

    def _run_object_detection_batch(self, model_pipeline, image_batch):
        """
        Runs an object detection pipeline on a batch of images.
        Returns a list of detection results (list of lists).
        """
        try:
            # The pipeline returns a list[list[dict]]
            batch_results = model_pipeline(image_batch)
            return batch_results
        except Exception as e:
            print(f"   [Error] Object detection model failed: {e}", file=sys.stderr)
            raise  # Re-raise to be caught by the batch processor

    def _run_zero_shot_batch(self, model_pipeline, image_batch, candidate_labels):
        """
        Runs a zero-shot classifier on a batch of images.
        Returns a list of classification results.
        """
        # (This method is identical to v5's)
        try:
            batch_results = model_pipeline(image_batch, candidate_labels=candidate_labels)
            return batch_results
        except Exception as e:
            print(f"   [Error] Zero-shot model failed: {e}", file=sys.stderr)
            raise  # Re-raise to be caught by the batch processor

    # --- Private Methods: Helpers & Reporting ---

    def _calculate_hash(self, file_path):
        """Calculates the SHA256 hash of a file efficiently."""
        # (This method is identical to v5)
        hasher = hashlib.sha256()
        try:
            with open(file_path, 'rb') as f:
                while chunk := f.read(self.HASH_CHUNK_SIZE):
                    hasher.update(chunk)
            return hasher.hexdigest()
        except (IOError, OSError) as e:
            print(f"   [Warning] Could not hash file {file_path}: {e}", file=sys.stderr)
            return None

    def _get_unique_dest_path(self, dest_dir, file_name):
        """
        Checks if a file exists. If so, appends a counter
        (e.g., '_1', '_2') until a unique name is found.
        """
        # (This method is identical to v5)
        dest_path = os.path.join(dest_dir, file_name)
        if not os.path.exists(dest_path):
            return dest_path

        base, ext = os.path.splitext(file_name)
        counter = 1
        while True:
            new_name = f"{base}_{counter}{ext}"
            new_dest_path = os.path.join(dest_dir, new_name)
            if not os.path.exists(new_dest_path):
                return new_dest_path
            counter += 1

    @staticmethod
    @contextmanager
    def _suppress_pil_warnings():
        """Context manager to suppress known PIL warnings."""
        # (This method is identical to v5)
        import warnings
        warnings.filterwarnings(
            "ignore",
            "(Possibly corrupt EXIF data|Image size.*exceeds pixel limit)"
        )
        yield
        warnings.resetwarnings()

    def _load_image_for_batch(self, path):
        """
        Safely opens, converts, and loads one image.
        Returns None if the image fails to load.
        """
        # (This method is identical to v5)
        try:
            with self._suppress_pil_warnings(), Image.open(path) as image:
                if image.mode == 'RGBA':
                    image_rgb = image.convert('RGB')
                else:
                    image.load() # Must load to keep in memory
                    image_rgb = image
                return image_rgb
        except Exception as e:
            print(f"   [Error] Failed to load {path}: {e}", file=sys.stderr)
            return None

    def _print_final_report(self):
        """Prints a summary of all actions taken."""
        # (This method is identical to v5)
        print("\n--- Sorting Complete ---")
        print(f"Total files scanned: {self.stats['total_files_scanned']}")
        print(f"Duplicate files skipped: {self.stats['skipped_duplicates']}")
        print(f"Unique files processed: {self.stats['processed_files_moved']}")
        
        total_errors = self.stats['total_errors_pass1'] + self.stats['total_errors_pass2']
        print(f"Total errors: {total_errors}")
        
        print(f"\nCheck the folders in: {self.output_path}")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# --- Main execution ---
def main():
    """
    Main function to run the script.
    Set your source and output folders here.
    """
    
    # --- CONFIGURE YOUR PATHS HERE ---
    source_folder = "D:\images\HockingHills"
    output_folder = "D:\images\output"
    # --- END CONFIGURATION ---

    
    # Check if the user has updated the default paths
    if "C:\\path\\to" in source_folder:
        print("Error: Please update the 'source_folder' and 'output_folder' variables", file=sys.stderr)
        print("inside the main() function at the bottom of the script.", file=sys.stderr)
        return

    if not os.path.isdir(source_folder):
        print(f"Error: Source folder not found: {source_folder}", file=sys.stderr)
        return
        
    try:
        # 1. Check dependencies first
        AIPhotoSorter.check_dependencies()
        
        # 2. Create and run the sorter instance
        sorter = AIPhotoSorter(source_folder, output_folder, batch_size=32)
        sorter.run_sort()

    except Exception as e:
        print(f"\n--- [FATAL ERROR] ---", file=sys.stderr)
        print(f"An unexpected error occurred during setup: {e}", file=sys.stderr)
        traceback.print_exc(file=sys.stderr)


main()

Checking dependencies...
HEIF/HEIC support enabled.
All critical dependencies are present.
--- AI Photo Sorter (v6) ---
Source: D:\images\HockingHills
Destination: D:\images\output
Batch Size: 32
-------------------------
Setting up output directories...
Output directories ready.
Loading AI models... (This may take a few minutes and download files on first run)
Loading Model 1: d-li/mobilefacedet (Face Detector)...


  source_folder = "D:\images\HockingHills"
  output_folder = "D:\images\output"
Error loading models: d-li/mobilefacedet is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`
Please ensure you have an internet connection.

--- [FATAL ERROR] ---
An unexpected error occurred: d-li/mobilefacedet is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`
Traceback (most recent call last):
  File "c:\Users\xiaom\.conda\envs\llm\Lib\site-packages\huggingface_hub\utils\_http.py", line 409, in hf_raise_for_status
    response.raise_for_status()
  File "c:\Users\xia