In [1]:
from imagehash import phash, dhash, average_hash
import numpy as np

In [22]:
from types import FunctionType
import os
import random
from PIL import Image
from copy import deepcopy
from numpy import array

class Dataset:
    """
    Class wrapper to instantiate a Dataset object composing of a subset of test images
    and a smaller fraction of images that are used as queries to test search and retrieval.
    Object contains hashed image fingerprints as well, however, hashing method is set by user.
    """
    def __init__(self, path_to_queries: str, path_to_test: str) -> None:
        #print(path_to_queries, path_to_test)
        self.query_docs = self.load_image_set(path_to_queries)
        self.test_docs = self.load_image_set(path_to_test)

    @staticmethod
    def load_image_set(path: str) -> dict: 
        return {doc: os.path.join(path, doc) for doc in os.listdir(path) if doc.endswith('.jpg')}

    
class HashedDataset(Dataset):
    def __init__(self, hashing_function: FunctionType, *args, **kwargs) -> None:
        super(HashedDataset, self).__init__(*args, **kwargs)
        self.hasher = hashing_function
        # self.test_hashes = {doc: str(self.hasher(Image.open(self.test_docs[doc]))) for doc in self.test_docs}
        # self.query_hashes = {doc: str(self.hasher(Image.open(self.query_docs[doc]))) for doc in self.query_docs}
        self.fingerprint()
        self.doc2hash = deepcopy(self.test_hashes)
        #self.doc2hash.update(self.query_hashes)
        self.hash2doc = {self.doc2hash[doc]: doc for doc in self.doc2hash}

    
    def fingerprint(self) -> None:
        self.test_hashes = {doc: str(self.hasher(Image.open(self.test_docs[doc]))) for doc in self.test_docs}
        self.query_hashes = {doc: str(self.hasher(Image.open(self.query_docs[doc]))) for doc in self.query_docs}
        
        
    def get_hashes(self) -> dict:
        return self.doc2hash


    def get_query_hashes(self) -> dict:
        return self.query_hashes


    def get_test_hashes(self) -> dict:
        return self.test_hashes

In [23]:
hull = HashedDataset(
    dhash,
    '/Users/zubin.john/forge/image-dedup/Transformed_dataset/Query/',
    '/Users/zubin.john/forge/image-dedup/Transformed_dataset/Retrieval/'
)

In [26]:
import shelve

class ResultSet:
    """In order to retrieve duplicate images an index needs to be built against which
    search operations are run. The ResultSe Class serves as a search and retrieval
    interface, essential for driving interfacing for downstream tasks.

    Takes input dictionary of image hashes for which DB has to be created."""
    def __init__(self, index_save_path: str, candidates:dict, queries: dict) -> None:
        self.db_path = f'{index_save_path}.db'
        self.db = self.create_db_index(index_save_path)
        self.populate_db(candidates)
        self.fetch_nearest_neighbors(queries)
        self.destroy_db_index()

    @staticmethod
    def create_db_index(path) -> shelve.DbfilenameShelf:
        return shelve.open(path, writeback=True)


    def refresh_db_buffer(self) -> shelve.DbfilenameShelf:
        return shelve.open(self.db_path)


    def populate_db(self, candidates: dict):
        for each in candidates:
            self.db[candidates[each]] = self.db.get(candidates[each], []) + [each]
        # Close the shelf database
        self.db.close()


    def fetch_nearest_neighbors(self, queries) -> None:
        self.db = self.refresh_db_buffer()
        self.query_results = {query: self.db[queries[query]] for query in queries}
        self.db.close()

        
    def destroy_db_index(self) -> None:
        if self.query_results and os.path.exists(self.db_path):
            os.remove(self.db_path)
            
    def retrieve_results(self):
        return self.query_results

In [27]:
hashes = hull.get_hashes()
queries = hull.get_query_hashes()

res = ResultSet('imageset', hashes, queries).retrieve_results()

In [29]:
import pickle
import numpy as np
from types import FunctionType


