In [1]:
import os
import hashlib
import shutil
import sys
from typing import Dict, List, Optional, Generator, Any
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter  
import numpy as np
from insightface.app import FaceAnalysis

# --- HEIC/HEIF Support ---
# This is required to make PIL.Image.open() support HEIC/HEIF formats.
# You must install this library: pip install pillow-heif
try:
    import pillow_heif
    pillow_heif.register_heif_opener()
    print("HEIC/HEIF support enabled.")
except ImportError:
    print("Warning: 'pillow-heif' not installed. HEIC/HEIF files will be skipped.")
    print("Install with: pip install pillow-heif")
# --- End HEIC/HEIF Support ---

try:
    from insightface.app import FaceAnalysis
    print("FaceAnalysis support enabled.")
except ImportError:
    print("Warning: FaceAnalysis not installed.")
    print("pip install insightface")
# --- End HEIC/HEIF Support ---

# --- Hugging Face transformers ---
# You must install this library: pip install transformers torch
try:
    from transformers import pipeline, Pipeline
    from transformers import DetrImageProcessor, DetrForObjectDetection
    import torch
except ImportError:
    print("CRITICAL: 'transformers' or 'torch' not found.")
    print("Please install them to run this script: pip install transformers torch")
    sys.exit(1)


class ImageDataloader:
    """
    Scans a directory for unique images and provides batches for processing.

    This class is implemented as a Python generator. It does not
    inherit from torch.utils.data.DataLoader, as our use case
    requires a simple, stateful iterator.
    """
    IMAGE_EXTENSIONS: tuple = ('.jpg', '.jpeg', '.png', '.heic', '.heif')

    def __init__(self, root_dir: str, batch_size: int = 32):
        if not os.path.isdir(root_dir):
            raise ValueError(f"Root directory not found: {root_dir}")
        if batch_size <= 0:
            raise ValueError("Batch size must be greater than 0")

        self.root_dir = root_dir
        self.batch_size = batch_size
        
        # self.labels will hold the "working state" of all images
        # {image_path: "unknown"}
        self.labels: Dict[str, str] = {}
        self._scan_and_deduplicate()

    def _calculate_hash(self, filepath: str, block_size: int = 65536) -> str:
        """
        Calculates the SHA256 hash of a file's content.
        """
        sha256 = hashlib.sha256()
        try:
            with open(filepath, 'rb') as f:
                while chunk := f.read(block_size):
                    sha256.update(chunk)
            return sha256.hexdigest()
        except (IOError, OSError) as e:
            print(f"Warning: Could not read file for hashing: {filepath}. Skipping. Error: {e}")
            return ""

    def _scan_and_deduplicate(self):
        """
        Walks the root directory, finds all unique images, and
        populates self.labels with the default 'unknown' label.
        """
        print(f"Scanning directory: {self.root_dir}...")
        image_hashes: set[str] = set()
        total_files = 0
        duplicates_skipped = 0

        for root, _, files in os.walk(self.root_dir):
            for file in files:
                if not file.lower().endswith(self.IMAGE_EXTENSIONS):
                    continue

                total_files += 1
                full_path = os.path.join(root, file)
                file_hash = self._calculate_hash(full_path)

                if not file_hash:
                    continue

                if file_hash not in image_hashes:
                    image_hashes.add(file_hash)
                    # All images start as 'unknown'
                    self.labels[full_path] = "unknown"
                else:
                    duplicates_skipped += 1

        print("--- Scan Complete ---")
        print(f"Total image files found: {total_files}")
        print(f"Duplicate images skipped: {duplicates_skipped}")
        print(f"Total unique images to process: {len(self.labels)}")
        if not self.labels:
            print("Warning: No valid, unique images were found.")

    def __len__(self) -> int:
        """
        Returns the total number of unique images to be processed.
        """
        return len(self.labels)

    def __iter__(self) -> Generator[Dict[str, str], None, None]:
        """
        Yields batches of images as dictionaries {image_path: label}.
        """
        # Get a static list of paths to iterate over
        all_paths = list(self.labels.keys())
        
        for i in range(0, len(all_paths), self.batch_size):
            batch_paths = all_paths[i : i + self.batch_size]
            
            # Create the batch dict
            batch_data = {path: self.labels[path] for path in batch_paths}
            
            if batch_data:
                yield batch_data


