In [1]:
#!/usr/bin/env python3

"""
AI Photo Sorter (v4 - Class-based)
Sorts a directory of images using a two-stage "waterfall" classification
and proper, fixed-size batching for memory safety and performance.

This version refactors the v3 script into an object-oriented class
for better organization and maintainability.

REQUIREMENTS:
pip install torch transformers Pillow pillow-heif

--- USAGE IN JUPYTER NOTEBOOK ---
1. Make sure you have run: !pip install torch transformers Pillow pillow-heif
2. Define your paths in a notebook cell:

    source_folder = r"C:\path\to\your\iphone_photos_copy"
    output_folder = r"C:\path\to\my_sorted_photos"

3. In the next cell, import, create, and run the sorter:

    if 'source_folder' in locals() and os.path.isdir(source_folder):
        try:
            # Check dependencies once
            AIPhotoSorter.check_dependencies()
            
            # Create an instance and run the sort
            sorter = AIPhotoSorter(source_folder, output_folder, batch_size=32)
            sorter.run_sort()
            
        except Exception as e:
            print(f"An error occurred: {e}")
    else:
        print("Please define 'source_folder' and 'output_folder' in your notebook cell.")
"""

import os
import sys
import shutil
import hashlib
import traceback
from contextlib import contextmanager
from collections import defaultdict

# --- Dependency Imports (grouped for checking) ---
try:
    from PIL import Image
except ImportError:
    Image = None

try:
    import torch
except ImportError:
    torch = None

try:
    from transformers import pipeline
except ImportError:
    pipeline = None

try:
    import pillow_heif
except ImportError:
    pillow_heif = None