class EvalPerformance:
    def __init__(self, dict_correct: dict, dict_retrieved: dict) -> None:
        self.dict_correct = dict_correct # dict of correct retrievals for each query(= ground truth), {'1.jpg': 'correct_dup1.jpg'}
        self.dict_retrieved = dict_retrieved # dict of all retrievals for each query, {'1.jpg': 'retrieval_1.jpg'}

    @staticmethod
    def avg_prec(correct_duplicates: list, retrieved_duplicates: list) -> float:
        """Input: (list of correct duplicates (i.e., ground truth), list of retrieved duplicates) for one single query
        return: float representing average precision for one input query"""
        if not len(retrieved_duplicates):
            return 0.0
        count_real_correct = len(correct_duplicates)
        relevance = np.array([1 if i in correct_duplicates else 0 for i in retrieved_duplicates])
        relevance_cumsum = np.cumsum(relevance)
        prec_k = [relevance_cumsum[k] / (k + 1) for k in range(len(relevance))]
        prec_and_relevance = [relevance[k] * prec_k[k] for k in range(len(relevance))]
        avg_precision = np.sum(prec_and_relevance) / count_real_correct
        return avg_precision

    @staticmethod
    def ndcg(correct_duplicates: list, retrieved_duplicates: list) -> float:
        """Input: (list of correct duplicates (i.e., ground truth), list of retrieved duplicates) for one single query
                return: float representing Normalized discounted Cumulative Gain (NDCG) for one input query"""
        if not len(retrieved_duplicates):
            return 0.0
        relevance = np.array([1 if i in correct_duplicates else 0 for i in retrieved_duplicates])
        relevance_numerator = [2 ** (k) - 1 for k in relevance]
        relevance_denominator = [np.log2(k + 2) for k in
                                 range(len(relevance))]  # first value of denominator term should be 2

        dcg_terms = [relevance_numerator[k] / relevance_denominator[k] for k in range(len(relevance))]
        dcg_k = np.sum(dcg_terms)

        # get #retrievals
        # if #retrievals <= #ground truth retrievals, set score=1 for calculating idcg
        # else score=1 for first #ground truth retrievals entries, score=0 for remaining positions

        if len(dcg_terms) <= len(correct_duplicates):
            ideal_dcg = np.sum([1 / np.log2(k + 2) for k in range(len(dcg_terms))])
            ndcg = dcg_k / ideal_dcg
        else:
            ideal_dcg_terms = [1] * len(correct_duplicates) + [0] * (len(dcg_terms) - len(correct_duplicates))
            ideal_dcg_numerator = [(2 ** ideal_dcg_terms[k]) - 1 for k in range(len(ideal_dcg_terms))]
            ideal_dcg_denominator = [np.log2(k + 2) for k in range(len(ideal_dcg_terms))]
            ideal_dcg = np.sum([ideal_dcg_numerator[k] / ideal_dcg_denominator[k] for k in range(len(ideal_dcg_numerator))])
            ndcg = dcg_k / ideal_dcg
        return ndcg

    @staticmethod
    def jaccard_similarity(correct_duplicates: list, retrieved_duplicates: list) -> float:
        """Input: (list of correct duplicates (i.e., ground truth), list of retrieved duplicates) for one single query
                return: float representing jaccard similarity for one input query"""
        if not len(retrieved_duplicates):
            return 0.0
        set_correct_duplicates = set(correct_duplicates)
        set_retrieved_duplicates = set(retrieved_duplicates)

        intersection_dups = set_retrieved_duplicates.intersection(set_correct_duplicates)
        union_dups = set_retrieved_duplicates.union(set_correct_duplicates)

        jacc_sim = len(intersection_dups) / len(union_dups)
        return jacc_sim

    def mean_all_func(self, metric_func: FunctionType) -> float:
        """Input: metric function on which mean is to be calculated across all queries
                return: float representing mean of the metric across all queries"""
        all_metrics = []
        for k in self.dict_correct.keys():
            all_metrics.append(metric_func(self.dict_correct[k], self.dict_retrieved[k]))
        return np.mean(all_metrics)

    def get_all_metrics(self, save: bool=True) -> dict:
        """Input: Save flag indicating whether the dictionary below should be saved
        return: dictionary of all mean metrics"""
        dict_average_metrics = {
            'MAP': self.mean_all_func(self.avg_prec),
            'NDCG': self.mean_all_func(self.ndcg),
            'Jaccard': self.mean_all_func(self.jaccard_similarity)
        }

        if save:
            with open('all_average_metrics.pkl', 'wb') as f:
                pickle.dump(dict_average_metrics, f)
        return dict_average_metrics

In [33]:
import pickle
with open('/Users/zubin.john/forge/image-dedup/Transformed_dataset/ground_truth_transformed.pkl', 'rb') as rb:
    correct_dict = pickle.load(rb)

In [35]:
e1 = EvalPerformance(correct_dict, res)
e1.get_all_metrics()

## Appendix

In [28]:
## Unit test

from PIL import Image

x = np.array(Image.open('/Users/zubin.john/forge/image-dedup/Transformed_dataset/Retrieval/ukbench04754_vflip.jpg'))
y = Image.open('/Users/zubin.john/forge/image-dedup/Transformed_dataset/Retrieval/ukbench04754_vflip.jpg')           

assert dhash(Image.fromarray(x)) == dhash(y)