In [4]:
from imagehash import phash, dhash, average_hash
import numpy as np

In [5]:
from types import FunctionType
import os
import random
from PIL import Image
from copy import deepcopy
from numpy import array

class Dataset:
    """
    Class wrapper to instantiate a Dataset object composing of a subset of test images
    and a smaller fraction of images that are used as queries to test search and retrieval.
    Object contains hashed image fingerprints as well, however, hashing method is set by user.
    """
    def __init__(self, path_to_queries: str, path_to_test: str) -> None:
        #print(path_to_queries, path_to_test)
        self.query_docs = self.load_image_set(path_to_queries)
        self.test_docs = self.load_image_set(path_to_test)

    @staticmethod
    def load_image_set(path: str) -> dict: 
        return {doc: os.path.join(path, doc) for doc in os.listdir(path) if doc.endswith('.jpg')}

    
class HashedDataset(Dataset):
    def __init__(self, hashing_function: FunctionType, *args, **kwargs) -> None:
        super(HashedDataset, self).__init__(*args, **kwargs)
        self.hasher = hashing_function
        # self.test_hashes = {doc: str(self.hasher(Image.open(self.test_docs[doc]))) for doc in self.test_docs}
        # self.query_hashes = {doc: str(self.hasher(Image.open(self.query_docs[doc]))) for doc in self.query_docs}
        self.fingerprint()
        self.doc2hash = deepcopy(self.test_hashes)
        self.doc2hash.update(self.query_hashes)
        self.hash2doc = {self.doc2hash[doc]: doc for doc in self.doc2hash}

    
    def fingerprint(self) -> None:
        self.test_hashes = {doc: str(self.hasher(Image.open(self.test_docs[doc]))) for doc in self.test_docs}
        self.query_hashes = {doc: str(self.hasher(Image.open(self.query_docs[doc]))) for doc in self.query_docs}
        
        
    def get_hashes(self) -> dict:
        return self.doc2hash


    def get_query_hashes(self) -> dict:
        return self.query_hashes


    def get_test_hashes(self) -> dict:
        return self.test_hashes

In [3]:
hull = HashedDataset(
    dhash,
    '/Users/zubin.john/forge/image-dedup/Transformed_dataset/Query/',
    '/Users/zubin.john/forge/image-dedup/Transformed_dataset/Retrieval/'
)

In [16]:
import shelve

class ResultSet:
    """In order to retrieve duplicate images an index needs to be built against which
    search operations are run. The ResultSe Class serves as a search and retrieval
    interface, essential for driving interfacing for downstream tasks.

    Takes input dictionary of image hashes for which DB has to be created."""
    def __init__(self, index_save_path: str, candidates:dict, queries: dict) -> None:
        self.db_path = index_save_path
        self.db = self.create_db_index(index_save_path)
        self.populate_db(candidates)
        self.fetch_nearest_neighbors(queries)

    @staticmethod
    def create_db_index(path) -> shelve.DbfilenameShelf:
        return shelve.open(path, writeback = True)


    def refresh_db_buffer(self) -> shelve.DbfilenameShelf:
        return shelve.open(self.db_path)


    def populate_db(self, candidates: dict):
        for each in candidates:
            self.db[candidates[each]] = self.db.get(candidates[each], []) + [each]
        # Close the shelf database
        self.db.close()


    def fetch_nearest_neighbors(self, queries) -> None:
        self.db = self.refresh_db_buffer()
        self.query_results = {query: self.db[queries[query]] for query in queries}
        self.db.close()

        
    def destroy_db_index(self) -> None:
        if self.query_results:
            os.remove(self.db_path)
            
    def retrieve_results(self):
        return self.query_results

In [17]:
hashes = hull.get_hashes()
queries = hull.get_query_hashes()

res = ResultSet('../Transformed_dataset/imageset', hashes, queries).retrieve_results()

## Appendix

In [28]:
## Unit test

from PIL import Image

x = np.array(Image.open('/Users/zubin.john/forge/image-dedup/Transformed_dataset/Retrieval/ukbench04754_vflip.jpg'))
y = Image.open('/Users/zubin.john/forge/image-dedup/Transformed_dataset/Retrieval/ukbench04754_vflip.jpg')           

assert dhash(Image.fromarray(x)) == dhash(y)