In [11]:
from PIL import Image
import numpy as np
import cv2
import os
from pathlib import Path
from imagehash import phash
from itertools import combinations
from typing import List, Set, Tuple

class ImageDescriptor:
    def __init__(self, unique_images: Set[Path], similar_groups: List[List[Path]]):
        self.unique_images = unique_images
        self.similar_groups = similar_groups

class HashDetector:
    def __init__(self, precision: int):
        self.precision = precision

    def detect(self, images: List[Path]) -> ImageDescriptor:
        print(f"Detecting duplicates using perceptual hash, precision: {self.precision}\nimages: {images}")
        hash_dict = {}
        for image_path in images:
            with Image.open(image_path) as img:
                img_hash = phash(img.convert("L").resize((self.precision, self.precision)))
                if img_hash in hash_dict:
                    hash_dict[img_hash].append(image_path)
                else:
                    hash_dict[img_hash] = [image_path]
        print(f"Found {len(hash_dict)} unique hashes")
        unique_images = set()
        similar_groups = []
        for paths in hash_dict.values():
            if len(paths) == 1:
                unique_images.add(paths[0])
            else:
                similar_groups.append(paths)
        print(f"Found {len(unique_images)} unique images, {len(similar_groups)} similar groups")
        return ImageDescriptor(unique_images, similar_groups)

class ORBDetector:
    def __init__(self, nfeatures: int, threshold: float):
        self.nfeatures = nfeatures
        self.threshold = threshold

    def detect(self, images: List[Path]) -> ImageDescriptor:
        keypoints_dict = {img: self._extract_features(img) for img in images}
        similar_groups = []
        unique_images = set(images)
        
        for img1, img2 in combinations(images, 2):
            kp1, des1 = keypoints_dict[img1]
            kp2, des2 = keypoints_dict[img2]
            if des1 is not None and des2 is not None:
                if self._match_features(des1, des2) > self.nfeatures * self.threshold:
                    similar_groups.append([img1, img2])
                    unique_images.remove(img1)
                    unique_images.remove(img2)

        return ImageDescriptor(unique_images, similar_groups)

    def _extract_features(self, image_path: Path):
        orb = cv2.ORB_create(self.nfeatures)
        img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
        return orb.detectAndCompute(img, None)

    def _match_features(self, des1, des2):
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        matches = bf.match(des1, des2)
        return len(matches)

class ImageDeduplicator:
    def __init__(self, directory: str):
        self.directory = directory
        self.detectors = [
            HashDetector(8), HashDetector(16),
            ORBDetector(500, 0.7), ORBDetector(1000, 0.8)
        ]

    def deduplicate(self):
        images = [path for path in Path(self.directory).glob("*") if path.suffix.lower() in [".jpg", ".png"]]
        descriptor = ImageDescriptor(set(), [images])

        for detector in self.detectors:
            print(f"Processed with {type(detector).__name__}")
            new_descriptor = ImageDescriptor(descriptor.unique_images, [])
            for group in descriptor.similar_groups:
                result = detector.detect(group)
                descriptor.unique_images |= result.unique_images
                descriptor.similar_groups.extend(result.similar_groups)
            descriptor = new_descriptor
            print(f"unique count: {len(descriptor.unique_images)}")
        
        return descriptor

# Example usage
deduplicator = ImageDeduplicator("/Users/chenweichu/dev/data/test")
final_descriptor = deduplicator.deduplicate()
print(f"Final unique images count: {len(final_descriptor.unique_images)}")


Processed with HashDetector
Detecting duplicates using perceptual hash, precision: 8
images: [PosixPath('/Users/chenweichu/dev/data/test/k.jpg'), PosixPath('/Users/chenweichu/dev/data/test/af.jpg'), PosixPath('/Users/chenweichu/dev/data/test/aq.jpg'), PosixPath('/Users/chenweichu/dev/data/test/bj.jpg'), PosixPath('/Users/chenweichu/dev/data/test/bk.jpg'), PosixPath('/Users/chenweichu/dev/data/test/ap.jpg'), PosixPath('/Users/chenweichu/dev/data/test/ag.jpg'), PosixPath('/Users/chenweichu/dev/data/test/j.jpg'), PosixPath('/Users/chenweichu/dev/data/test/ae.jpg'), PosixPath('/Users/chenweichu/dev/data/test/h.jpg'), PosixPath('/Users/chenweichu/dev/data/test/ar.jpg'), PosixPath('/Users/chenweichu/dev/data/test/bi.jpg'), PosixPath('/Users/chenweichu/dev/data/test/bh.jpg'), PosixPath('/Users/chenweichu/dev/data/test/as.jpg'), PosixPath('/Users/chenweichu/dev/data/test/i.jpg'), PosixPath('/Users/chenweichu/dev/data/test/ad.jpg'), PosixPath('/Users/chenweichu/dev/data/test/aw.jpg'), PosixPath

KeyboardInterrupt: 