In [1]:
import re
import pandas as pd
import numpy as np
import csv

import pylsh

In [2]:
model_path = 'model/'
model_name = 'model-20170512-110547'
checkpoint_path = 'model-20170512-110547.ckpt-250000'
src_images = '../cropped_photos/'
cropped_embeddings_path = 'data/cropped_embeddings.csv'
split_file_path = b'data/split.txt'
index_dir_path = b'data/index/'
index_embedding_path = b'data/index_embedding.txt'

In [3]:
embeddings = pd.read_csv(cropped_embeddings_path, names=['index', 'embedding'], dtype=str, skiprows=1)

In [4]:
embeddings['float_values'] = embeddings.embedding.apply(lambda x: re.sub(' +', ' ', str(x).replace('\n', ' ')
                                                                         ).replace('[', '') \
                                                                          .replace(']', '') \
                                                                          .split(' '))

In [5]:
embeddings['float_values'] = embeddings['float_values'].apply(lambda x: 
                                                        np.array(list(map(float, 
                                                                          [item for item in x if item != '']))))

In [6]:
len(embeddings.float_values[0])

128

In [7]:
embeddings_list = embeddings.float_values.values.tolist()

In [8]:
len(embeddings.values)

157220

In [9]:
index = pylsh.PyLSH(5, 64, 128)

In [10]:
index.create_splits()

In [11]:
index.write_planes_to_file(split_file_path)

True

In [12]:
for i, cur_emb in zip(embeddings.index, embeddings.float_values):
    index.add_to_table(i, cur_emb)
    if i % 10000 == 0:
        print(i)

0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000


In [13]:
index.write_hash_tables_to_files(index_dir_path)

True

In [14]:
index.write_index_embedding_dict(index_embedding_path)

True

## А теперь найдем для них соседей

In [15]:
bad = index.find_k_neighbors(10, embeddings.loc[0, 'float_values'])
bad

[0, 9228, 9647, 25807, 25982, 28029, 104060, 112167, 131103, 151771]

In [16]:
good = index.dummy_k_neighbors(10, embeddings.index.values.tolist(), embeddings_list, \
                               embeddings.loc[0, 'float_values'])
good

[0, 7772, 28029, 31961, 51484, 79317, 96549, 112643, 113408, 116541]

In [17]:
len(set(bad) & set(good))

2

### 8 из 10 ближайших соседей совпадают с результатом полного перебора.

## Сравним скорость:

In [20]:
%timeit index.find_k_neighbors(10, embeddings.loc[0, 'float_values'])

385 µs ± 58.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [22]:
%timeit index.dummy_k_neighbors(10, embeddings.index.values.tolist(), embeddings_list, \
                         embeddings.loc[0, 'float_values'])

KeyboardInterrupt: 

### Наш приближенный метод в среднем работает в 604 раза быстрее, чем полный перебор.

## Проверю как работает инициализация lsh данными с диска

In [18]:
index2 = pylsh.PyLSH(5, 64, 128)

In [None]:
index2.fill_data_from_files(planes_path=split_file_path, hash_tables_dir_path=index_dir_path,
                           index_embedding_dict_path=index_embedding_path)

In [53]:
bad2 = index2.find_k_neighbors(10, embeddings.loc[0, 'float_values'])
bad2

[]

In [52]:
# TODO: WTF???

In [54]:
index2.write_planes_to_file(b'./data/split2.txt')

True

In [55]:
!mkdir ./data/index2

In [59]:
index2.write_hash_tables_to_files(b'./data/index2/')

True

In [62]:
index2.write_index_embedding_dict(b'./data/index_embedding2.txt')

True