In [75]:
%%writefile solution_6.py
from collections import OrderedDict, defaultdict
from typing import Callable, Tuple, Dict, List

import numpy as np
import scipy.spatial.distance as dist
from tqdm.auto import tqdm


def distance(pointA: np.ndarray, documents: np.ndarray) -> np.ndarray:
    return dist.cdist(pointA, documents).reshape(-1, 1)


def create_sw_graph(
        data: np.ndarray,
        num_candidates_for_choice_long: int = 10,
        num_edges_long: int = 5,
        num_candidates_for_choice_short: int = 10,
        num_edges_short: int = 5,
        use_sampling: bool = False,
        sampling_share: float = 0.05,
        dist_f: Callable = distance
    ) -> Dict[int, List[int]]:    
    graph = {}
    
    for i in range(len(data)):
        distances = dist_f(data[i].reshape(1,-1),
                           data)
        distances = distances.reshape(-1)
                
        order_short = distances.argsort()
        order_short = order_short[order_short != i]
        idx_short_select = np.random.choice(num_candidates_for_choice_short,
                                      size=num_edges_short,
                                      replace=False)
        
        graph[i] = order_short[idx_short_select].tolist()
        
        order_long = distances.argsort()[::-1]
        order_long = order_long[order_long != i]
        idx_long_select = np.random.choice(num_candidates_for_choice_long,
                                      size=num_edges_long,
                                      replace=False)
        
        graph[i] += order_long[idx_long_select].tolist()
    
    return graph
    

def nsw(query_point: np.ndarray, all_documents: np.ndarray, 
        graph_edges: Dict[int, List[int]],
        search_k: int = 10, num_start_points: int = 5,
        dist_f: Callable = distance) -> np.ndarray:
    documents = []
        
    all_edges = np.array(list(graph_edges.keys()))
    
    queue_flag = False
    queue_edges = list()
    queue_distances = list()
    candidates_edges = list()
    candidates_distances = list()
    explored_edges = list()
    distance_min = 1e12
    parent_idx = None
    idx = np.random.choice(all_edges,
                           size=1,
                           replace=False)[0]
    while len(documents) != search_k:
        edges = np.array(graph_edges[idx])
        edges = edges[~np.isin(edges, explored_edges)]
        
        distances = dist_f(query_point.reshape(1, -1),
                           all_documents[edges])
        distances = distances.reshape(-1)
        
        print('-'*50)
        print(f'idx {idx}',
              f'distance_min {distance_min}',
              f'explored_edges {explored_edges}',
              f'queue_edges {queue_edges}',
              f'queue_distances {queue_distances}',
              f'candidates_edges {candidates_edges}',
              f'candidates_distances {candidates_distances}',
              f'dicuments {documents}',
              f'all_edges {graph_edges[idx]}',
              f'filtered_edges {edges}',
              f'distances {distances}',
              f'current_min {distances.min()}',
              f'current_min_idx {edges[distances.argmin()]}',
              sep='\n')
        
        if (distances.min() > distance_min
            and idx not in documents):
            if not queue_flag:
                print('Ini queue')
                queue_flag = True
    
#                 if parent_idx:
#                     edges = np.array(graph_edges[parent_idx])
#                     edges = edges[~np.isin(edges, explored_edges)]

#                     distances = dist_f(query_point.reshape(1, -1),
#                                        all_documents[edges])
#                     distances = distances.reshape(-1)
                    
                queue_edges = list(edges[distances.argsort()])[1:]
                queue_distances = list(np.sort(distances))[1:]
                candidates_edges = list()
                candidates_distances = list()
            explored_edges.append(idx) 
            
            if len(queue_edges) != 0:
                print('Pop queue')
                candidates_edges.append(idx)
                candidates_distances.append(distance_min)
                
                idx = queue_edges[0]
                distance_min = queue_distances[0]
                queue_edges.remove(queue_edges[0])
                queue_distances.remove(queue_distances[0])
            else:
                print('Append')
                queue_flag = False
                documents.append(candidates_edges[np.argmin(candidates_distances)])
                explored_edges = list(documents)
                distance_min = 1e12
                parent_idx = None
                idx = np.random.choice(all_edges[~np.isin(all_edges,
                                                          explored_edges)],
                                       size=1,
                                       replace=False)[0]
        elif (distances.min() > distance_min
              and idx in documents):
            explored_edges.append(idx)     
#             print('Restart')            
            distance_min = 1e12
            idx = np.random.choice(all_edges[~np.isin(all_edges,
                                                      explored_edges)],
                                   size=1,
                                   replace=False)[0]
        else:
            explored_edges.append(idx)     
            parent_idx = idx
            idx = edges[distances.argmin()]
            distance_min = distances.min()   
    return np.array(documents)



Overwriting solution_6.py


In [76]:
import solution_6
import importlib
importlib.reload(solution_6)

<module 'solution_6' from '/home/jupyter-v.pashentsev-2/2.matching/6/solution_6.py'>

In [77]:
import numpy as np

