In [1]:
import torch
from annoy import AnnoyIndex
import random
import networkx as nx
import numpy as np
import bisect
from scipy.spatial.distance import pdist

In [2]:
import h5py
filename = './data/glove-25-angular.hdf5'
f = h5py.File(filename, 'r')

# List all groups
# print("Keys: %s" % f.keys())
# a_group_key = list(f.keys())[0]

# Get the data
train_data = list(f['train'])
test_data = list(f['test'])

  from ._conv import register_converters as _register_converters


In [3]:
# conf
d = 25

In [4]:
G = nx.DiGraph()
G.add_edge(0, 1)
G.nodes[0]['val'] = np.random.random(d)
G.nodes[1]['val'] = np.random.random(d)

In [5]:
def dist(a, b):
    #return np.linalg.norm(a-b)
    return pdist([a, b], "euclidean")

In [6]:
class Results:
    def __init__(self, max_len=16):
        self.data = []
        self.max_len = max_len
        
    def insert(self, x):
        bisect.insort(self.data, x)
        if self.max_len and len(self.data) > self.max_len:
            del[self.data[-1]]

In [38]:
def knn_search(query, num_restarts=5, max_results=4, max_greedy_steps=10):
    candidates = Results(max_len=5)
    visited = set([])
    results = Results(max_len=None)
    
    for i in range(num_restarts):
        random_entry_point = random.randint(0, len(G.nodes) - 1)
        candidates.insert((dist(query, G.nodes[random_entry_point]['val']), random_entry_point))
        tempResults = Results(max_len=None)
        #TODO: move candidate selection out of the loop to break local minima
        for _ in range(max_greedy_steps):
            if len(candidates.data) > 0:
                best_candidate_val, best_candidate = candidates.data[0]
                del candidates.data[0]
                if len(tempResults.data) >= max_results and best_candidate_val > tempResults.data[-1][0]:
                    break
                for n in G.neighbors(best_candidate):
                    if n not in visited:
                        candidates.insert((dist(G.nodes[n]['val'], query), n))
                        tempResults.insert((dist(G.nodes[n]['val'], query), n))
                        visited.add(n)
        for val, node in tempResults.data:
            results.insert((val, node))
    return results

In [39]:
def insert(idx, val):
    id = idx
    results = knn_search(val, max_greedy_steps=20, max_results=50, num_restarts=50)
    for v, n in results.data[:50]:
        G.add_edge(id, n)
        G.add_edge(n, id)
        G.nodes[id]['val'] = val

In [40]:
def prune():
    for n in G.nodes:
        nhbrs = sorted([(dist(G.nodes[n]['val'], G.nodes[nbhr]['val']), nbhr) for nbhr in G.neighbors(n)])
        nhbrs = nhbrs[:10] + random.sample(nhbrs[10:], min(40, len(nhbrs[10:])))
        nhbrs_ids = [i[1] for i in nhbrs]
        for nhbr in list(G.neighbors(n)):
            if nhbr not in nhbrs_ids:
                G.remove_edge(n, nhbr)
        
            

In [41]:
t = AnnoyIndex(d, metric="euclidean")  # Length of item vector that will be indexed

In [44]:
%%timeit -n 1 -r 1
for idx, p in enumerate(test_data):
    if idx % 500 == 0:
        print (idx)
    insert(idx, p)
    t.add_item(idx, p)


0
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
7000
7500
8000
8500
9000
9500
1h 13min 3s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [58]:
prune()
#t.build(10) # 10 tree

In [63]:
import torch
from torch import nn
device = torch.device('cpu')

In [64]:
t1 = torch.FloatTensor([1,2,3])
t2 = torch.FloatTensor([1,2])
torch.cat([t1, t2])

tensor([1., 2., 3., 1., 2.])

In [65]:
class Router(nn.Module):
    def __init__(self):
        super(Router, self).__init__()
        self.embedding = nn.Embedding(len(G.nodes),8)
        self.encoder = torch.nn.Sequential(
          torch.nn.Linear(d, 8),
          torch.nn.ReLU(),
          torch.nn.Linear(8, 8),
          torch.nn.ReLU()
        ).to(device)
        self.route_net = torch.nn.Sequential(
          torch.nn.Linear(16, 8),
          torch.nn.ReLU(),
          torch.nn.Linear(8, 50),
          torch.nn.Softmax()
        ).to(device)
        
    def forward(self, node_id_list, query):
        emd = self.embedding(torch.LongTensor(node_id_list))
        encoded_query = self.encoder(torch.FloatTensor(query).view(1, -1))
        inp = torch.cat((emd.view(-1), encoded_query.view(-1)))
        pred_dir = self.route_net(inp)
        return pred_dir

In [66]:
r = Router()

In [67]:
nodes = list(G.nodes(data=False))

In [68]:
def get_w_nighbours(node_id, query):
    t = torch.FloatTensor([-dist(query, G.nodes[n]['val']) for n in G.neighbors(node_id)])
    return torch.softmax(t, 0)
    

In [143]:
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(r.parameters(), lr=1e-3)

for i in range(1):
    random.shuffle(nodes)
    r.route_net.requires_grad = True
    for idx, random_entry in enumerate(nodes):
        if idx > 1000:
            r.route_net.requires_grad = False
        try:
            for j in range(3):
                random_point = G.nodes[random.sample(nodes, 1)[0]]['val']
                y_pred = r([random_entry], random_point)
                loss = loss_fn(y_pred.view(1, -1), get_w_nighbours(random_entry, query=random_point))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if idx % 1000 == 0:
                print(loss.item())
        except:
            pass

  input = module(input)
  "Please ensure they have the same size.".format(target.size(), input.size()))


