In [1]:
%load_ext autoreload
%autoreload 2
import torch 
torch.set_default_dtype(torch.float64)

In [2]:
# experiment 1:
# first, creating a set of N randomly generated (m, 1) vectors:
import numpy as np
m = 10
N = 1000
# vectors = [np.random.rand(*(m, 1)) for _ in range(N)] # uniform distribution between 0 and 1
vectors = [np.random.randn(*(m, 1)) for _ in range(N)] # normal dist so can get negative vals too
# print (vectors)

In [3]:
# finding nearest neighbor of each vector using O(N^2) brute force approach:
'''
nearest_neighbors = {} # the keys will be the indices of the vectors from 0 to N-1, and the corresponding values will be the indices of its nearest neighbor
for i in range(len(vectors)):
    nearest_dist = float('inf')
    nearest_neighbors[i] = None
    for j in range(len(vectors)):
        if j != i:
            dist = np.linalg.norm(vectors[i] - vectors[j])
            if dist < nearest_dist:
                nearest_dist = dist
                nearest_neighbors[i] = j
print (nearest_neighbors) 
'''
# more efficient approach (although technically same big O runtime complexiity i think cuz N^2 vs N(N+1)/2) which will also allow us to find top 5 nearest neighbors easily
pairwise_distances = {} # {0: {1: 0.1, 2: 0.3, 3: 0.6}, 1: {0: 0.1, 2: 0.7, 3: 0.8}}
for i in range(len(vectors)):
    for j in range(i+1, len(vectors)):
        dist = np.linalg.norm(vectors[i] - vectors[j])
        #pairwise_distances[(i, j)] = dist
        try:
            pairwise_distances[i][j] = dist
        except:
            pairwise_distances[i] = {j: dist}
        try:
            pairwise_distances[j][i] = dist
        except:
            pairwise_distances[j] = {i: dist}
nearest_neighbors = {}
five_nearest_neighbors = {}
ten_nearest_neighbors = {}
for i in list(pairwise_distances.keys()):
    nearest_neighbors[i] = min(pairwise_distances[i], key=pairwise_distances[i].get)
    five_nearest_neighbors[i] = list(dict(sorted(pairwise_distances[i].items(), key=lambda item: item[1])).keys())[:5]
    ten_nearest_neighbors[i] = list(dict(sorted(pairwise_distances[i].items(), key=lambda item: item[1])).keys())[:10]
print (nearest_neighbors)
print (five_nearest_neighbors)
print (ten_nearest_neighbors)

