In [None]:
#!/usr/bin/env python3

"""
AI Photo Sorter (v3)
Sorts a directory of images using a two-stage "waterfall" classification
and proper, fixed-size batching for memory safety and performance.

This version corrects a major flaw from v2 and is safe to run on
very large photo libraries (e.g., 20,000+ images) without
crashing due to "Out of Memory" errors.

REQUIREMENTS:
pip install torch transformers Pillow pillow-heif

--- USAGE IN JUPYTER NOTEBOOK ---
1. Make sure you have run: !pip install torch transformers Pillow pillow-heif
2. Define your paths in a notebook cell:

   source_folder = r"C:\path\to\your\iphone_photos_copy"
   output_folder = r"C:\path\to\my_sorted_photos"

3. Call the function in the next cell:

   if 'source_folder' in locals() and os.path.isdir(source_folder):
       sort_images(source_folder, output_folder)
   else:
       print("Please define 'source_folder' and 'output_folder' in your notebook cell.")
       print("See the docstring at the top of this file for an example.")
"""

import os
import sys
import shutil
import hashlib
from contextlib import contextmanager

# --- Dependency Checking ---
try:
    from PIL import Image
except ImportError:
    print("Error: 'Pillow' library not found. Please install with: pip install Pillow", file=sys.stderr)
    sys.exit(1)

try:
    import torch
except ImportError:
    print("Error: 'torch' library not found. Please install with: pip install torch", file=sys.stderr)
    sys.exit(1)

try:
    from transformers import pipeline
except ImportError:
    print("Error: 'transformers' library not found. Please install with: pip install transformers", file=sys.stderr)
    sys.exit(1)

try:
    import pillow_heif
    # Register the HEIF plugin with PIL
    pillow_heif.register_heif_opener()
except ImportError:
    print("Warning: 'pillow-heif' library not found. HEIC/HEIF files will not be processed.", file=sys.stderr)
    print("Install with: pip install pillow-heif", file=sys.stderr)


# --- Constants ---

# --- NEW: BATCH_SIZE ---
# This is the number of images to load into memory and process at a time.
# - Increase (e.g., 64, 128) if you have a powerful GPU with lots of VRAM.
# - Decrease (e.g., 16, 8) if you get "Out of Memory" errors or have an older GPU.
BATCH_SIZE = 32

# The AI model will categorize images based on these text labels.
CATEGORIES = [
    "a photo of a person, face, or family",
    "a scenic landscape, mountain, beach, or flower",
    "a photo of delicious food or a meal",
    "a screenshot of a phone or computer screen",
    "a photo of a receipt, document, or white-board",
    "a graphic, logo, or drawing",
    "a blurry or dark photo"
]

# Mapping from the AI's category label to our destination folder names
CATEGORY_TO_FOLDER = {
    "a photo of a person, face, or family": "family_photos",
    "a scenic landscape, mountain, beach, or flower": "views",
    "a photo of delicious food or a meal": "views",
    "a screenshot of a phone or computer screen": "junk",
    "a photo of a receipt, document, or white-board": "junk",
    "a graphic, logo, or drawing": "junk",
    "a blurry or dark photo": "junk"
}

# --- Constants for 2nd Model (Object Detection) ---
# ... (same as v2) ...
OBJECT_DETECTOR_PERSON_LABELS = {'person'}
OBJECT_DETECTOR_PERSON_THRESHOLD = 0.8  # Be pretty confident it's a person
OBJECT_DETECTOR_VIEW_LABELS = {
    'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'dining table', 'bowl', 'cup', 'fork', 'knife', 'spoon',
    'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
    'parking meter', 'bench', 'motorcycle', 'bicycle',
    'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'frisbee'
}
OBJECT_DETECTOR_VIEW_THRESHOLD = 0.7 # Moderately confident

IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.heic', '.heif')
HASH_CHUNK_SIZE = 8192  # Read files in 8KB chunks for hashing

# --- Helper Functions ---

def calculate_hash(file_path):
    """Calculates the SHA256 hash of a file efficiently."""
    hasher = hashlib.sha256()
    try:
        with open(file_path, 'rb') as f:
            while chunk := f.read(HASH_CHUNK_SIZE):
                hasher.update(chunk)
        return hasher.hexdigest()
    except (IOError, OSError) as e:
        print(f"  [Warning] Could not hash file {file_path}: {e}", file=sys.stderr)
        return None

def get_unique_dest_path(dest_dir, file_name):
    """
    Checks if a file exists at the destination. If so, appends
    a counter (e.g., '_1', '_2') until a unique name is found.
    """
    dest_path = os.path.join(dest_dir, file_name)
    
    if not os.path.exists(dest_path):
        return dest_path

    # Handle collision
    base, ext = os.path.splitext(file_name)
    counter = 1
    while True:
        new_name = f"{base}_{counter}{ext}"
        new_dest_path = os.path.join(dest_dir, new_name)
        if not os.path.exists(new_dest_path):
            return new_dest_path
        counter += 1