class ImageClassifier:
    """
    Uses a "waterfall" method to classify images from a dataloader
    using multiple, chained AI models.
    """
    def __init__(self, dataloader: ImageDataloader, output_dir: str = "output"):
        self.dataloader = dataloader
        self.output_dir = output_dir
        
        # this variable save final prediction
        self.processed_images = {}

        # 1. Determine the device
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.pipeline_device = 0 if self.device == torch.device('cuda') else -1
        print(f"Using device: {self.device} (Pipeline device: {self.pipeline_device})")
        
        self.output_paths = {
            "people": os.path.join(output_dir, "people"),
            "view": os.path.join(output_dir, "view"),
            "unknown": os.path.join(output_dir, "unknown")
        }
        self._create_output_dirs()
        
        # --- Load Models in __init__ ---
        print("Loading models...")
        
        # Stage 1: Face Detector 
        self.retinaface_model = self.load_retinaface_model()
        self.detr_pipeline = self.load_detr_pipeline()

        # Stage 2: View Detector 
        self.clip_pipeline = self.load_clip_pipeline()

        print("Model loading complete.")


    def _create_output_dirs(self):
        """Creates the output directories if they don't exist."""
        print(f"Ensuring output directories exist at: {self.output_dir}")
        for path in self.output_paths.values():
            os.makedirs(path, exist_ok=True)


    def _move_file(self, batch_data: Dict[str, str], action='move'):
        """
        batch_data: dict {image_path: label}

        Moves a batch of image files to their classified directories.
        """
        assert action in ['move', 'copy']
        # Iterate over each image path and its assigned label in the dictionary
        for image_path, new_label in batch_data.items():
            try:
                # 1. Get destination directory. Default to 'unknown' if label is invalid.
                dest_dir = self.output_paths.get(new_label, self.output_paths["unknown"])
                
                # 2. Get the base filename from the full path
                filename = os.path.basename(image_path)
                
                # 3. Create the full destination path
                dest_path = os.path.join(dest_dir, filename)
                

                # 5. Move the file
                if action=='move':
                    shutil.move(image_path, dest_path)
                else:
                    shutil.copy2(image_path, dest_path)


            except FileNotFoundError:
                print(f"File not found, cannot move: {image_path}")
            except PermissionError:
                print(f"Permission denied, cannot move: {image_path}")
            except Exception as e:
                print(f"Failed to move {image_path}: {e}")


    def _display_batch(self, batch_data: Dict[str, str], show_image=False):
        """
        Displays a summary of batch data (category counts) and optionally
        renders each image with its category as the title.

        Args:
            batch_data (Dict[str, str]): Dictionary where keys are file paths 
                                        and values are the categories/labels.
            show_image (bool): A flag to control whether to display images.
        """
        # --- Batch Summary Section ---
        
        # Use collections.Counter to efficiently count the occurrences of each 
        # unique value (category) in the batch_data dictionary.
        value_counts = Counter(batch_data.values())
        
        print("Counts for each unique value in this batch:")
        
        # Loop through the resulting Counter object (value_counts.items())
        for value, count in value_counts.items():
            # Prints each unique value and its total count
            print(f"  Image category '{value}': {count} time(s)")
            
        print("-----------------------------------------")
        
        # --- Image Display Section ---
        
        # This entire block is conditional. It only runs if show_image=True
        if show_image:
            for key in batch_data:
                # Print the file path being processed
                print(key)
                plt.title(batch_data[key])
                plt.imshow(Image.open(key).convert('RGB'))
                plt.show()


    def classify_images(self):
        """
        The main classification loop (orchestrator).
        Iterates through all batches and applies the waterfall logic.
        """
        print("\n--- Starting Image Classification Waterfall ---")
        total_batches = (len(self.dataloader) + self.dataloader.batch_size - 1) // self.dataloader.batch_size
        
        for i, initial_batch in enumerate(self.dataloader):
            print(f"\nProcessing Batch {i+1} / {total_batches} (Size: {len(initial_batch)})...")
            
            remaining_batch = initial_batch.copy()

            # --- STAGE 1: Detect faces ---
            if self.retinaface_model:
                remaining_batch = self.detect_human_detr_pipeline(remaining_batch, self.detr_pipeline)
            
            # --- STAGE 1: Detect people ---
            if self.retinaface_model:
                remaining_batch = self.detect_with_RetinaFace(remaining_batch, self.retinaface_model)

            # --- STAGE 2: View Models ---
            if self.clip_pipeline:
                remaining_batch = self.detect_view_clip_pipeline(
                    remaining_batch, self.clip_pipeline,
                    target_label="view",
                    prompts=["a photo of a landscape", "a beautiful view", "a flower", "a photo of food", "a city skyline", "a beach", "mountains", "a forest"]
                )


            self._move_file(remaining_batch, 'copy')
            self.processed_images |=remaining_batch
            print(f"Batch {i+1} complete.")
        
        self._display_batch(self.processed_images)

        print("\n--- Image Classification Finished ---")

    # --- Specific Model Classification Methods ---
    # --- Specific retinaface Model ---
    def load_retinaface_model(self):
        # try:
        #     app = FaceAnalysis(name="buffalo_l")  # uses RetinaFace + ArcFace
        #     app.prepare(ctx_id=0, det_size=(640, 640))  # GPU: ctx_id=0
        #     print("Loaded: retinaface")
        # except Exception as e:
        #     print(f"CRITICAL: Failed to load retinaface. {e}")
        #     app = None

        app = FaceAnalysis(name="buffalo_l")  # uses RetinaFace + ArcFace
        app.prepare(ctx_id=0, det_size=(640, 640))  # GPU: ctx_id=0
        return app
    

    def detect_with_RetinaFace(self, batch_data, model, target_label='people'):
        """
        batch_data: dict {image_path: label}
        model: RetinaFace model (insightface FaceAnalysis)

        Processes only images labeled 'unknown'.
        Returns updated batch_data with detections.
        """
        updated_data = batch_data.copy()

        # 1️⃣ Filter for unknown images
        unknown_items = [(path, label) for path, label in batch_data.items() if label == "unknown"]
        if not unknown_items:
            return updated_data
        
        print(f"\nProcessing {len(unknown_items)} 'unknown' images using the RetinaFace model...")

        image_paths = [p for p, _ in unknown_items]

        # 2️⃣ Load all images once
        loaded_images = {}
        for img_path in image_paths:
            try:
                img = Image.open(img_path).convert("RGB")
                img_bgr = np.array(img)[:, :, ::-1]  # RGB → BGR
                loaded_images[img_path] = img_bgr
            except Exception as e:
                print(f"⚠️ Error reading {img_path}: {e}")
                updated_data[img_path] = "invalid"

        update_count=0
        # 3️⃣ Run inference per image (insightface doesn't support batch)
        for img_path, img_bgr in loaded_images.items():
            faces = model.get(img_bgr)  # must be called one by one

            if len(faces) > 0:
                updated_data[img_path] = target_label
                update_count+=1

        print(f"\nSuccessfully processed {len(unknown_items)} images using detr. Updated {update_count} labels to '{target_label}'.")
        return updated_data

    # --- Specific detr Model ---
    def load_detr_pipeline(self):
        try:
            model = "facebook/detr-resnet-50"
            detr_pipeline = pipeline("object-detection", model=model, device=self.pipeline_device)
            print("Loaded: facebook/detr-resnet-50")
        except Exception as e:
            print(f"CRITICAL: Failed to load DETR. {e}")
            detr_pipeline = None
        return detr_pipeline

    def detect_human_detr_pipeline(self, batch_data: Dict[str, str], detr_pipeline: Any, target_label='people') -> Dict[str, str]:
        """
        Detects humans ('person' label) in ALL images labeled 'unknown' using the 
        Hugging Face pipeline in a single batch operation and updates the labels.

        Args:
            batch_data (dict): A dictionary of {image_path: label}.
            detr_pipeline (transformers.Pipeline): The loaded DETR object detection pipeline.
            confidence_threshold (float): Minimum confidence for a detection to be considered.

        Returns:
            dict: The updated batch_data dictionary.
        """
        # DETR is trained on COCO, where the label for a human is 'person'
        PERSON_LABEL = 'person' 
        confidence_threshold = 0.9

        if detr_pipeline is None:
            print("ERROR: Pipeline is not loaded. Cannot process data.")
            return batch_data

        # 1. Identify images to process, load, and validate paths
        unknown_paths = [path for path, label in batch_data.items() if label == 'unknown']
        
        if not unknown_paths:
            print("No images labeled 'unknown' found. Returning original data.")
            return batch_data
        
        print(f"\nProcessing all {len(unknown_paths)} 'unknown' images using the detr pipeline...")

        # Load PIL Images for the pipeline
        batch_images: List[Image.Image] = []
        valid_paths: List[str] = []
        for path in unknown_paths:
            try:
                if not os.path.exists(path):
                    print(f"⚠️ Warning: Image not found at {path}. Skipping.")
                    continue
                batch_images.append(Image.open(path).convert("RGB"))
                valid_paths.append(path)
            except Exception as e:
                print(f"❌ An error occurred loading {path}: {e}. Skipping.")

        if not valid_paths:
            print("No valid images could be loaded. Returning original data.")
            return batch_data

        # 2. Perform single batch inference
        # The pipeline handles moving data to the device defined during load.
        # We pass the list of PIL images directly.
        try:
            # The result is a list of lists: [[det1, det2, ...], [det1, ...], ...]
            results: List[List[Dict[str, Any]]] = detr_pipeline(batch_images)
        except Exception as e:
            print(f"\nRUNTIME ERROR during pipeline execution: {e}")
            print("The batch might be too large for available memory.")
            return batch_data

        update_count = 0
        # 3. Post-process the results
        for idx, image_results in enumerate(results):
            image_path = valid_paths[idx]
            is_human_detected = False
            
            # image_results is a list of dictionaries (one for each detection)
            for detection in image_results:
                # The pipeline provides the score and the label (e.g., 'person')
                if detection['score'] >= confidence_threshold and detection['label'] == PERSON_LABEL:
                    is_human_detected = True
                    # print(f"✅ Detected {PERSON_LABEL} in: {image_path} with score {detection['score']:.2f}")
                    break # Found a person, no need to check other detections for this image

            if is_human_detected:
                batch_data[image_path] = target_label
                update_count+=1

        print(f"\nSuccessfully processed {len(valid_paths)} images using detr. Updated {update_count} labels to '{target_label}'.")

        return batch_data

    # --- Specific clip Model ---
    def load_clip_pipeline(self):
        """
        Loads the CLIP zero-shot classification pipeline.
        """
        try:
            clip_pipeline = pipeline(
                "zero-shot-image-classification",
                model="openai/clip-vit-large-patch14",
                device=self.pipeline_device  
            )
            print("Loaded: clip-vit-large-patch14")
            return clip_pipeline  # <-- FIX: You must return the loaded pipeline
        except Exception as e:
            print(f"CRITICAL: Failed to load CLIP. {e}")
            return None # Return None on failure

    def detect_view_clip_pipeline(self, batch_data: Dict[str, str], clip_pipeline: Any, target_label: str, prompts: List[str], confidence_threshold: float = 0.80) -> Dict[str, str]:
        """
        Classifies images labeled 'unknown' using the CLIP pipeline and a list of prompts.
        
        If the top-scoring prompt meets the confidence threshold, the image label
        is updated to the single `target_label`.

        Args:
            batch_data (dict): A dictionary of {image_path: label}.
            clip_pipeline (transformers.Pipeline): The loaded CLIP pipeline.
            target_label (str): The new label to assign if a match is found (e.g., 'view').
            prompts (list): A list of candidate labels to check against (e.g., "a flower").
            confidence_threshold (float): Minimum confidence for a classification.

        Returns:
            dict: The updated batch_data dictionary.
        """
        if clip_pipeline is None:
            print("ERROR: CLIP Pipeline is not loaded. Cannot process data.")
            return batch_data

        # 1. Identify images to process, load, and validate paths
        unknown_paths = [path for path, label in batch_data.items() if label == 'unknown']
        
        if not unknown_paths:
            print("No images labeled 'unknown' found for CLIP processing. Returning original data.")
            return batch_data
        
        print(f"\nProcessing {len(unknown_paths)} 'unknown' images using CLIP with {len(prompts)} prompts...")

        # Load PIL Images for the pipeline
        batch_images: List[Image.Image] = []
        valid_paths: List[str] = []
        for path in unknown_paths:
            try:
                if not os.path.exists(path):
                    print(f"⚠️ Warning: Image not found at {path}. Skipping.")
                    continue
                batch_images.append(Image.open(path).convert("RGB"))
                valid_paths.append(path)
            except Exception as e:
                print(f"❌ An error occurred loading {path}: {e}. Skipping.")

        if not valid_paths:
            print("No valid images could be loaded for CLIP. Returning original data.")
            return batch_data

        # 2. Perform single batch inference
        # We pass the list of PIL images and the candidate labels
        try:
            # The result is a list of lists: [[class1, class2, ...], [class1, ...], ...]
            results: List[List[Dict[str, Any]]] = clip_pipeline(
                batch_images, 
                candidate_labels=prompts
            )
        except Exception as e:
            print(f"\nRUNTIME ERROR during CLIP pipeline execution: {e}")
            print("The batch might be too large for available memory.")
            return batch_data

        # 3. Post-process the results
        update_count = 0
        for idx, image_results in enumerate(results):
            image_path = valid_paths[idx]
            
            # The pipeline returns all prompts, sorted by score.
            # We only care about the top-scoring one.
            top_detection = image_results[0]
            
            # Check if the top score meets our threshold.
            # Since the pipeline only returns labels from our `prompts` list,
            # we know the label is one we are looking for.
            if top_detection['score'] >= confidence_threshold:
                # print(f"✅ Classified {image_path} as '{top_detection['label']}' (score: {top_detection['score']:.2f}). Setting label to '{target_label}'.")
                
                # Update the label to the generic target_label
                batch_data[image_path] = target_label
                update_count += 1
            else:
                # The top score was too low, so we "leave it" (it remains 'unknown')
                pass

        print(f"\nSuccessfully processed {len(valid_paths)} images using CLIP. Updated {update_count} labels to '{target_label}'.")
        return batch_data


  from .autonotebook import tqdm as notebook_tqdm