{0: 581, 1: 437, 2: 644, 3: 30, 4: 792, 5: 602, 6: 388, 7: 44, 8: 342, 9: 684, 10: 893, 11: 604, 12: 227, 13: 201, 14: 307, 15: 89, 16: 369, 17: 157, 18: 976, 19: 963, 20: 152, 21: 675, 22: 548, 23: 580, 24: 921, 25: 425, 26: 228, 27: 362, 28: 216, 29: 98, 30: 7, 31: 503, 32: 721, 33: 921, 34: 229, 35: 468, 36: 794, 37: 923, 38: 123, 39: 395, 40: 745, 41: 542, 42: 750, 43: 684, 44: 7, 45: 724, 46: 703, 47: 310, 48: 290, 49: 324, 50: 532, 51: 500, 52: 486, 53: 690, 54: 927, 55: 125, 56: 684, 57: 504, 58: 671, 59: 173, 60: 280, 61: 803, 62: 107, 63: 981, 64: 68, 65: 964, 66: 344, 67: 331, 68: 363, 69: 806, 70: 676, 71: 932, 72: 265, 73: 986, 74: 737, 75: 632, 76: 541, 77: 309, 78: 114, 79: 572, 80: 393, 81: 519, 82: 922, 83: 992, 84: 772, 85: 702, 86: 366, 87: 365, 88: 164, 89: 225, 90: 10, 91: 490, 92: 73, 93: 168, 94: 275, 95: 296, 96: 273, 97: 675, 98: 609, 99: 463, 100: 222, 101: 637, 102: 391, 103: 683, 104: 280, 105: 963, 106: 622, 107: 579, 108: 550, 109: 483, 110: 236, 111: 190, 

In [4]:
# finding nearest neighbor of each vector using NaiveLSH:
from memristor.engine.model import NaiveLSH
from memristor.crossbar.model import LineResistanceCrossbar
from memristor.devices import StaticMemristor
# naive_lsh = NaiveLSH(
#     hash_size=10, # adjustable hyperparameter
#     crossbar_class=LineResistanceCrossbar,
#     crossbar_params={'r_wl': 20, 'r_bl': 20, 'r_in':10, 'r_out':10, 'V_SOURCE_MODE':'|_|'},
#     memristor_model_class=StaticMemristor,
#     memristor_params={'frequency': 1e8, 'temperature': 273 + 40},
#     m=m,
#     r=1, # adjustable hyperparameter
# )
reps = 3 # adjustable hyperparameter (repetitions of the hashing)
all_bins = []
for _ in range(reps):
    naive_lsh = NaiveLSH(
        hash_size=5, # adjustable hyperparameter
        crossbar_class=LineResistanceCrossbar,
        crossbar_params={'r_wl': 20, 'r_bl': 20, 'r_in':10, 'r_out':10, 'V_SOURCE_MODE':'|_|'},
        memristor_model_class=StaticMemristor,
        memristor_params={'frequency': 1e8, 'temperature': 273 + 40},
        m=m,
        r=1, # adjustable hyperparameter
    )
    bins = {}
    for i in range(len(vectors)):
        hash = naive_lsh.inference(vectors[i])
        # print (hash)
        if hash not in bins.keys():
            bins[hash] = [i]
        else:
            bins[hash].append(i)
    for bin in list(bins.values()):
        all_bins.append(bin)
#      {010:[1,5,7], 111:[5,6,7]}

# now at this point all_bins is a list like [[1,2], [1,3,5], ... ] where each element of all_bins is a bin containing indices of vectors that are likely
# to be close to each other. so now to find the nearest neighbor for each vector, we simply iterate through and check only those vectors that share a bin
# with it, so in this case for 1 we would check 2, 3, and 5, to find the nearest neighbor
nearest_neighbors_approx = {}
for i in range(len(vectors)):
    nearest_dist = float('inf')
    nearest_neighbors_approx[i] = None
    for bin in all_bins:
        if len(bin) > N/reps:
            continue
        if i in bin:
            for j in bin:
                if j != i:
                    dist = np.linalg.norm(vectors[i] - vectors[j])
                    if dist < nearest_dist:
                        nearest_dist = dist 
                        nearest_neighbors_approx[i] = j
print (nearest_neighbors_approx)

  self.fitted_w = torch.tensor([[self.memristors[i][j].g_linfit for j in range(ideal_w.shape[1])]


{0: 467, 1: 437, 2: 155, 3: 30, 4: 792, 5: 576, 6: 511, 7: 44, 8: 342, 9: 597, 10: 893, 11: 604, 12: 227, 13: 201, 14: 307, 15: 89, 16: 834, 17: 450, 18: 976, 19: 963, 20: 802, 21: 675, 22: 248, 23: 580, 24: 921, 25: 425, 26: 113, 27: 419, 28: 216, 29: 98, 30: 7, 31: 443, 32: 476, 33: 24, 34: 229, 35: 468, 36: 943, 37: 923, 38: 123, 39: 395, 40: 745, 41: 542, 42: 750, 43: 684, 44: 7, 45: 586, 46: 703, 47: 310, 48: 290, 49: 324, 50: 532, 51: 484, 52: 486, 53: 690, 54: 346, 55: 125, 56: 43, 57: 504, 58: 97, 59: 173, 60: 461, 61: 803, 62: 107, 63: 258, 64: 68, 65: 907, 66: 344, 67: 331, 68: 363, 69: 310, 70: 676, 71: 932, 72: 0, 73: 986, 74: 737, 75: 65, 76: 541, 77: 966, 78: 49, 79: 572, 80: 393, 81: 519, 82: 578, 83: 992, 84: 772, 85: 702, 86: 366, 87: 365, 88: 164, 89: 225, 90: 682, 91: 490, 92: 73, 93: 168, 94: 275, 95: 296, 96: 273, 97: 675, 98: 609, 99: 463, 100: 222, 101: 637, 102: 391, 103: 683, 104: 280, 105: 963, 106: 622, 107: 62, 108: 550, 109: 483, 110: 236, 111: 463, 112: 18

In [5]:
print (nearest_neighbors)
print (nearest_neighbors_approx)

{0: 581, 1: 437, 2: 644, 3: 30, 4: 792, 5: 602, 6: 388, 7: 44, 8: 342, 9: 684, 10: 893, 11: 604, 12: 227, 13: 201, 14: 307, 15: 89, 16: 369, 17: 157, 18: 976, 19: 963, 20: 152, 21: 675, 22: 548, 23: 580, 24: 921, 25: 425, 26: 228, 27: 362, 28: 216, 29: 98, 30: 7, 31: 503, 32: 721, 33: 921, 34: 229, 35: 468, 36: 794, 37: 923, 38: 123, 39: 395, 40: 745, 41: 542, 42: 750, 43: 684, 44: 7, 45: 724, 46: 703, 47: 310, 48: 290, 49: 324, 50: 532, 51: 500, 52: 486, 53: 690, 54: 927, 55: 125, 56: 684, 57: 504, 58: 671, 59: 173, 60: 280, 61: 803, 62: 107, 63: 981, 64: 68, 65: 964, 66: 344, 67: 331, 68: 363, 69: 806, 70: 676, 71: 932, 72: 265, 73: 986, 74: 737, 75: 632, 76: 541, 77: 309, 78: 114, 79: 572, 80: 393, 81: 519, 82: 922, 83: 992, 84: 772, 85: 702, 86: 366, 87: 365, 88: 164, 89: 225, 90: 10, 91: 490, 92: 73, 93: 168, 94: 275, 95: 296, 96: 273, 97: 675, 98: 609, 99: 463, 100: 222, 101: 637, 102: 391, 103: 683, 104: 280, 105: 963, 106: 622, 107: 579, 108: 550, 109: 483, 110: 236, 111: 190, 

In [6]:
print (nearest_neighbors == nearest_neighbors_approx)
count = 0
cnt = 0
cn = 0
for i in range(N):
    if nearest_neighbors_approx[i] == nearest_neighbors[i]:
        count += 1
    if nearest_neighbors_approx[i] in five_nearest_neighbors[i]:
        cnt += 1
    if nearest_neighbors_approx[i] in ten_nearest_neighbors[i]:
        cn += 1    
accuracy = count/N
top5_accuracy = cnt/N
top10_accuracy = cn/N
print (accuracy)
print (top5_accuracy)
print (top10_accuracy)

False
0.684
0.97
0.993


In [7]:
print (len(all_bins))
print (all_bins)
print ([len(bin) for bin in all_bins])
print ([len(bin) for bin in all_bins if len(bin) <= N/reps])

93
[[0, 26, 31, 44, 45, 71, 72, 81, 83, 94, 108, 113, 201, 203, 217, 239, 259, 268, 284, 300, 310, 331, 339, 352, 358, 359, 374, 377, 400, 403, 422, 423, 435, 440, 442, 449, 455, 464, 467, 468, 487, 488, 519, 523, 546, 550, 558, 602, 612, 646, 651, 654, 688, 703, 708, 711, 712, 722, 727, 737, 748, 761, 767, 769, 775, 776, 782, 785, 786, 806, 817, 819, 835, 844, 847, 893, 895, 907, 914, 918, 921, 930, 945, 995], [1, 80, 87, 131, 149, 165, 178, 186, 231, 365, 375, 437, 514, 518, 534, 563, 700, 714, 749, 750, 803, 809, 822, 832, 851, 856, 902, 944], [2, 23, 158, 198, 204, 209, 234, 283, 313, 319, 321, 345, 417, 497, 589, 606, 653, 662, 674, 679, 680, 718, 721, 833, 850, 854, 866, 908, 985], [3, 61, 78, 84, 90, 128, 153, 194, 206, 223, 334, 351, 384, 397, 402, 447, 453, 457, 526, 532, 568, 607, 645, 666, 739, 772, 777, 788, 813, 831, 867, 874, 882, 911, 929, 950, 954, 957, 961, 968, 972, 974, 975, 982], [4, 50, 95, 232, 304, 399, 709], [5, 39, 52, 66, 86, 117, 122, 170, 202, 213, 255, 260,

In [8]:
# when experimenting with like a bigger dataset and stuff come up with  a metric to compare these 2 dicts
# also compare the runtime complexities, cuz its possible that its working so well because of sth wrong in the implementation whiich results in the runtime
# just being the same as the brute force method

In [9]:
'''
things to check: 
- sizes of the bins, should not be big because otherwise there would be no runtime improvement -> get metrics for reduction of search space/runtime or space complexity
- varying the parameters like N, m, hash size, etc. (try bigger/more data)
- using lineres_memristive_vmm not naive_memristive_vmm
- experiment with varying non-idealities
- experiment 2, create visualizations
- write section 4 of the paper    (rn highest priority is finalizing the experiment and section 4 of the paper)
- is the change made in StaticMemristor fine?
'''

'\nthings to check: \n- sizes of the bins, should not be big because otherwise there would be no runtime improvement -> get metrics for reduction of search space/runtime or space complexity\n- varying the parameters like N, m, hash size, etc. (try bigger/more data)\n- using lineres_memristive_vmm not naive_memristive_vmm\n- experiment with varying non-idealities\n- experiment 2, create visualizations\n- write section 4 of the paper    (rn highest priority is finalizing the experiment and section 4 of the paper)\n- is the change made in StaticMemristor fine?\n'