In [None]:
import os
import hashlib
import shutil
import sys
from typing import Dict, List, Optional, Generator, Any
from PIL import Image

import numpy as np
from insightface.app import FaceAnalysis

# --- HEIC/HEIF Support ---
# This is required to make PIL.Image.open() support HEIC/HEIF formats.
# You must install this library: pip install pillow-heif
try:
    import pillow_heif
    pillow_heif.register_heif_opener()
    print("HEIC/HEIF support enabled.")
except ImportError:
    print("Warning: 'pillow-heif' not installed. HEIC/HEIF files will be skipped.")
    print("Install with: pip install pillow-heif")
# --- End HEIC/HEIF Support ---

try:
    from insightface.app import FaceAnalysis
    print("FaceAnalysis support enabled.")
except ImportError:
    print("Warning: FaceAnalysis not installed.")
    print("pip install insightface")
# --- End HEIC/HEIF Support ---

# --- Hugging Face transformers ---
# You must install this library: pip install transformers torch
try:
    from transformers import pipeline, Pipeline
    from transformers import DetrImageProcessor, DetrForObjectDetection
    import torch
except ImportError:
    print("CRITICAL: 'transformers' or 'torch' not found.")
    print("Please install them to run this script: pip install transformers torch")
    sys.exit(1)


class ImageDataloader:
    """
    Scans a directory for unique images and provides batches for processing.

    This class is implemented as a Python generator. It does not
    inherit from torch.utils.data.DataLoader, as our use case
    requires a simple, stateful iterator.
    """
    IMAGE_EXTENSIONS: tuple = ('.jpg', '.jpeg', '.png', '.heic', '.heif')

    def __init__(self, root_dir: str, batch_size: int = 32):
        if not os.path.isdir(root_dir):
            raise ValueError(f"Root directory not found: {root_dir}")
        if batch_size <= 0:
            raise ValueError("Batch size must be greater than 0")

        self.root_dir = root_dir
        self.batch_size = batch_size
        
        # self.labels will hold the "working state" of all images
        # {image_path: "unknown"}
        self.labels: Dict[str, str] = {}
        self._scan_and_deduplicate()

    def _calculate_hash(self, filepath: str, block_size: int = 65536) -> str:
        """
        Calculates the SHA256 hash of a file's content.
        """
        sha256 = hashlib.sha256()
        try:
            with open(filepath, 'rb') as f:
                while chunk := f.read(block_size):
                    sha256.update(chunk)
            return sha256.hexdigest()
        except (IOError, OSError) as e:
            print(f"Warning: Could not read file for hashing: {filepath}. Skipping. Error: {e}")
            return ""

    def _scan_and_deduplicate(self):
        """
        Walks the root directory, finds all unique images, and
        populates self.labels with the default 'unknown' label.
        """
        print(f"Scanning directory: {self.root_dir}...")
        image_hashes: set[str] = set()
        total_files = 0
        duplicates_skipped = 0

        for root, _, files in os.walk(self.root_dir):
            for file in files:
                if not file.lower().endswith(self.IMAGE_EXTENSIONS):
                    continue

                total_files += 1
                full_path = os.path.join(root, file)
                file_hash = self._calculate_hash(full_path)

                if not file_hash:
                    continue

                if file_hash not in image_hashes:
                    image_hashes.add(file_hash)
                    # All images start as 'unknown'
                    self.labels[full_path] = "unknown"
                else:
                    duplicates_skipped += 1

        print("--- Scan Complete ---")
        print(f"Total image files found: {total_files}")
        print(f"Duplicate images skipped: {duplicates_skipped}")
        print(f"Total unique images to process: {len(self.labels)}")
        if not self.labels:
            print("Warning: No valid, unique images were found.")

    def __len__(self) -> int:
        """
        Returns the total number of unique images to be processed.
        """
        return len(self.labels)

    def __iter__(self) -> Generator[Dict[str, str], None, None]:
        """
        Yields batches of images as dictionaries {image_path: label}.
        """
        # Get a static list of paths to iterate over
        all_paths = list(self.labels.keys())
        
        for i in range(0, len(all_paths), self.batch_size):
            batch_paths = all_paths[i : i + self.batch_size]
            
            # Create the batch dict
            batch_data = {path: self.labels[path] for path in batch_paths}
            
            if batch_data:
                yield batch_data


