In [22]:
import numpy as np
import matplotlib.pyplot as plt
import math
from hashlib import md5

# Dataset

In [20]:
%store -r Datasets

In [25]:
mnist = Datasets['MNIST-Hamming']

train = mnist['train']
test = mnist['test']
k_near_neighbors = mnist['k_near_neighbors']
nearest_neighbor = mnist['nearest_neighbor']

# One Permutation with Rotation(OPR)

In [23]:
def generate_md5(H):
    hmd5 = md5()
    hmd5.update(str(H).encode(encoding='utf-8'))
    return hmd5.hexdigest()

In [24]:
def OPR(P, k, L, C):
    """
    Desc:
        One Permutation with Rotation
    Args:
        P: 点集
        k: 分段数
            sd: D / k 每段长度
        L: 桶数
        C: 防碰撞参数
    """
    n = P.shape[0]
    D = P.shape[1]
    sd = int(D / k)

    print('k = {}, L = {}'.format(k, L))
    print('Every segament length is {}'.format(sd))
    buckets = []
    seeds = np.random.choice(np.arange(k * L * D), L, replace=False)

    one_segament_vector = list(range(sd, 0, -1))
    auxiliary_vector = np.array(one_segament_vector * k)

    for i in range(L):

        bucket = dict()
        
        # 1. 重排
        P_ = []
        for p in P:
            np.random.seed(seeds[i])
            P_.append(np.random.permutation(p))
        
        # 2. 计算哈希
        for idx, p in enumerate(P_):
            H = []
            
            H = -((auxiliary_vector * p).reshape(-1, sd).max(axis=1) - sd)
            H[H == sd] = -1
            for j, h in enumerate(H):
                if h == -1:
                    nj = j
                    while H[nj] == -1:
                        nj = (1 + nj) % k
                    H[j] = H[nj] + C
            bi = generate_md5(H)
            if bi not in bucket:
                bucket[bi] = [idx]
            else:
                bucket[bi].append(idx)
        buckets.append(bucket)

    return buckets, seeds, auxiliary_vector

In [31]:
def OPR_query(args, q):
    """
    Desc:
        One Permutation with Rotation
    Args:
        args:
            [0]: (D, k, L, C)
            [1]: buckets
            [2]: seeds
            [3]: auxiliary_vector
            
    """
    D, k, L, C = args[0]
    buckets = args[1]
    seeds = args[2]
    auxiliary_vector = args[3]

    sd = int(D / k)

    result = []
    for i in range(L):
        # 1. 重排
        np.random.seed(seeds[i])
        q_ = np.random.permutation(q)
        
        # 2. 计算哈希
        H = -((auxiliary_vector * q_).reshape(-1, sd).max(axis=1) - sd)
        H[H == sd] = -1

        for j, h in enumerate(H):
            if h == -1:
                nj = j
                while H[nj] == -1:
                    nj = (1 + nj) % k
                H[j] = H[nj] + C
        
        bi = generate_md5(H)
        if bi in buckets[i]:
            result.append(buckets[i][bi])

    if len(result) != 0:
        result = np.unique(np.concatenate(result))

    return result

In [56]:
def metrics(P, query_func, args, nearest_neighbor, test):
    correct = 0
    n = P.shape[0]
    total = 0
    selectivity = 0
    for i, q in enumerate(test):
        candidates = query_func(args, q)
        if nearest_neighbor[i] in candidates:
            correct += 1
        selectivity += len(candidates) / n
    
    result = dict()
    result['precision'] = correct / len(test)
    result['selectivity'] = selectivity / len(test)
    
    return result

In [46]:
%%time
k = 8
L = 16

D = train.shape[1]
C = np.ceil(D / k + L)

buckets, seeds, auxiliary_vector = OPR(train, k=k, L=L, C=C)

k = 8, L = 16
Every segament length is 98
CPU times: user 2min 10s, sys: 1.55 s, total: 2min 12s
Wall time: 2min 11s


In [57]:
%%time
metrics(train, OPR_query, [(D, k, L, C), buckets, seeds, auxiliary_vector], nearest_neighbor, test)

CPU times: user 1.18 s, sys: 3.22 ms, total: 1.18 s
Wall time: 1.18 s


{'precision': 0.878, 'selectivity': 0.018446820143884907}