class AIPhotoSorter:
    """
    Encapsulates all logic for scanning, classifying, and sorting images
    into a structured, class-based system.
    """

    # --- Constants as Class Attributes ---
    BATCH_SIZE_DEFAULT = 32

    CATEGORIES = [
        "a photo of a person, face, or family",
        "a scenic landscape, mountain, beach, or flower",
        "a photo of delicious food or a meal",
        "a screenshot of a phone or computer screen",
        "a photo of a receipt, document, or white-board",
        "a graphic, logo, or drawing",
        "a blurry or dark photo"
    ]

    CATEGORY_TO_FOLDER = {
        "a photo of a person, face, or family": "family_photos",
        "a scenic landscape, mountain, beach, or flower": "views",
        "a photo of delicious food or a meal": "views",
        "a screenshot of a phone or computer screen": "junk",
        "a photo of a receipt, document, or white-board": "junk",
        "a graphic, logo, or drawing": "junk",
        "a blurry or dark photo": "junk"
    }

    OBJECT_DETECTOR_PERSON_LABELS = {'person'}
    OBJECT_DETECTOR_PERSON_THRESHOLD = 0.8  # Be pretty confident it's a person
    OBJECT_DETECTOR_VIEW_LABELS = {
        'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
        'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
        'dining table', 'bowl', 'cup', 'fork', 'knife', 'spoon',
        'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
        'parking meter', 'bench', 'motorcycle', 'bicycle',
        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
        'frisbee'
    }
    OBJECT_DETECTOR_VIEW_THRESHOLD = 0.7 # Moderately confident

    IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.heic', '.heif')
    HASH_CHUNK_SIZE = 8192  # Read files in 8KB chunks for hashing

    def __init__(self, root_path, output_path, batch_size=BATCH_SIZE_DEFAULT):
        self.root_path = root_path
        self.output_path = output_path
        self.batch_size = batch_size

        # Output directories
        self.output_dirs = {
            "family_photos": os.path.join(self.output_path, "family_photos"),
            "views": os.path.join(self.output_path, "views"),
            "junk": os.path.join(self.output_path, "junk")
        }

        # AI Models
        self.classifier = None
        self.detector = None
        
        # State
        self.seen_hashes = {}  # { 'hash_value': 'path_to_first_file_seen' }
        self.unique_image_paths = [] # Master list of paths to process
        
        # Statistics
        self.stats = defaultdict(int)

    @staticmethod
    def check_dependencies():
        """
        Checks if all required libraries are installed.
        Raises ImportError if a critical one is missing.
        """
        print("Checking dependencies...")
        if Image is None:
            raise ImportError("Error: 'Pillow' library not found. Please install with: pip install Pillow")
        if torch is None:
            raise ImportError("Error: 'torch' library not found. Please install with: pip install torch")
        if pipeline is None:
            raise ImportError("Error: 'transformers' library not found. Please install with: pip install transformers")
        
        if pillow_heif is None:
            print("Warning: 'pillow-heif' library not found. HEIC/HEIF files will not be processed.", file=sys.stderr)
            print("Install with: pip install pillow-heif", file=sys.stderr)
        else:
            # Register the HEIF plugin with PIL
            pillow_heif.register_heif_opener()
            print("HEIF/HEIC support enabled.")
        print("All critical dependencies are present.")

    # --- Public Entry Point ---

    def run_sort(self):
        """
        Executes the entire sorting process from start to finish.
        This is the main public method to call.
        """
        print("--- AI Photo Sorter (v4) ---")
        print(f"Source: {self.root_path}")
        print(f"Destination: {self.output_path}")
        print(f"Batch Size: {self.batch_size}")
        print("-" * 25)

        try:
            self._setup_output_dirs()
            self._load_models()
            self._run_pass_1_scan_files()
            self._run_pass_2_process_batches()
            self._print_final_report()

        except Exception as e:
            print(f"\n--- [FATAL ERROR] ---", file=sys.stderr)
            print(f"An unexpected error occurred: {e}", file=sys.stderr)
            traceback.print_exc(file=sys.stderr)

    # --- Private Methods: Setup ---

    def _setup_output_dirs(self):
        """Creates all necessary output directories."""
        print("Setting up output directories...")
        for dir_path in self.output_dirs.values():
            os.makedirs(dir_path, exist_ok=True)
        print("Output directories ready.")

    def _load_models(self):
        """Loads and initializes the AI models (pipelines)."""
        print("Loading AI models... (This may take a few minutes and download files on first run)")
        try:
            device = 0 if torch.cuda.is_available() else -1
            
            print("Loading Model 1: Zero-Shot Classifier (openai/clip-vit-large-patch14)...")
            self.classifier = pipeline(
                "zero-shot-image-classification",
                model="openai/clip-vit-large-patch14",
                device=device
            )
            
            print("Loading Model 2: Object Detector (facebook/detr-resnet-50)...")
            self.detector = pipeline(
                "object-detection",
                model="facebook/detr-resnet-50",
                device=device
            )
            
            print("Models loaded successfully.")
        except Exception as e:
            print(f"Error loading models: {e}", file=sys.stderr)
            print("Please ensure you have an internet connection.", file=sys.stderr)
            raise # Re-raise to stop the process

    # --- Private Methods: Pass 1 (Scanning) ---

    def _run_pass_1_scan_files(self):
        """
        Recursively scans the root_path, finds all images, calculates hashes,
        and builds a de-duplicated list of unique images to process.
        """
        print("\n--- Pass 1: Scanning for duplicates and building file list ---")
        
        for dirpath, _, filenames in os.walk(self.root_path):
            print(f"Scanning: {dirpath}")
            for filename in filenames:
                if not filename.lower().endswith(self.IMAGE_EXTENSIONS):
                    continue
                
                self.stats['total_files_scanned'] += 1
                full_path = os.path.join(dirpath, filename)

                file_hash = self._calculate_hash(full_path)
                if not file_hash:
                    self.stats['total_errors_pass1'] += 1
                    continue  # Skip if hashing failed

                if file_hash in self.seen_hashes:
                    print(f"   -> Duplicate of: {self.seen_hashes[file_hash]} (Skipping)")
                    self.stats['skipped_duplicates'] += 1
                    continue
                else:
                    self.seen_hashes[file_hash] = full_path
                    self.unique_image_paths.append(full_path) # Add to our master list

        print(f"\n--- Pass 1 Complete ---")
        print(f"Total files scanned: {self.stats['total_files_scanned']}")
        print(f"Duplicate files found: {self.stats['skipped_duplicates']}")
        print(f"Unique images to process: {len(self.unique_image_paths)}")
        print(f"Hashing/Read errors: {self.stats['total_errors_pass1']}")

    # --- Private Methods: Pass 2 (Processing) ---

    def _run_pass_2_process_batches(self):
        """
        Iterates over the list of unique image paths in fixed-size batches
        and calls _process_one_batch for each.
        """
        print("\n--- Pass 2: Processing unique images in batches ---")
        
        num_batches = (len(self.unique_image_paths) + self.batch_size - 1) // self.batch_size
        
        for i in range(0, len(self.unique_image_paths), self.batch_size):
            batch_paths = self.unique_image_paths[i : i + self.batch_size]
            current_batch_num = (i // self.batch_size) + 1
            
            print(f"\nProcessing Batch {current_batch_num} / {num_batches} (Size: {len(batch_paths)})...")
            
            self._process_one_batch(batch_paths)

    def _process_one_batch(self, batch_paths):
        """
        Processes a single batch of images:
        1. Loads images into memory.
        2. Runs Model 1 (Classification).
        3. Runs Model 2 (Detection) on 'junk' candidates.
        4. Categorizes and moves files.
        """
        
        # 1. Load image objects for this batch
        valid_paths_in_batch, batch_image_objects = self._load_batch_images(batch_paths)
        
        if not batch_image_objects:
            print("   -> All images in batch failed to load. Skipping.")
            return
        
        try:
            # 2. Classify Batch with Model 1 (CLIP)
            batch_clip_results = self._run_classification_batch(batch_image_objects)

            # 3. Check "junk" with Model 2 (DETR)
            batch_detr_results = self._run_detection_on_junk(batch_clip_results, batch_image_objects)

            # 4. Iterate over batch results, decide destination, and move files
            print("   Categorizing and moving files...")
            self._categorize_and_move_batch(valid_paths_in_batch, batch_clip_results, batch_detr_results)
        
        except Exception as e:
            print(f"   [CRITICAL BATCH ERROR] Failed to process batch: {e}", file=sys.stderr)
            print("     -> Skipping this entire batch. These files will not be moved.", file=sys.stderr)
            traceback.print_exc(file=sys.stderr)
            self.stats['total_errors_pass2'] += len(valid_paths_in_batch)

    def _load_batch_images(self, batch_paths):
        """Helper to load all images for a batch, skipping failures."""
        batch_image_objects = []
        valid_paths_in_batch = []
        for path in batch_paths:
            image = self._load_image_for_batch(path)
            if image:
                batch_image_objects.append(image)
                valid_paths_in_batch.append(path)
            else:
                self.stats['total_errors_pass2'] += 1
        return valid_paths_in_batch, batch_image_objects

    def _run_detection_on_junk(self, batch_clip_results, batch_image_objects):
        """Finds junk images, runs detection, and maps results back."""
        junk_indices = [
            idx for idx, result in enumerate(batch_clip_results)
            if self.CATEGORY_TO_FOLDER.get(result[0]['label'], "junk") == "junk"
        ]
        
        junk_images_to_check = [batch_image_objects[idx] for idx in junk_indices]
        batch_detr_results = {} # Will store {original_index: [results]}
        
        if junk_images_to_check:
            detr_results_list = self._run_detection_batch(junk_images_to_check)
            
            # Map results back to their original batch index
            for j, detr_result in enumerate(detr_results_list):
                original_batch_index = junk_indices[j]
                batch_detr_results[original_batch_index] = detr_result
        
        return batch_detr_results

    def _categorize_and_move_batch(self, valid_paths, clip_results, detr_results):
        """Loops through batch results, determines final category, and moves files."""
        for j in range(len(valid_paths)):
            full_path = valid_paths[j]
            original_filename = os.path.basename(full_path)
            clip_result = clip_results[j]
            
            top_label = clip_result[0]['label']
            dest_folder_name = self.CATEGORY_TO_FOLDER.get(top_label, "junk")

            if dest_folder_name == "junk":
                # This image was in the junk pile, check Model 2's results
                objects = detr_results.get(j, []) # Get results, or empty list
                dest_folder_name = self._reclassify_junk(objects)

            # F) Handle Filename Collisions and Move
            dest_dir = self.output_dirs[dest_folder_name]
            dest_path = self._get_unique_dest_path(dest_dir, original_filename)
            
            shutil.move(full_path, dest_path)
            print(f"     -> Moved {original_filename} to {dest_folder_name}")
            self.stats['processed_files_moved'] += 1

    def _reclassify_junk(self, objects):
        """Uses object detection results to potentially save a 'junk' image."""
        found_person = False
        found_view_object = False
        
        for obj in objects:
            if obj['label'] in self.OBJECT_DETECTOR_PERSON_LABELS and obj['score'] > self.OBJECT_DETECTOR_PERSON_THRESHOLD:
                found_person = True
                break # Person is highest priority
            if obj['label'] in self.OBJECT_DETECTOR_VIEW_LABELS and obj['score'] > self.OBJECT_DETECTOR_VIEW_THRESHOLD:
                found_view_object = True
        
        if found_person:
            return "family_photos"
        elif found_view_object:
            return "views"
        
        # If nothing found, it remains junk
        return "junk"

    # --- Private Methods: Inference ---

    def _run_classification_batch(self, image_batch):
        """
        Runs the zero-shot classifier (Model 1) on a batch of images.
        Returns a list of classification results.
        """
        print(f"   Running Model 1 (Zero-Shot) on {len(image_batch)} images...")
        try:
            batch_results = self.classifier(image_batch, candidate_labels=self.CATEGORIES)
            return batch_results
        except Exception as e:
            print(f"   [Error] Model 1 (Zero-Shot) failed: {e}", file=sys.stderr)
            raise  # Re-raise to be caught by the batch processor

    def _run_detection_batch(self, image_batch):
        """
        Runs the object detector (Model 2) on a batch of images.
        Returns a list of detection results.
        """
        print(f"   Running Model 2 (Object Detector) on {len(image_batch)} 'junk' candidates...")
        try:
            batch_results = self.detector(image_batch)
            return batch_results
        except Exception as e:
            print(f"   [Error] Model 2 (Object Detector) failed: {e}", file=sys.stderr)
            raise # Re-raise to be caught by the batch processor

    # --- Private Methods: Helpers & Reporting ---

    def _calculate_hash(self, file_path):
        """Calculates the SHA256 hash of a file efficiently."""
        hasher = hashlib.sha256()
        try:
            with open(file_path, 'rb') as f:
                while chunk := f.read(self.HASH_CHUNK_SIZE):
                    hasher.update(chunk)
            return hasher.hexdigest()
        except (IOError, OSError) as e:
            print(f"   [Warning] Could not hash file {file_path}: {e}", file=sys.stderr)
            return None

    def _get_unique_dest_path(self, dest_dir, file_name):
        """
        Checks if a file exists. If so, appends a counter
        (e.g., '_1', '_2') until a unique name is found.
        """
        dest_path = os.path.join(dest_dir, file_name)
        if not os.path.exists(dest_path):
            return dest_path

        base, ext = os.path.splitext(file_name)
        counter = 1
        while True:
            new_name = f"{base}_{counter}{ext}"
            new_dest_path = os.path.join(dest_dir, new_name)
            if not os.path.exists(new_dest_path):
                return new_dest_path
            counter += 1

    @staticmethod
    @contextmanager
    def _suppress_pil_warnings():
        """Context manager to suppress known PIL warnings."""
        import warnings
        warnings.filterwarnings(
            "ignore",
            "(Possibly corrupt EXIF data|Image size.*exceeds pixel limit)"
        )
        yield
        warnings.resetwarnings()

    def _load_image_for_batch(self, path):
        """
        Safely opens, converts, and loads one image.
        Returns None if the image fails to load.
        """
        try:
            with self._suppress_pil_warnings(), Image.open(path) as image:
                if image.mode == 'RGBA':
                    image_rgb = image.convert('RGB')
                else:
                    image.load() # Must load to keep in memory
                    image_rgb = image
                return image_rgb
        except Exception as e:
            print(f"   [Error] Failed to load {path}: {e}", file=sys.stderr)
            return None

    def _print_final_report(self):
        """Prints a summary of all actions taken."""
        print("\n--- Sorting Complete ---")
        print(f"Total files scanned: {self.stats['total_files_scanned']}")
        print(f"Duplicate files skipped: {self.stats['skipped_duplicates']}")
        print(f"Unique files processed: {self.stats['processed_files_moved']}")
        
        total_errors = self.stats['total_errors_pass1'] + self.stats['total_errors_pass2']
        print(f"Total errors: {total_errors}")
        
        print(f"\nCheck the folders in: {self.output_path}")




  """
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# --- Main execution ---
def main():
    """
    Main function to run the script.
    Set your source and output folders here.
    """
    
    # --- CONFIGURE YOUR PATHS HERE ---
    source_folder = "D:\images\HockingHills"
    output_folder = "D:\images\output"
    # --- END CONFIGURATION ---

    
    # Check if the user has updated the default paths
    if "C:\\path\\to" in source_folder:
        print("Error: Please update the 'source_folder' and 'output_folder' variables", file=sys.stderr)
        print("inside the main() function at the bottom of the script.", file=sys.stderr)
        return

    if not os.path.isdir(source_folder):
        print(f"Error: Source folder not found: {source_folder}", file=sys.stderr)
        return
        
    try:
        # 1. Check dependencies first
        AIPhotoSorter.check_dependencies()
        
        # 2. Create and run the sorter instance
        sorter = AIPhotoSorter(source_folder, output_folder, batch_size=32)
        sorter.run_sort()

    except Exception as e:
        print(f"\n--- [FATAL ERROR] ---", file=sys.stderr)
        print(f"An unexpected error occurred during setup: {e}", file=sys.stderr)
        traceback.print_exc(file=sys.stderr)


main()

  source_folder = "D:\images\HockingHills"
  output_folder = "D:\images\output"


Checking dependencies...
HEIF/HEIC support enabled.
All critical dependencies are present.
--- AI Photo Sorter (v4) ---
Source: D:\images\HockingHills
Destination: D:\images\output
Batch Size: 32
-------------------------
Setting up output directories...
Output directories ready.
Loading AI models... (This may take a few minutes and download files on first run)
Loading Model 1: Zero-Shot Classifier (openai/clip-vit-large-patch14)...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0


Loading Model 2: Object Detector (facebook/detr-resnet-50)...


Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0


Models loaded successfully.

--- Pass 1: Scanning for duplicates and building file list ---
Scanning: D:\images\HockingHills
Scanning: D:\images\HockingHills\videos

--- Pass 1 Complete ---
Total files scanned: 91
Duplicate files found: 0
Unique images to process: 91
Hashing/Read errors: 0

--- Pass 2: Processing unique images in batches ---

Processing Batch 1 / 3 (Size: 32)...
   Running Model 1 (Zero-Shot) on 32 images...
   Running Model 2 (Object Detector) on 3 'junk' candidates...
   Categorizing and moving files...
     -> Moved AILF6859.JPG to family_photos
     -> Moved AKIU2926.JPG to views
     -> Moved APWO0546.JPG to family_photos
     -> Moved BAXC5641.JPG to junk
     -> Moved BWVB7885.JPG to family_photos
     -> Moved DPTE4051.JPG to views
     -> Moved EKJO8625.JPG to views
     -> Moved EMXD5711.JPG to family_photos
     -> Moved EUJU2177.JPG to family_photos
     -> Moved FGCU1920.JPG to family_photos
     -> Moved GAHE9208.JPG to family_photos
     -> Moved HNSX664