In [1]:
import re
import pandas as pd
import numpy as np
import csv

import pylsh

In [2]:
model_path = 'model/'
model_name = 'model-20170512-110547'
checkpoint_path = 'model-20170512-110547.ckpt-250000'
src_images = '../cropped_photos/'
IMG_SHAPE = (160, 160)


In [3]:
embeddings = pd.read_csv("data/cropped_embeddings.csv", names=['index', 'embedding'], dtype=str, skiprows=1)

In [4]:
embeddings['float_values'] = embeddings.embedding.apply(lambda x: re.sub(' +', ' ', str(x).replace('\n', ' ')
                                                                         ).replace('[', '') \
                                                                          .replace(']', '') \
                                                                          .split(' '))

In [5]:
embeddings['float_values'] = embeddings['float_values'].apply(lambda x: 
                                                        np.array(list(map(float, 
                                                                          [item for item in x if item != '']))))

In [6]:
len(embeddings.float_values[0])

128

In [7]:
# embeddings.float_values[0]

In [8]:
embeddings_list = embeddings.float_values.values.tolist()

In [9]:
len(embeddings.values)

157220

In [10]:
index = pylsh.PyLSH(50, 64, 128)

In [11]:
index.create_splits()

In [12]:
index.write_planes_to_file(b'./data/split.txt')

True

In [13]:
for i, cur_emb in zip(embeddings.index, embeddings.float_values):
    index.add_to_table(i, cur_emb)
    if i % 10000 == 0:
        print(i)
#     if i == 10:
#         break

0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000


In [14]:
# !mkdir ./data/index

In [15]:
index.write_hash_tables_to_files(b'./data/index')

True

In [16]:
index.write_index_embedding_dict(b'./data/index_embedding.txt')

True

## А теперь найдем для них соседей

In [43]:
bad = index.find_k_neighbors(10, embeddings.loc[0, 'float_values'])
bad

[0, 7772, 10432, 27671, 28029, 79317, 96549, 112643, 113408, 116541]

In [44]:
good = index.dummy_k_neighbors(10, embeddings.index.values.tolist(), embeddings_list, \
                               embeddings.loc[0, 'float_values'])
good

[0, 7772, 28029, 31961, 51484, 79317, 96549, 112643, 113408, 116541]

In [45]:
len(set(bad) & set(good))

8

### 8 из 10 ближайших соседей совпадают с результатом полного перебора.

## Сравним скорость:

In [41]:
%timeit index.find_k_neighbors(10, embeddings.loc[0, 'float_values'])

2.5 ms ± 487 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [42]:
%timeit index.dummy_k_neighbors(10, embeddings.index.values.tolist(), embeddings_list, \
                         embeddings.loc[0, 'float_values'])

1.51 s ± 56.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Наш приближенный метод в среднем работает в 604 раза быстрее, чем полный перебор.

## Проверю как работает инициализация lsh данными с диска

In [49]:
index2 = pylsh.PyLSH(50, 64, 128)

In [50]:
index2.fill_data_from_files(planes_path=b'./data/split.txt', hash_tables_dir_path=b'./data/index',
                           index_embedding_dict_path=b'./data/index_embedding.txt')

True

In [51]:
bad2 = index2.find_k_neighbors(10, embeddings.loc[0, 'float_values'])
bad2

[]