class ImageClassifier:
    """
    Uses a "waterfall" method to classify images from a dataloader
    using multiple, chained AI models.
    """
    def __init__(self, dataloader: ImageDataloader, output_dir: str = "output"):
        self.dataloader = dataloader
        self.output_dir = output_dir
        
        # 1. Determine the device
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.pipeline_device = 0 if self.device == torch.device('cuda') else -1
        print(f"Using device: {self.device} (Pipeline device: {self.pipeline_device})")
        
        self.output_paths = {
            "people": os.path.join(output_dir, "people"),
            "view": os.path.join(output_dir, "view"),
            "unknown": os.path.join(output_dir, "unknown")
        }
        self._create_output_dirs()
        
        # --- Load Models in __init__ ---
        print("Loading models...")
        
        # Load retinaface_model 
        self.retinaface_model = self.load_retinaface_model()

        # load detr model
        self.detr_pipeline = self.load_detr_pipeline()

        # Model 1: Face Detector (Pipeline)
        try:
            self.mobilefacedet_pipeline = pipeline(
                "object-detection",
                model="d-li/mobilefacedet",
                device=self.pipeline_device
            )
            print("Loaded: d-li/mobilefacedet")
        except Exception as e:
            print(f"CRITICAL: Failed to load mobilefacedet. {e}")
            self.mobilefacedet_pipeline = None

        # Model 2: CLIP (Pipeline)
        try:
            self.clip_pipeline = pipeline(
                "zero-shot-image-classification",
                model="openai/clip-vit-large-patch14",
                device=self.pipeline_device
            )
            print("Loaded: clip-vit-large-patch14")
        except Exception as e:
            print(f"CRITICAL: Failed to load CLIP. {e}")
            self.clip_pipeline = None



        
        print("Model loading complete.")


    # --- Specific retinaface Model ---

    def load_retinaface_model(self):
        app = FaceAnalysis(name="buffalo_l")  # uses RetinaFace + ArcFace
        app.prepare(ctx_id=0, det_size=(640, 640))  # GPU: ctx_id=0
        return app

    def detect_with_RetinaFace(self, batch_data, model):
        """
        batch_data: dict {image_path: label}
        model: RetinaFace model (insightface FaceAnalysis)

        Processes only images labeled 'unknown'.
        Returns updated batch_data with detections.
        """
        updated_data = batch_data.copy()

        # 1️⃣ Filter for unknown images
        unknown_items = [(path, label) for path, label in batch_data.items() if label == "unknown"]
        if not unknown_items:
            return updated_data

        image_paths = [p for p, _ in unknown_items]

        # 2️⃣ Load all images once
        loaded_images = {}
        for img_path in image_paths:
            try:
                img = Image.open(img_path).convert("RGB")
                img_bgr = np.array(img)[:, :, ::-1]  # RGB → BGR
                loaded_images[img_path] = img_bgr
            except Exception as e:
                print(f"⚠️ Error reading {img_path}: {e}")
                updated_data[img_path] = "invalid"

        # 3️⃣ Run inference per image (insightface doesn't support batch)
        for img_path, img_bgr in loaded_images.items():
            faces = model.get(img_bgr)  # must be called one by one

            if len(faces) > 0:
                updated_data[img_path] = "people"

        return updated_data


    # --- Specific detr Model ---
    def load_detr_pipeline(self):
        try:
            model = "facebook/detr-resnet-50"
            detr_pipeline = pipeline("object-detection", model=model, device=self.pipeline_device)
            print("Loaded: facebook/detr-resnet-50")
        except Exception as e:
            print(f"CRITICAL: Failed to load DETR. {e}")
            detr_pipeline = None
        return detr_pipeline


    def detect_human_detr_pipeline(batch_data: Dict[str, str], detr_pipeline: Any) -> Dict[str, str]:
        """
        Detects humans ('person' label) in ALL images labeled 'unknown' using the 
        Hugging Face pipeline in a single batch operation and updates the labels.

        Args:
            batch_data (dict): A dictionary of {image_path: label}.
            detr_pipeline (transformers.Pipeline): The loaded DETR object detection pipeline.
            confidence_threshold (float): Minimum confidence for a detection to be considered.

        Returns:
            dict: The updated batch_data dictionary.
        """
        # DETR is trained on COCO, where the label for a human is 'person'
        PERSON_LABEL = 'person' 
        confidence_threshold = 0.9

        if detr_pipeline is None:
            print("ERROR: Pipeline is not loaded. Cannot process data.")
            return batch_data

        # 1. Identify images to process, load, and validate paths
        unknown_paths = [path for path, label in batch_data.items() if label == 'unknown']
        
        if not unknown_paths:
            print("No images labeled 'unknown' found. Returning original data.")
            return batch_data
        
        print(f"\nProcessing all {len(unknown_paths)} 'unknown' images using the pipeline...")

        # Load PIL Images for the pipeline
        batch_images: List[Image.Image] = []
        valid_paths: List[str] = []
        for path in unknown_paths:
            try:
                if not os.path.exists(path):
                    print(f"⚠️ Warning: Image not found at {path}. Skipping.")
                    continue
                batch_images.append(Image.open(path).convert("RGB"))
                valid_paths.append(path)
            except Exception as e:
                print(f"❌ An error occurred loading {path}: {e}. Skipping.")

        if not valid_paths:
            print("No valid images could be loaded. Returning original data.")
            return batch_data

        # 2. Perform single batch inference
        # The pipeline handles moving data to the device defined during load.
        # We pass the list of PIL images directly.
        try:
            # The result is a list of lists: [[det1, det2, ...], [det1, ...], ...]
            results: List[List[Dict[str, Any]]] = detr_pipeline(batch_images)
        except Exception as e:
            print(f"\nRUNTIME ERROR during pipeline execution: {e}")
            print("The batch might be too large for available memory.")
            return batch_data

        # 3. Post-process the results
        for idx, image_results in enumerate(results):
            image_path = valid_paths[idx]
            is_human_detected = False
            
            # image_results is a list of dictionaries (one for each detection)
            for detection in image_results:
                # The pipeline provides the score and the label (e.g., 'person')
                if detection['score'] >= confidence_threshold and detection['label'] == PERSON_LABEL:
                    is_human_detected = True
                    # print(f"✅ Detected {PERSON_LABEL} in: {image_path} with score {detection['score']:.2f}")
                    break # Found a person, no need to check other detections for this image

            if is_human_detected:
                batch_data[image_path] = 'people'


        print(f"\nSuccessfully processed {len(valid_paths)} images using the pipeline.")
        return batch_data


    def _create_output_dirs(self):
        """Creates the output directories if they don't exist."""
        print(f"Ensuring output directories exist at: {self.output_dir}")
        for path in self.output_paths.values():
            os.makedirs(path, exist_ok=True)

    def _move_file(self, image_path: str, new_label: str):
        """Moves a file to its new classified directory."""
        if new_label not in self.output_paths:
            print(f"Warning: Unknown label '{new_label}'. Cannot move file.")
            return

        target_dir = self.output_paths[new_label]
        filename = os.path.basename(image_path)
        target_path = os.path.join(target_dir, filename)
        
        i = 1
        while os.path.exists(target_path):
            name, ext = os.path.splitext(filename)
            target_path = os.path.join(target_dir, f"{name}_{i}{ext}")
            i += 1
            
        try:
            shutil.move(image_path, target_path)
        except Exception as e:
            print(f"Error moving file {image_path} to {target_path}. Error: {e}")

    def classify_images(self):
        """
        The main classification loop (orchestrator).
        Iterates through all batches and applies the waterfall logic.
        """
        print("\n--- Starting Image Classification Waterfall ---")
        total_batches = (len(self.dataloader) + self.dataloader.batch_size - 1) // self.dataloader.batch_size
        
        for i, initial_batch in enumerate(self.dataloader):
            print(f"\nProcessing Batch {i+1} / {total_batches} (Size: {len(initial_batch)})...")
            
            remaining_batch = initial_batch.copy()

            # --- STAGE 1: Detect faces ---
            remaining_batch = self.detect_human_detr_pipeline(remaining_batch, self.retinaface_model)
            # --- STAGE 1: Detect people ---
            remaining_batch = self.detect_with_RetinaFace(remaining_batch, self.retinaface_model)

            
            # --- STAGE 2: View Models ---
            if self.clip_pipeline:
                remaining_batch = self.classify_with_clip(
                    remaining_batch, self.clip_pipeline,
                    target_label="view",
                    prompts=["a photo of a landscape", "a beautiful view", "a flower", "a photo of food", "a city skyline", "a beach", "mountains", "a forest"]
                )

            if self.detr_pipeline:
                remaining_batch = self.classify_with_detr(
                    remaining_batch, self.detr_pipeline,
                    target_label="view",
                    target_classes={"flower", "bird", "cat", "dog", "horse", "pizza", "donut", "cake", "boat", "airplane", "bench", "car", "bus"}
                )

            # --- STAGE 3: Custom Models (Example) ---
            # if self.my_custom_model:
            #     remaining_batch = self.classify_with_custom_model(remaining_batch, self.my_custom_model)

            # --- STAGE 4: unknown ---
            print(f"Moving {len(remaining_batch)} remaining images to 'unknown'...")
            for path in remaining_batch.keys():
                self._move_file(path, "unknown")
                
            print(f"Batch {i+1} complete.")

        print("\n--- Image Classification Finished ---")

    # --- Specific Model Classification Methods ---



    def classify_with_mobilefacedet(self, batch_data: Dict[str, str], model: Pipeline) -> Dict[str, str]:
        """Runs the mobilefacedet pipeline."""
        if not batch_data: return {}
            
        print(f"  Running 'mobilefacedet' on {len(batch_data)} images...")
        paths = list(batch_data.keys())
        
        try:
            results = model(paths, batch_size=self.dataloader.batch_size)
        except Exception as e:
            print(f"    ERROR: Model 'mobilefacedet' failed. {e}")
            return batch_data

        remaining_batch = batch_data.copy()
        moved_count = 0
        for path, result_list in zip(paths, results):
            if result_list: # Found a face
                self._move_file(path, "people")
                del remaining_batch[path]
                moved_count += 1
                
        print(f"    Moved {moved_count} images to 'people'.")
        return remaining_batch

    def classify_with_clip(self, batch_data: Dict[str, str], model: Pipeline, target_label: str, prompts: List[str]) -> Dict[str, str]:
        """Runs the CLIP zero-shot pipeline."""
        if not batch_data: return {}

        print(f"  Running 'CLIP' for '{target_label}' on {len(batch_data)} images...")
        paths = list(batch_data.keys())

        try:
            results = model(paths, candidate_labels=prompts, batch_size=self.dataloader.batch_size)
        except Exception as e:
            print(f"    ERROR: Model 'CLIP' failed. {e}")
            return batch_data

        remaining_batch = batch_data.copy()
        moved_count = 0
        for path, result_list in zip(paths, results):
            top_label = result_list[0]['label']
            top_score = result_list[0]['score']

            if top_label in prompts and top_score > 0.8: # Confidence threshold
                self._move_file(path, target_label)
                del remaining_batch[path]
                moved_count += 1
        
        print(f"    Moved {moved_count} images to '{target_label}'.")
        return remaining_batch