HEIC/HEIF support enabled.
FaceAnalysis support enabled.


In [None]:

#test_dir = r"D:\images\HockingHills"
test_dir = r"D:\images\images"
output = r"D:\images\output"

import matplotlib.pyplot as plt
# 1. Instantiate the Dataloader
# Using a small batch size for the example
dataloader = ImageDataloader(root_dir=test_dir, batch_size=32)

classifier = ImageClassifier(dataloader=dataloader, output_dir=output)

try:
    classifier.classify_images()
except Exception as e:
    print(f"\nAn error occurred during classification: {e}")
    print("This can happen if you are offline or models are unavailable.")

# 5. Print a summary
print("\n--- Final Output Summary ---")
try:
    for category in os.listdir(output):
        category_path = os.path.join(output, category)
        if os.path.isdir(category_path):
            files = os.listdir(category_path)
            print(f"Files in '{category}': {len(files)} {files}")
except FileNotFoundError:
    print("Output directory 'classification_output' not found. Did the script run?")

Scanning directory: D:\images\images...
--- Scan Complete ---
Total image files found: 8619
Duplicate images skipped: 0
Total unique images to process: 8619
Using device: cuda (Pipeline device: 0)
Ensuring output directories exist at: D:\images\output
Loading models...




Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:\Users\xiaom/.insightface\models\buffalo_l\1k3d68.onnx landmark_3d_68 ['None', 3, 192, 192] 0.0 1.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:\Users\xiaom/.insightface\models\buffalo_l\2d106det.onnx landmark_2d_106 ['None', 3, 192, 192] 0.0 1.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:\Users\xiaom/.insightface\models\buffalo_l\det_10g.onnx detection [1, 3, '?', '?'] 127.5 128.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:\Users\xiaom/.insightface\models\buffalo_l\genderage.onnx genderage ['None', 3, 96, 96] 0.0 1.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:\Users\xiaom/.insightface\models\buffalo_l\w600k_r50.onnx recognition ['None', 3, 112, 112] 127.

Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0


