In [1]:
import cv2
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import euclidean_distances
import os
import seaborn as sns
import math
import random
import falconn
import timeit
# sns.set_theme()


def kmeans_trans(img, K=10, attempts=10, epsilon=0.1, max_iter=10, lab=False):
    if lab is True:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    Z = np.float32(img.reshape((-1, 3)))
    
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, max_iter, epsilon)
    K = K
    attempts = attempts
    ret, label, center = cv2.kmeans(Z, K , None, criteria, attempts, cv2.KMEANS_RANDOM_CENTERS)

    center = np.uint8(center)
    res = center[label.flatten()]
    res = res.reshape((img.shape))
    if lab is True:
        res = cv2.cvtColor(res, cv2.COLOR_LAB2BGR)
    return res, label

def proportion_list(label):
    """
    input: a label list of an image after kmeans transformation
    return: a sorted color proportion list
    """
    return (np.sort(np.unique(label, return_counts=True)[1]/len(label))*100)[::-1]    
    
def get_proportion_list(img, K=10, attempts=10, epsilon=0.1, max_iter=10, lab=False):
    res, label = kmeans_trans(img, K, attempts, epsilon, max_iter, lab)
    return proportion_list(label)

def euclidean_dist_df(proportion_df):
    cp_df = pd.DataFrame(euclidean_distances(proportion_df.T, squared=True), columns=proportion_df.columns)
    cp_df["name"] = cp_df.columns
    cp_df = cp_df.set_index("name")
    return cp_df

def embedding(proportion_df):
    """
        embedding proportion euclidean lists to hamming binary strings
    """
    euclid_df = np.floor(proportion_df)
    max_digit = np.max(euclid_df.max())
    ham_dict= {}
    for col in proportion_df.columns:
        euclid_list = list(euclid_df[col])
        ham_str = ""
        for num in euclid_list:
            ham_str += "1"*int(num) + "0"* int(max_digit-num)
        ham_dict[col] = ham_str
    return ham_dict

def hashfunction(K, max_digit, dim):
    hashes = []
    for i in range(K):
        hashes.append(random.randint(0, max_digit*dim-1))
    def h(ham_dict):
        hashtable = {}
        for key in ham_dict:
            hashcode = ""
            for j in hashes:
                hashcode += ham_dict[key][j]
            if hashcode in hashtable:
                hashtable[hashcode].append(key)
            else:
                hashtable[hashcode] = [key]
        return hashtable
    return h

In [2]:
group = "blur_modified"
images = os.listdir(f"images/{group}")
if ".DS_Store" in images:
    images.remove(".DS_Store")
proportion_dict = {}
t1 = timeit.default_timer()
for image in images:
    img = cv2.imread(f"images/{group}/{image}")
    proportion_dict[image] = get_proportion_list(img, K=20, epsilon=0.0001, max_iter=500, lab=True)
t2 = timeit.default_timer()
print(f'Dataset construction time:  {(t2 - t1) / float(len(proportion_dict))} per image')
proportion_df = pd.DataFrame(proportion_dict)

Dataset construction time:  113.03372778137498 per image


In [12]:
dataset = []
dataset_label = []
queries = []
queries_label = []
for p in proportion_dict:
    if 'origin' in p:
        dataset_label.append(p)
        dataset.append(list(proportion_dict[p]))
    else:
        queries_label.append(p)
        queries.append(proportion_dict[p])
dataset

