In [1]:
from PIL import Image
import numpy as np
import cv2
import os
from pathlib import Path
from imagehash import phash
from itertools import combinations
from typing import List, Set, Tuple
from termcolor import colored
import datetime

class ImageDescriptor:
    def __init__(self, unique_images: Set[Path], similar_groups: List[List[Path]]):
        self.unique_images = unique_images
        self.similar_groups = similar_groups

    def serialize(self, filepath: str):
        """将描述信息保存为文本文件"""
        print(f"Saving description to {filepath}")
        with open(filepath, 'w') as file:
            file.write("Unique Images:\n")
            for image in self.unique_images:
                file.write(f"{image.name}\n")
            file.write("\nSimilar Groups:\n")
            for group in self.similar_groups:
                file.write(f"Group:\n")
                for image in group:
                    file.write(f"{image.name}\n")
                file.write("\n")

class HashDetector:
    def __init__(self, precision: int):
        self.precision = precision

    def detect(self, images: List[Path]) -> ImageDescriptor:
        print(f"Detecting duplicates using perceptual hash, precision: {self.precision}\nimages cnt: {len(images)}")
        hash_dict = {}
        for image_path in images:
            with Image.open(image_path) as img:
                img_hash = phash(img.convert("L").resize((self.precision, self.precision)))
                if img_hash in hash_dict:
                    hash_dict[img_hash].append(image_path)
                else:
                    hash_dict[img_hash] = [image_path]
        print(f"Found {len(hash_dict)} unique hashes")
        unique_images = set()
        similar_groups = []
        for paths in hash_dict.values():
            if len(paths) == 1:
                unique_images.add(paths[0])
            else:
                similar_groups.append(paths)
        print(f"Found {len(unique_images)} unique images, {len(similar_groups)} similar groups")
        return ImageDescriptor(unique_images, similar_groups)

class ORBDetector:
    def __init__(self, nfeatures: int, threshold: float):
        self.nfeatures = nfeatures
        self.threshold = threshold

    def detect(self, images: List[Path]) -> ImageDescriptor:
        keypoints_dict = {img: self._extract_features(img) for img in images}
        similar_groups = []
        unique_images = set(images)
        
        for img1, img2 in combinations(images, 2):
            kp1, des1 = keypoints_dict[img1]
            kp2, des2 = keypoints_dict[img2]
            if des1 is not None and des2 is not None:
                if self._match_features(des1, des2) > self.nfeatures * self.threshold:
                    similar_groups.append([img1, img2])
                    unique_images.discard(img1)
                    unique_images.discard(img2)

        return ImageDescriptor(unique_images, similar_groups)

    def _extract_features(self, image_path: Path):
        orb = cv2.ORB_create(self.nfeatures)
        img = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
        return orb.detectAndCompute(img, None)

    def _match_features(self, des1, des2):
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        matches = bf.match(des1, des2)
        return len(matches)

class ImageDeduplicator:
    def __init__(self, directory: str):
        self.directory = directory
        self.detectors = [
            HashDetector(8),
            ORBDetector(500, 0.7)
        ]

    def deduplicate(self):
        images = [path for path in Path(self.directory).glob("*") if path.suffix.lower() in [".jpg", ".png"]]
        descriptor = ImageDescriptor(set(), [images])

        for detector in self.detectors:
            print(colored(f"Processed with {type(detector).__name__}[{id(detector)}], similar_groups count: {len(descriptor.similar_groups)}", "yellow"))
            result = detector.detect(images)
            descriptor.unique_images.update(result.unique_images)
            descriptor.similar_groups.extend(result.similar_groups)

            # 序列化当前描述对象
            timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            filepath = f"{self.directory}/descriptor_{type(detector).__name__}_{timestamp}.txt"
            descriptor.serialize(filepath)

            print(colored(f"unique count: {len(descriptor.unique_images)}", "yellow"))
        
        return descriptor

# Example usage
deduplicator = ImageDeduplicator("/Users/chenweichu/dev/data/test")
final_descriptor = deduplicator.deduplicate()
print(f"Final unique images count: {len(final_descriptor.unique_images)}")


[33mProcessed with HashDetector[4776618784], similar_groups count: 1[0m
Detecting duplicates using perceptual hash, precision: 8
images cnt: 86
Found 51 unique hashes
Found 29 unique images, 22 similar groups
Saving description to /Users/chenweichu/dev/data/test/descriptor_HashDetector_20240508194321.txt
[33munique count: 29[0m
[33mProcessed with ORBDetector[4776619792], similar_groups count: 23[0m
Saving description to /Users/chenweichu/dev/data/test/descriptor_ORBDetector_20240508194335.txt
[33munique count: 31[0m
Final unique images count: 31