Loaded: facebook/detr-resnet-50


Device set to use cuda:0


Loaded: clip-vit-large-patch14
Model loading complete.

--- Starting Image Classification Waterfall ---

Processing Batch 1 / 270 (Size: 32)...

Processing all 32 'unknown' images using the detr pipeline...

Successfully processed 32 images using detr. Updated 32 labels to 'people'.
No images labeled 'unknown' found for CLIP processing. Returning original data.
Batch 1 complete.

Processing Batch 2 / 270 (Size: 32)...

Processing all 32 'unknown' images using the detr pipeline...

Successfully processed 32 images using detr. Updated 30 labels to 'people'.

Processing 2 'unknown' images using the RetinaFace model...

Successfully processed 2 images using detr. Updated 1 labels to 'people'.

Processing 1 'unknown' images using CLIP with 8 prompts...

Successfully processed 1 images using CLIP. Updated 0 labels to 'view'.
Batch 2 complete.

Processing Batch 3 / 270 (Size: 32)...

Processing all 32 'unknown' images using the detr pipeline...

Successfully processed 32 images using detr. Up

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset



Successfully processed 32 images using detr. Updated 32 labels to 'people'.
No images labeled 'unknown' found for CLIP processing. Returning original data.
Batch 11 complete.

Processing Batch 12 / 270 (Size: 32)...

Processing all 32 'unknown' images using the detr pipeline...

Successfully processed 32 images using detr. Updated 32 labels to 'people'.
No images labeled 'unknown' found for CLIP processing. Returning original data.
Batch 12 complete.

Processing Batch 13 / 270 (Size: 32)...

Processing all 32 'unknown' images using the detr pipeline...

Successfully processed 32 images using detr. Updated 32 labels to 'people'.
No images labeled 'unknown' found for CLIP processing. Returning original data.
Batch 13 complete.

Processing Batch 14 / 270 (Size: 32)...

Processing all 32 'unknown' images using the detr pipeline...

Successfully processed 32 images using detr. Updated 31 labels to 'people'.

Processing 1 'unknown' images using the RetinaFace model...

Successfully processe