@contextmanager
def suppress_pil_warnings():
    """Context manager to suppress known PIL warnings like 'Possibly corrupt EXIF data'."""
    import warnings
    from PIL import Image
    warnings.filterwarnings(
        "ignore",
        "(Possibly corrupt EXIF data|Image size.*exceeds pixel limit)"
    )
    yield
    warnings.resetwarnings()

def load_image_for_batch(path):
    """
    Safely opens, converts, and loads one image.
    Returns None if the image fails to load.
    """
    try:
        with suppress_pil_warnings(), Image.open(path) as image:
            if image.mode == 'RGBA':
                image_rgb = image.convert('RGB')
            else:
                image.load() # Must load to keep in memory
                image_rgb = image
            return image_rgb
    except Exception as e:
        print(f"  [Error] Failed to load {path}: {e}", file=sys.stderr)
        return None

# --- Main Function ---

def sort_images(root_path, output_path):
    """
    Recursively scans the root_path, classifies images, and moves them
    to the output_path, sorted into subdirectories.
    """
    
    # 1. Setup Output Directories
    output_dirs = {
        "family_photos": os.path.join(output_path, "family_photos"),
        "views": os.path.join(output_path, "views"),
        "junk": os.path.join(output_path, "junk")
    }

    for dir_path in output_dirs.values():
        os.makedirs(dir_path, exist_ok=True)

    print("--- AI Photo Sorter (v3) ---")
    print(f"Source: {root_path}")
    print(f"Destination: {output_path}")
    print(f"Batch Size: {BATCH_SIZE}")
    print("-" * 25)

    # 2. Load AI Models
    print("Loading AI models... (This may take a few minutes and download files on first run)")
    try:
        device = 0 if torch.cuda.is_available() else -1
        
        print("Loading Model 1: Zero-Shot Classifier (openai/clip-vit-large-patch14)...")
        classifier = pipeline(
            "zero-shot-image-classification",
            model="openai/clip-vit-large-patch14",
            device=device
        )
        
        print("Loading Model 2: Object Detector (facebook/detr-resnet-50)...")
        detector = pipeline(
            "object-detection",
            model="facebook/detr-resnet-50",
            device=device
        )
        
        print("Models loaded successfully.")
    except Exception as e:
        print(f"Error loading models: {e}", file=sys.stderr)
        print("Please ensure you have an internet connection and 'transformers' and 'torch' are installed.", file=sys.stderr)
        return

    # --- NEW: Pass 1: Pre-scan for duplicates and build master file list ---
    print("\n--- Pass 1: Scanning for duplicates and building file list ---")
    
    seen_hashes = {}  # { 'hash_value': 'path_to_first_file_seen' }
    unique_image_paths = [] # Master list of paths to process
    total_files_scanned = 0
    skipped_duplicates = 0
    total_errors_pass1 = 0

    for dirpath, _, filenames in os.walk(root_path):
        print(f"Scanning: {dirpath}")
        for filename in filenames:
            if not filename.lower().endswith(IMAGE_EXTENSIONS):
                continue
            
            total_files_scanned += 1
            full_path = os.path.join(dirpath, filename)

            file_hash = calculate_hash(full_path)
            if not file_hash:
                total_errors_pass1 += 1
                continue  # Skip if hashing failed

            if file_hash in seen_hashes:
                print(f"  -> Duplicate of: {seen_hashes[file_hash]} (Skipping)")
                skipped_duplicates += 1
                continue
            else:
                seen_hashes[file_hash] = full_path
                unique_image_paths.append(full_path) # Add to our master list

    print(f"\n--- Pass 1 Complete ---")
    print(f"Total files scanned: {total_files_scanned}")
    print(f"Duplicate files found: {skipped_duplicates}")
    print(f"Unique images to process: {len(unique_image_paths)}")
    print(f"Hashing/Read errors: {total_errors_pass1}")


    # --- NEW: Pass 2: Process unique images in fixed-size batches ---
    print("\n--- Pass 2: Processing unique images in batches ---")
    
    processed_files_moved = 0
    total_errors_pass2 = 0
    num_batches = (len(unique_image_paths) + BATCH_SIZE - 1) // BATCH_SIZE # Calculate total batches

    for i in range(0, len(unique_image_paths), BATCH_SIZE):
        
        # Get the paths for this one batch
        batch_paths = unique_image_paths[i : i + BATCH_SIZE]
        
        print(f"\nProcessing Batch {i//BATCH_SIZE + 1} / {num_batches} (Size: {len(batch_paths)})...")

        # Load image objects for this batch
        batch_image_objects = []
        # Keep track of paths, as some images might fail to load
        valid_paths_in_batch = [] 
        
        for path in batch_paths:
            image = load_image_for_batch(path)
            if image:
                batch_image_objects.append(image)
                valid_paths_in_batch.append(path)
            else:
                total_errors_pass2 += 1
        
        if not batch_image_objects:
            print("  -> All images in batch failed to load. Skipping.")
            continue
        
        try:
            # C) Classify Batch with Model 1 (CLIP)
            print(f"  Running Model 1 (Zero-Shot) on {len(batch_image_objects)} images...")
            batch_clip_results = classifier(batch_image_objects, candidate_labels=CATEGORIES)

            # D) Check "junk" with Model 2 (DETR)
            # Find which images Model 1 flagged as junk
            junk_indices = [
                idx for idx, result in enumerate(batch_clip_results)
                if CATEGORY_TO_FOLDER.get(result[0]['label'], "junk") == "junk"
            ]
            
            junk_images_to_check = [batch_image_objects[idx] for idx in junk_indices]
            batch_detr_results = {} # Will store {index: [results]}
            
            if junk_images_to_check:
                print(f"  Running Model 2 (Object Detector) on {len(junk_images_to_check)} 'junk' candidates...")
                # Run detection on all junk candidates at once
                detr_results_list = detector(junk_images_to_check)
                
                # Map results back to their original batch index
                for j, detr_result in enumerate(detr_results_list):
                    original_batch_index = junk_indices[j]
                    batch_detr_results[original_batch_index] = detr_result
            

            # E) Iterate over batch results, decide destination, and move files
            print("  Categorizing and moving files...")
            for j in range(len(valid_paths_in_batch)):
                full_path = valid_paths_in_batch[j]
                original_filename = os.path.basename(full_path)
                clip_results = batch_clip_results[j]
                
                top_label = clip_results[0]['label']
                dest_folder_name = CATEGORY_TO_FOLDER.get(top_label, "junk")

                if dest_folder_name == "junk":
                    # This image was in the junk pile, check Model 2's results
                    objects = batch_detr_results.get(j, []) # Get results, or empty list
                    
                    found_person = False
                    found_view_object = None
                    
                    for obj in objects:
                        if obj['label'] in OBJECT_DETECTOR_PERSON_LABELS and obj['score'] > OBJECT_DETECTOR_PERSON_THRESHOLD:
                            found_person = True
                            break
                        if obj['label'] in OBJECT_DETECTOR_VIEW_LABELS and obj['score'] > OBJECT_DETECTOR_VIEW_THRESHOLD:
                            found_view_object = obj['label']
                    
                    if found_person:
                        dest_folder_name = "family_photos"
                    elif found_view_object:
                        dest_folder_name = "views"

                # F) Handle Filename Collisions and Move
                dest_dir = output_dirs[dest_folder_name]
                dest_path = get_unique_dest_path(dest_dir, original_filename)
                
                shutil.move(full_path, dest_path)
                print(f"    -> Moved {original_filename} to {dest_folder_name}")
                processed_files_moved += 1
        
        except Exception as e:
            print(f"  [CRITICAL BATCH ERROR] Failed to process batch {i//BATCH_SIZE + 1}: {e}", file=sys.stderr)
            print("    -> Skipping this entire batch. These files will not be moved.", file=sys.stderr)
            total_errors_pass2 += len(valid_paths_in_batch)
            
        # Clear memory (Python's garbage collector will handle this,
        # but being explicit isn't bad)
        del batch_image_objects
        del valid_paths_in_batch
        del batch_clip_results
        del batch_detr_results


    # 5. Final Report
    print("\n--- Sorting Complete ---")
    print(f"Total files scanned: {total_files_scanned}")
    print(f"Duplicate files skipped: {skipped_duplicates}")
    print(f"Unique files processed: {processed_files_moved}")
    print(f"Total errors (Pass 1 + Pass 2): {total_errors_pass1 + total_errors_pass2}")
    print(f"\nCheck the folders in: {output_path}")




In [None]:
# --- EXAMPLE USAGE IN JUPYTER NOTEBOOK ---
# 1. Make sure you have run: !pip install torch transformers Pillow pillow-heif
# 2. Define your paths in a notebook cell:
#
source_folder = "images"
output_folder = "output"
#
# 3. Call the function in the next cell:
#
if 'source_folder' in locals() and os.path.isdir(source_folder):
    print("Starting photo sort...")
    sort_images(source_folder, output_folder)
else:
    print("Please define 'source_folder' and 'output_folder' in your notebook cell.")
    print("See the docstring at the top of this file for an example.")