HEIC/HEIF support enabled.
pip install insightface


  from .autonotebook import tqdm as notebook_tqdm


In [25]:

test_dir = r"D:\images\HockingHills"
output = r"D:\images\output"

# 1. Instantiate the Dataloader
# Using a small batch size for the example
dataloader = ImageDataloader(root_dir=test_dir, batch_size=32)




# for batch in dataloader:
#     # 2. Run the detection
#     result = detect_human_batch(batch.copy(), processor, model)


# for key in batch:
#     print(key, batch[key], result[key])

Scanning directory: D:\images\HockingHills...
--- Scan Complete ---
Total image files found: 91
Duplicate images skipped: 0
Total unique images to process: 91


In [None]:
# 4. Run the classification
# This will use the *actual* AI models on the dummy images.
# The dummy images won't be classified correctly, but this
# demonstrates the file scanning, de-duplication, batching,
# model loading, and file moving logic.
try:
    classifier.classify_images()
except Exception as e:
    print(f"\nAn error occurred during classification: {e}")
    print("This can happen if you are offline or models are unavailable.")

# 5. Print a summary
print("\n--- Final Output Summary ---")
try:
    for category in os.listdir(output):
        category_path = os.path.join(output, category)
        if os.path.isdir(category_path):
            files = os.listdir(category_path)
            print(f"Files in '{category}': {len(files)} {files}")
except FileNotFoundError:
    print("Output directory 'classification_output' not found. Did the script run?")