In [43]:
import numpy as np
class MultiIndexHash(object):

    def __init__(self,codes,m=None):
        self.N = codes.shape[0]
        self.Q = codes.shape[1]
        
        self.codes = codes
        
        if not m:
            m = codes.shape[1]//np.log2(self.N)
            
        self.m = int(m)
        self.s = np.array_split(np.arange(self.Q),self.m)
        
        self.tables = self.init_tables()
        self.lookup = list(np.asarray(list(t.keys())) for t in self.tables)
        
        
    def init_tables(self):
        '''creates multi-index hash tables
           codes - a NxQ binary array with N vectors of length Q
           m - number of tables to build, if empty, will compute optimal number'''
        tables = []

        for j in range(self.m):
            table = {}
            for i in range(self.N):
                substr = tuple(self.codes[i,self.s[j]])
                if substr not in table:
                    table[substr] = []
                table[substr].append(i)
            tables.append(table)

        return tables
    
    def r_search(self,query,r):
        
        r_ = r // self.m
        a = r % self.m
        
        neighbors = set()
        
        ## Search for neighbors using substring hash tables
        for j in range(self.m):
            if j < a:
                r_search = r_
            else:
                r_search = r_ - 1
            
            
            sub_index = self.s[j]
            q_sub = query[sub_index]
            
                        
            look_up = self.lookup[j]          
            q_sub = np.reshape(q_sub,(1,-1))
            dist = np.sum(np.logical_xor(q_sub,look_up), axis=1) ##Hamming Distance
            
            candidates = set()

            for n in np.argwhere(dist <= r_search).flatten():
                neighbors |= self.tables[j][tuple(look_up[n,:])]
            
        ## Check all neighbors using full Hamming Distance
        
        neighbors = np.array(list(neighbors))
        codes_n = self.codes[neighbors,:]
        dist = np.sum(np.logical_xor(query,codes_n), axis=1)
        
        results = {}
        for n in np.argwhere(dist <= r).flatten():
            results[neighbors[n]] = dist[n]
        return sorted(results.items(), key = lambda x: x[1])
    
    def k_nn(self,query,k):
        neighbors = [set() for i in range(self.Q)]
        near = 0
        j = 0
        r = 0
        r_ = 0
        while near < k:
            sub_index = self.s[j]
            q_sub = query[sub_index]
            
            look_up = self.lookup[j]            
            q_sub = np.reshape(q_sub,(1,-1))
            dist = np.sum(np.logical_xor(q_sub,look_up), axis=1) ##Hamming Distance
            
            candidates = set()
            for n in np.argwhere(dist <= r_).flatten():
                for l in self.tables[j][tuple(look_up[n,:])]:
                    candidates.add(l)
                
            candidates = np.array(list(candidates))
            codes_n = self.codes[candidates,:]
            dist = np.sum(np.logical_xor(query,codes_n), axis=1)
            
            for i in range(candidates.shape[0]):
                d = dist[i]
                neighbors[d].add(candidates[i])
                
            near = sum(list(len(neighbors[d]) for d in range(r)))
            
            j += 1
            if j >= self.m:
                j = 0
                r_ += 1
                
            r += 1
            
        out = []
        for d in range(r):
            for n in neighbors[d]:
                out.append((n,d))
        return out

In [36]:
def dist_knn(img_feature,k,feature_matrix):
    pairs = []
    dist = np.sum(np.logical_xor(img_feature,feature_matrix),axis=1)
        
    
    match = np.argsort(dist)[0:k]
    for m in match:
        pairs.append((m,dist[m]))
        
    return pairs

In [44]:
import pickle
feat = pickle.load(open('Data/feature_matrix/FC6_full_trained.p','rb'))
f = feat > 0

In [45]:
import time
start = time.time()
MIH = MultiIndexHash(f)
print('{:.2f} s'.format(time.time()-start))

50.93 s


In [48]:
%time MIH.r_search(f[41,:],1500)

CPU times: user 1.49 s, sys: 197 ms, total: 1.69 s
Wall time: 1.69 s


[(41, 0),
 (10750, 964),
 (9289, 977),
 (16356, 979),
 (7311, 984),
 (18918, 1013),
 (14597, 1017),
 (3138, 1021),
 (4246, 1023),
 (13399, 1024),
 (15516, 1036),
 (18728, 1036),
 (8396, 1040),
 (2929, 1044),
 (5075, 1046),
 (16675, 1046),
 (17435, 1050),
 (10227, 1054),
 (5183, 1059),
 (4886, 1061),
 (16521, 1062),
 (8163, 1063),
 (12161, 1064),
 (16622, 1065),
 (4094, 1074),
 (16030, 1074),
 (12579, 1075),
 (8661, 1079),
 (19359, 1081),
 (17836, 1083),
 (1232, 1084),
 (1019, 1085),
 (13646, 1086),
 (5525, 1089),
 (16732, 1090),
 (16686, 1091),
 (13724, 1092),
 (11337, 1093),
 (14773, 1093),
 (3814, 1094),
 (13060, 1097),
 (8377, 1098),
 (8382, 1098),
 (14064, 1100),
 (16974, 1100),
 (5867, 1102),
 (19496, 1105),
 (2605, 1110),
 (6421, 1110),
 (529, 1112),
 (4658, 1112),
 (17799, 1112),
 (19124, 1112),
 (10970, 1113),
 (12117, 1115),
 (16671, 1115),
 (4801, 1117),
 (6369, 1119),
 (7123, 1119),
 (1543, 1120),
 (11441, 1120),
 (7875, 1122),
 (11734, 1124),
 (15718, 1124),
 (19610, 1125),

In [50]:
%snakeviz MIH.k_nn(f[400,:],10)

 
*** Profile stats marshalled to file '/var/folders/xg/61d36gvd7c74014sl8wkr2bw0000gn/T/tmpq4211gf6'. 


In [38]:
dist_knn(f[400,:],10,f)

[(400, 0),
 (12380, 764),
 (10892, 900),
 (19923, 901),
 (6321, 907),
 (18297, 910),
 (10734, 912),
 (4800, 930),
 (18800, 931),
 (4997, 943)]

In [23]:
class MacOSFile(object):
    def __init__(self, f):
        self.f = f

    def __getattr__(self, item):
        return getattr(self.f, item)

    def read(self, n):
        if n >= (1 << 31):
            buffer = bytearray(n)
            pos = 0
            while pos < n:
                size = min(n - pos, 1 << 31 - 1)
                chunk = self.f.read(size)
                buffer[pos:pos + size] = chunk
                pos += size
            return buffer
        return self.f.read(n)

In [25]:
with open('Data/feature_matrix/fc6_full_set.p','rb') as f:
    feat = pickle.load(MacOSFile(f))

In [26]:
f = feat > 0

In [27]:
import time
start = time.time()
MIH = MultiIndexHash(f)
print('{:.2f} s'.format(time.time()-start))

268.15 s


In [29]:
%timeit MIH.r_search(f[40,:],1500)

KeyboardInterrupt: 

In [None]:
%timeit MIH.k_nn(f[40,:],10)

In [40]:
%load_ext snakeviz

In [41]:
%snakeviz MIH.k_nn(f[40,:],10)

 
*** Profile stats marshalled to file '/var/folders/xg/61d36gvd7c74014sl8wkr2bw0000gn/T/tmpzkzwoe87'. 


In [42]:
%snakeviz MIH.r_search(f[40,:],1000)

 
*** Profile stats marshalled to file '/var/folders/xg/61d36gvd7c74014sl8wkr2bw0000gn/T/tmppi6v68r9'. 