[[13.801727572859745,
  12.08376024590164,
  11.227800546448087,
  10.926471425318761,
  8.540741689435336,
  5.534921448087432,
  4.93226320582878,
  3.252006489071038,
  3.2448912795992713,
  3.0815972222222223,
  2.9083418715846996,
  2.8777464708561022,
  2.7368653233151186,
  2.60985883424408,
  2.5998975409836067,
  2.5177168715846996,
  2.2918089708561022,
  1.8172244990892532,
  1.747495446265938,
  1.2668630464480874],
 [25.790298333650934,
  11.643799360562367,
  10.179127231150352,
  9.60162188485888,
  8.708103072264922,
  5.890448665015139,
  3.6222448072159055,
  3.5814859514281485,
  3.1585466556564823,
  3.091321010396155,
  2.4349446314764234,
  2.3364881746384634,
  2.1840394672764614,
  1.9458383620233333,
  1.288403311524699,
  1.2439391052107815,
  1.0486141989032163,
  0.953333756801965,
  0.7135446441804822,
  0.5838573757648903],
 [10.037163013353489,
  9.429327286470143,
  8.854035441337029,
  8.837238599143362,
  8.805744520030235,
  8.112874779541446,
  7.707

In [13]:
dataset /= np.linalg.norm(dataset, axis=1).reshape(-1, 1)
queries /= np.linalg.norm(queries, axis=1).reshape(-1, 1)

In [14]:
dataset = np.array(dataset,)
queries = np.array(queries,)

In [15]:
dataset.shape

(4, 20)

In [21]:
# Perform linear scan using NumPy to get answers to the queries.
print('Solving queries using linear scan')
t1 = timeit.default_timer()
answers = []
for query in queries:
    answers.append(np.dot(list(dataset), query).argmax())
t2 = timeit.default_timer()
print('Done')
print('Linear scan time: {} per query'.format((t2 - t1) / float(len(queries))))
anss = [dataset_label[ans] for ans in answers]

Solving queries using linear scan
Done
Linear scan time: 0.0001810657502119284 per query


In [22]:
cp_dict = {}
cp_dict["queries_label"] = queries_label
cp_dict["answer_label"] = anss
cp_df =  pd.DataFrame(cp_dict)
cp_df

Unnamed: 0,queries_label,answer_label
0,4_blurred.jpg,4_origin.jpg
1,1_blurred.jpg,1_origin.png
2,2_blurred.jpg,2_origin.jpg
3,3_blurred.jpg,3_origin.jpg


In [23]:
queries_label

['4_blurred.jpg', '1_blurred.jpg', '2_blurred.jpg', '3_blurred.jpg']

In [24]:
number_of_tables = 50
params_cp = falconn.LSHConstructionParameters()
params_cp.dimension = len(dataset[0])
params_cp.lsh_family = falconn.LSHFamily.CrossPolytope
params_cp.distance_function = falconn.DistanceFunction.EuclideanSquared
params_cp.l = 50
params_cp.num_rotations = 1
params_cp.seed = 5721840
params_cp.num_setup_threads = 0
params_cp.storage_hash_table = falconn.StorageHashTable.BitPackedFlatHashTable
falconn.compute_number_of_hash_functions(18, params_cp)

print('Constructing the LSH table')
t1 = timeit.default_timer()
table = falconn.LSHIndex(params_cp)
table.setup(dataset)
t2 = timeit.default_timer()
print('Done')
print(f'Construction time: {t2-t1}')

Constructing the LSH table
Done
Construction time: 0.04291926999940188


In [25]:
query_object = table.construct_query_object()

In [26]:
print('Choosing number of probes')
number_of_probes = number_of_tables

def evaluate_number_of_probes(number_of_probes):
    query_object.set_num_probes(number_of_probes)
    score = 0
    for (i, query) in enumerate(queries):
        if answers[i] in query_object.get_candidates_with_duplicates(
                query):
            score += 1
    return float(score) / len(queries)

while True:
    accuracy = evaluate_number_of_probes(number_of_probes)
    print('{} -> {}'.format(number_of_probes, accuracy))
    if accuracy >= 0.9:
        break
    number_of_probes = number_of_probes * 2
if number_of_probes > number_of_tables:
    left = number_of_probes // 2
    right = number_of_probes
    while right - left > 1:
        number_of_probes = (left + right) // 2
        accuracy = evaluate_number_of_probes(number_of_probes)
        print('{} -> {}'.format(number_of_probes, accuracy))
        if accuracy >= 0.9:
            right = number_of_probes
        else:
            left = number_of_probes
    number_of_probes = right
print('Done')
print('{} probes'.format(number_of_probes))

# final evaluation
t1 = timeit.default_timer()
score = 0
for (i, query) in enumerate(queries):
    if query_object.find_nearest_neighbor(query) == answers[i]:
        score += 1
t2 = timeit.default_timer()

print('Query time: {}'.format((t2 - t1) / len(queries)))
print('Precision: {}'.format(float(score) / len(queries)))

Choosing number of probes
50 -> 1.0
Done
50 probes
Query time: 0.00018558499982646026
Precision: 1.0


In [None]:
max_digit = int(np.max(np.floor(proportion_df).max()))
K, L = 50, 5
final_tables = []
ft={}
gs = []
for i in range(L):
    g = hashfunction(K, max_digit, dim=10)
    gs.append(g)
    table = g(embedding(proportion_df))
    final_tables.append(table)
    ft = {**ft, **table}
final_tables

In [None]:
#query 

res = []
for i, g in enumerate(gs):
    ql = list(g(embedding(proportion_df["3_whitegem.jpg"].to_frame())).keys())[0]
    res.append(final_tables[i][ql])
res