0.09693653136491776
0.09824513643980026
0.0980229303240776
0.09852825105190277
0.09840621054172516
0.0984981432557106
0.09746470302343369


  "Please ensure they have the same size.".format(target.size(), input.size()))


0.0976511687040329
0.09757306426763535
0.09762583673000336
0.09790787845849991
0.0978493019938469


In [144]:
def neural_knn(query, num_restarts=5, max_greedy_steps=10):
    results = Results(max_len=None)
    visited = set([])
    for i in range(num_restarts):
        current_point = random.randint(0, len(G.nodes) - 1)
        for j in range(max_greedy_steps):
            pred = r([current_point], query)
            nhbrs = list(G.neighbors(current_point))
            next_point_idx = [int(k) for k in torch.argsort(pred.view(-1), 0) if nhbrs[int(k)] not in visited]
            if not next_point_idx:
                break
            next_point = nhbrs[next_point_idx[0]]
            #this does not work, may be wrong decisions are rectified while proceeding withe algorithm
#             if dist(G.nodes[current_point]['val'], query) - dist(G.nodes[next_point]['val'], query) > 0.2:
#                 break
            current_point = next_point
            results.insert((dist(G.nodes[next_point]['val'], query), next_point))
            visited.add(next_point)
    return results

In [161]:
def neural_knn_(query, num_restarts=5, max_results=4, max_greedy_steps=10):
    candidates = Results(max_len=5)
    visited = set([])
    results = Results(max_len=None)
    
    for i in range(num_restarts):
        random_entry_point = random.randint(0, len(G.nodes) - 1)
        candidates.insert((dist(query, G.nodes[random_entry_point]['val']), random_entry_point))
        tempResults = Results(max_len=None)
        #TODO: move candidate selection out of the loop to break local minima
        for _ in range(max_greedy_steps):
            if len(candidates.data) > 0:
                best_candidate_val, best_candidate = candidates.data[0]
                del candidates.data[0]
                visited.add(best_candidate)
                
                if len(tempResults.data) >= max_results and best_candidate_val > tempResults.data[-1][0]:
                    break
                    
                pred = r([best_candidate], query)
                nhbrs = list(G.neighbors(best_candidate))
                next_point_idx = [int(k) for k in torch.argsort(pred.view(-1), 0) if nhbrs[int(k)] not in visited]
                to_add = [nhbrs[n] for n in next_point_idx[:5]]
                for n in to_add:
                    if n not in visited:
                        candidates.insert((dist(G.nodes[n]['val'], query), n))
                        tempResults.insert((dist(G.nodes[n]['val'], query), n))
                        visited.add(n)
        for val, node in tempResults.data:
            results.insert((val, node))
    return results

In [162]:
random_point = random.sample(test_data, 1)[0]

In [177]:
#%%timeit -n 1 -r 1
neural_knn(random_point, max_greedy_steps=10, num_restarts=50).data[:10]

  input = module(input)


[(array([4.13469063]), 10746),
 (array([5.13280338]), 10974),
 (array([5.33629806]), 10060),
 (array([5.34439976]), 6552),
 (array([5.44745123]), 10410),
 (array([5.45162408]), 11037),
 (array([5.46896645]), 7545),
 (array([5.47235525]), 3125),
 (array([5.48701356]), 2048),
 (array([5.53063763]), 2166)]

In [179]:
#%%timeit -n 1 -r 1
neural_knn_(random_point, max_greedy_steps=100, max_results=20, num_restarts=5).data[:10]

  input = module(input)


[(array([0.]), 4067),
 (array([4.13469063]), 10746),
 (array([4.48703518]), 2335),
 (array([4.58762304]), 2042),
 (array([4.68207971]), 10187),
 (array([4.69303712]), 3855),
 (array([4.75208318]), 7042),
 (array([4.83061656]), 8876),
 (array([4.84047982]), 3435),
 (array([4.86146194]), 925)]

In [181]:
#%%timeit -n 1 -r 1
knn_search(random_point, max_greedy_steps=10, max_results=20, num_restarts=5).data[:10]


[(array([0.]), 4067),
 (array([4.13469063]), 10746),
 (array([4.48703518]), 2335),
 (array([4.58762304]), 2042),
 (array([4.68207971]), 10187),
 (array([4.68427814]), 2054),
 (array([4.69303712]), 3855),
 (array([4.75208318]), 7042),
 (array([4.79639161]), 2974),
 (array([4.83061656]), 8876)]

In [174]:
#%%timeit -n 1 -r 1
sorted([(dist(G.nodes[n]['val'], random_point), n) for n in G.nodes])[:10]

[(array([0.]), 4067),
 (array([4.13469063]), 10746),
 (array([4.48703518]), 2335),
 (array([4.58762304]), 2042),
 (array([4.68207971]), 10187),
 (array([4.68427814]), 2054),
 (array([4.69303712]), 3855),
 (array([4.75208318]), 7042),
 (array([4.79639161]), 2974),
 (array([4.83061656]), 8876)]

In [184]:
#%%timeit -n 1 -r 1
t.get_nns_by_vector(random_point,10 , search_k=-1, include_distances=True)

([2984, 9663, 1252, 9104, 1891, 7202, 1365, 582, 3378, 5467],
 [0.0,
  4.134690761566162,
  4.487035274505615,
  4.682079315185547,
  4.796391487121582,
  4.887374401092529,
  5.054071426391602,
  5.093081951141357,
  5.236557960510254,
  5.268767356872559])