In [78]:
query = np.array([[0.5, 0.5, 0.5]])
docs = np.random.randn(1000, 3)
query.shape, docs.shape

((1, 3), (1000, 3))

In [79]:
graph = solution_6.create_sw_graph(docs,
                                   num_candidates_for_choice_long=10,
                                   num_candidates_for_choice_short=10,
                                   num_edges_long=5,
                                   num_edges_short=5)

In [80]:
idxs = solution_6.nsw(query, docs, graph, search_k=10)
idxs

--------------------------------------------------
idx 298
distance_min 1000000000000.0
explored_edges []
queue_edges []
queue_distances []
candidates_edges []
candidates_distances []
dicuments []
all_edges [334, 365, 617, 670, 848, 965, 777, 197, 567, 233]
filtered_edges [334 365 617 670 848 965 777 197 567 233]
distances [1.87577611 1.66493859 2.00731911 1.40790311 1.53691626 3.44371948
 3.10281949 3.53741185 2.57544111 2.70177146]
current_min 1.4079031133791322
current_min_idx 670
--------------------------------------------------
idx 670
distance_min 1.4079031133791322
explored_edges [298]
queue_edges []
queue_distances []
candidates_edges []
candidates_distances []
dicuments []
all_edges [114, 351, 501, 319, 506, 965, 862, 17, 173, 397]
filtered_edges [114 351 501 319 506 965 862  17 173 397]
distances [1.39566629 1.36315269 1.62939074 1.22212865 1.15021392 3.44371948
 4.13253805 2.54338893 3.70646558 4.20630395]
current_min 1.1502139169198593
current_min_idx 506
-----------------

idx 411
distance_min 0.7859687292110517
explored_edges [991, 4, 515, 199, 881, 751, 457, 48, 372, 402, 936, 832, 258, 720, 913, 141]
queue_edges [530, 284, 862, 397]
queue_distances [3.936091354161183, 3.970570225271179, 4.132538050669699, 4.20630394892048]
candidates_edges [751, 372, 936, 832]
candidates_distances [0.43063996162556983, 0.294661984193148, 0.5194628647916658, 0.7564628956146718]
dicuments [991]
all_edges [29, 776, 221, 132, 446, 588, 112, 645, 258, 397]
filtered_edges [ 29 776 221 132 446 588 112 645 397]
distances [0.54268127 1.14621494 1.12623346 0.78694828 1.20243191 3.62695616
 3.63921412 3.75524275 4.20630395]
current_min 0.5426812730275331
current_min_idx 29
--------------------------------------------------
idx 29
distance_min 0.5426812730275331
explored_edges [991, 4, 515, 199, 881, 751, 457, 48, 372, 402, 936, 832, 258, 720, 913, 141, 411]
queue_edges [530, 284, 862, 397]
queue_distances [3.936091354161183, 3.970570225271179, 4.132538050669699, 4.20630394892048

candidates_edges [89, 773, 605, 321, 751, 372]
candidates_distances [0.4753156322316352, 0.3665854458427344, 0.5561491951248854, 0.2155590552705677, 0.43063996162556983, 0.294661984193148]
dicuments [991, 940]
all_edges [221, 308, 141, 50, 878, 56, 258, 645, 198, 112]
filtered_edges [221 308 141  50 878  56 645 198 112]
distances [1.12623346 1.0880446  0.81542387 1.50833012 1.1951593  3.71195037
 3.75524275 3.91011177 3.63921412]
current_min 0.8154238708865008
current_min_idx 141
--------------------------------------------------
idx 141
distance_min 0.8154238708865008
explored_edges [991, 940, 459, 139, 274, 602, 314, 579, 182, 89, 186, 773, 605, 294, 384, 973, 321, 588, 256, 439, 518, 832, 936, 751, 894, 400, 130, 151, 692, 92, 364, 819, 102, 450, 372, 258, 720, 913]
queue_edges [152, 862]
queue_distances [3.9304959949850096, 4.132538050669699]
candidates_edges [89, 773, 605, 321, 751, 372]
candidates_distances [0.4753156322316352, 0.3665854458427344, 0.5561491951248854, 0.2155590552

--------------------------------------------------
idx 889
distance_min 0.32212512809784805
explored_edges [991, 940, 321, 372, 954, 118, 8, 380, 294, 89, 186, 773, 605, 182, 969, 973]
queue_edges [588, 894, 258, 152, 862]
queue_distances [3.6269561566708184, 3.7590080903722005, 3.927173922264074, 3.9304959949850096, 4.132538050669699]
candidates_edges [89, 773, 605]
candidates_distances [0.4753156322316352, 0.3665854458427344, 0.5561491951248854]
dicuments [991, 940, 321, 372]
all_edges [643, 773, 940, 686, 384, 504, 258, 284, 56, 198]
filtered_edges [643 686 384 504 258 284  56 198]
distances [0.65318677 0.66217229 0.643045   3.85345474 3.92717392 3.97057023
 3.71195037 3.91011177]
current_min 0.6430450029674316
current_min_idx 384
Pop queue
--------------------------------------------------
idx 588
distance_min 3.6269561566708184
explored_edges [991, 940, 321, 372, 954, 118, 8, 380, 294, 89, 186, 773, 605, 182, 969, 973, 889]
queue_edges [894, 258, 152, 862]
queue_distances [3.75900

--------------------------------------------------
idx 640
distance_min 2.0553244138767814
explored_edges [991, 940, 321, 372, 889, 75, 515, 199, 881, 751, 457, 48, 936, 402, 165, 583, 832, 258, 720, 913, 141, 411, 29, 219, 304, 530, 370, 46, 911, 539, 116, 284]
queue_edges [862, 397]
queue_distances [4.132538050669699, 4.20630394892048]
candidates_edges [751, 936, 583, 832, 304, 116]
candidates_distances [0.43063996162556983, 0.5194628647916658, 0.6439184013369329, 0.7564628956146718, 0.33375965533537083, 0.4005953215502105]
dicuments [991, 940, 321, 372, 889]
all_edges [221, 678, 354, 930, 251, 897, 397, 112, 862, 152]
filtered_edges [221 678 354 930 251 897 397 112 862 152]
distances [1.12623346 1.98411941 1.29628558 1.45645622 2.33476868 3.49907348
 4.20630395 3.63921412 4.13253805 3.93049599]
current_min 1.1262334563603413
current_min_idx 221
--------------------------------------------------
idx 221
distance_min 1.1262334563603413
explored_edges [991, 940, 321, 372, 889, 75, 515,

candidates_distances [0.43598985225736914, 0.4753156322316352, 0.5124480680090537]
dicuments [991, 940, 321, 372, 889, 304, 773]
all_edges [84, 728, 984, 800, 853, 894, 504, 152, 258, 683]
filtered_edges [ 84 728 800 853 894 504 152 258 683]
distances [1.19176967 0.95666548 1.22561886 1.00631753 3.75900809 3.85345474
 3.93049599 3.92717392 3.72977164]
current_min 0.956665478732596
current_min_idx 728
--------------------------------------------------
idx 728
distance_min 0.956665478732596
explored_edges [991, 940, 321, 372, 889, 304, 773, 677, 572, 494, 661, 185, 117, 969, 973, 182, 89, 593, 759, 20, 198, 918, 984, 412]
queue_edges [152, 530, 862, 397]
queue_distances [3.9304959949850096, 3.936091354161183, 4.132538050669699, 4.20630394892048]
candidates_edges [973, 89, 20]
candidates_distances [0.43598985225736914, 0.4753156322316352, 0.5124480680090537]
dicuments [991, 940, 321, 372, 889, 304, 773]
all_edges [42, 686, 479, 492, 83, 764, 949, 258, 284, 894]
filtered_edges [ 42 686 479

filtered_edges [186 819 969 286 753 284 197 588 152 897]
distances [0.53628651 0.62508917 0.48443788 0.73769889 1.00470917 3.97057023
 3.53741185 3.62695616 3.93049599 3.49907348]
current_min 0.4844378823363078
current_min_idx 969
--------------------------------------------------
idx 969
distance_min 0.4844378823363078
explored_edges [991, 940, 321, 372, 889, 304, 773, 676, 522, 91, 762, 639, 925, 903, 679, 896, 926, 227, 962, 12, 748, 242, 154, 10, 433, 788, 225, 64, 871, 807, 947, 458, 432, 359, 297, 220, 599, 185, 117]
queue_edges [504, 530]
queue_distances [3.853454743495208, 3.936091354161183]
candidates_edges [91, 926, 962, 748, 64, 947]
candidates_distances [1.4615983616430344, 0.5118724402428889, 1.5250762708136427, 1.6512262541827483, 0.4412827576688875, 1.4489231593974454]
dicuments [991, 940, 321, 372, 889, 304, 773, 676]
all_edges [117, 23, 605, 773, 973, 284, 152, 844, 198, 645]
filtered_edges [ 23 605 973 284 152 844 198 645]
distances [0.82711028 0.5561492  0.43598985 3

array([991, 940, 321, 372, 889, 304, 773, 676, 973, 219])

In [81]:
idxs

array([991, 940, 321, 372, 889, 304, 773, 676, 973, 219])

In [82]:
import scipy.spatial.distance as dist

In [83]:
all_dist = dist.cdist(query, docs).reshape(-1)
all_dist.argsort()[:10]

array([991, 940, 321, 372, 889, 304, 219, 773, 676, 116])

In [84]:
[id_ for id_ in idxs if id_ not in all_dist.argsort()[:10]]

[973]

In [None]:
dict(zip(graph[949], dist.cdist(query, docs[graph[949]]).reshape(-1)))

In [None]:
dict(zip(graph[597], dist.cdist(query, docs[graph[597]]).reshape(-1)))

In [None]:
dict(zip(graph[826], dist.cdist(query, docs[graph[826]]).reshape(-1)))