In [1]:
%load_ext autoreload
%autoreload 2
import torch 
torch.set_default_dtype(torch.float64)

In [2]:
# experiment 1:
# first, creating a set of N randomly generated (m, 1) vectors:
import numpy as np
m = 10
N = 10000
# vectors = [np.random.rand(*(m, 1)) for _ in range(N)] # uniform distribution between 0 and 1
vectors = [np.random.randn(*(m, 1)) for _ in range(N)] # normal dist so can get negative vals too
# print (vectors)

In [3]:
# finding nearest neighbor of each vector using O(N^2) brute force approach:
'''
nearest_neighbors = {} # the keys will be the indices of the vectors from 0 to N-1, and the corresponding values will be the indices of its nearest neighbor
for i in range(len(vectors)):
    nearest_dist = float('inf')
    nearest_neighbors[i] = None
    for j in range(len(vectors)):
        if j != i:
            dist = np.linalg.norm(vectors[i] - vectors[j])
            if dist < nearest_dist:
                nearest_dist = dist
                nearest_neighbors[i] = j
print (nearest_neighbors) 
'''
# more efficient approach (although technically same big O runtime complexiity i think cuz N^2 vs N(N+1)/2) which will also allow us to find top 5 nearest neighbors easily
pairwise_distances = {} # {0: {1: 0.1, 2: 0.3, 3: 0.6}, 1: {0: 0.1, 2: 0.7, 3: 0.8}}
for i in range(len(vectors)):
    for j in range(i+1, len(vectors)):
        dist = np.linalg.norm(vectors[i] - vectors[j])
        #pairwise_distances[(i, j)] = dist
        try:
            pairwise_distances[i][j] = dist
        except:
            pairwise_distances[i] = {j: dist}
        try:
            pairwise_distances[j][i] = dist
        except:
            pairwise_distances[j] = {i: dist}
nearest_neighbors = {}
five_nearest_neighbors = {}
ten_nearest_neighbors = {}
for i in list(pairwise_distances.keys()):
    nearest_neighbors[i] = min(pairwise_distances[i], key=pairwise_distances[i].get)
    five_nearest_neighbors[i] = list(dict(sorted(pairwise_distances[i].items(), key=lambda item: item[1])).keys())[:5]
    ten_nearest_neighbors[i] = list(dict(sorted(pairwise_distances[i].items(), key=lambda item: item[1])).keys())[:10]
print (nearest_neighbors)
print (five_nearest_neighbors)
print (ten_nearest_neighbors)

{0: 5536, 1: 9684, 2: 2992, 3: 8346, 4: 4229, 5: 3752, 6: 9447, 7: 7494, 8: 8860, 9: 9886, 10: 3035, 11: 7883, 12: 5976, 13: 8782, 14: 9113, 15: 7345, 16: 907, 17: 1962, 18: 8426, 19: 7348, 20: 1767, 21: 2584, 22: 9224, 23: 1303, 24: 9799, 25: 764, 26: 8730, 27: 9490, 28: 7471, 29: 2839, 30: 438, 31: 6491, 32: 2212, 33: 6068, 34: 2667, 35: 7515, 36: 851, 37: 1873, 38: 6670, 39: 6793, 40: 5982, 41: 7259, 42: 5204, 43: 7573, 44: 6187, 45: 68, 46: 1616, 47: 1530, 48: 2957, 49: 3355, 50: 3872, 51: 5755, 52: 6120, 53: 4610, 54: 9593, 55: 7815, 56: 8796, 57: 2214, 58: 6272, 59: 5813, 60: 2404, 61: 3650, 62: 9857, 63: 3332, 64: 4726, 65: 4031, 66: 2228, 67: 3439, 68: 8995, 69: 3645, 70: 1301, 71: 1188, 72: 666, 73: 9911, 74: 4327, 75: 8490, 76: 1907, 77: 6195, 78: 6076, 79: 5457, 80: 5968, 81: 1496, 82: 7378, 83: 2034, 84: 9850, 85: 269, 86: 9092, 87: 137, 88: 3442, 89: 6272, 90: 3218, 91: 568, 92: 6678, 93: 2384, 94: 9430, 95: 6754, 96: 6277, 97: 3634, 98: 3962, 99: 956, 100: 8253, 101: 9234

In [4]:
# finding nearest neighbor of each vector using NaiveLSH:
from memristor.engine.model import NaiveLSH
from memristor.crossbar.model import LineResistanceCrossbar
from memristor.devices import StaticMemristor
# naive_lsh = NaiveLSH(
#     hash_size=10, # adjustable hyperparameter
#     crossbar_class=LineResistanceCrossbar,
#     crossbar_params={'r_wl': 20, 'r_bl': 20, 'r_in':10, 'r_out':10, 'V_SOURCE_MODE':'|_|'},
#     memristor_model_class=StaticMemristor,
#     memristor_params={'frequency': 1e8, 'temperature': 273 + 40},
#     m=m,
#     r=1, # adjustable hyperparameter
# )
reps = 3 # adjustable hyperparameter (repetitions of the hashing)
all_bins = []
for _ in range(reps):
    naive_lsh = NaiveLSH(
        hash_size=5, # adjustable hyperparameter
        crossbar_class=LineResistanceCrossbar,
        crossbar_params={'r_wl': 20, 'r_bl': 20, 'r_in':10, 'r_out':10, 'V_SOURCE_MODE':'|_|'},
        memristor_model_class=StaticMemristor,
        memristor_params={'frequency': 1e8, 'temperature': 273 + 40},
        m=m,
        r=1, # adjustable hyperparameter
    )
    bins = {}
    for i in range(len(vectors)):
        hash = naive_lsh.inference(vectors[i])
        # print (hash)
        if hash not in bins.keys():
            bins[hash] = [i]
        else:
            bins[hash].append(i)
    for bin in list(bins.values()):
        all_bins.append(bin)
#      {010:[1,5,7], 111:[5,6,7]}

# now at this point all_bins is a list like [[1,2], [1,3,5], ... ] where each element of all_bins is a bin containing indices of vectors that are likely
# to be close to each other. so now to find the nearest neighbor for each vector, we simply iterate through and check only those vectors that share a bin
# with it, so in this case for 1 we would check 2, 3, and 5, to find the nearest neighbor
nearest_neighbors_approx = {}
for i in range(len(vectors)):
    nearest_dist = float('inf')
    nearest_neighbors_approx[i] = None
    for bin in all_bins:
        if len(bin) > N/reps:
            continue
        if i in bin:
            for j in bin:
                if j != i:
                    dist = np.linalg.norm(vectors[i] - vectors[j])
                    if dist < nearest_dist:
                        nearest_dist = dist 
                        nearest_neighbors_approx[i] = j
print (nearest_neighbors_approx)

{0: 5536, 1: 5538, 2: 2992, 3: 1807, 4: 4229, 5: 3752, 6: 9447, 7: 7494, 8: 8860, 9: 9886, 10: 1912, 11: 7883, 12: 5976, 13: 8782, 14: 9113, 15: 7345, 16: 907, 17: 1962, 18: 6355, 19: 9700, 20: 1767, 21: 2584, 22: 9224, 23: 1303, 24: 9799, 25: 764, 26: 8730, 27: 9490, 28: 7471, 29: 2839, 30: 8064, 31: 6491, 32: 2516, 33: 6068, 34: 2667, 35: 7515, 36: 851, 37: 1873, 38: 6670, 39: 6793, 40: 5982, 41: 7259, 42: 5204, 43: 7573, 44: 8400, 45: 68, 46: 4272, 47: 1530, 48: 2957, 49: 3355, 50: 3872, 51: 5755, 52: 6120, 53: 4610, 54: 9593, 55: 7815, 56: 8796, 57: 9821, 58: 6272, 59: 5813, 60: 2404, 61: 3650, 62: 9857, 63: 6371, 64: 4173, 65: 4031, 66: 2228, 67: 3439, 68: 45, 69: 8622, 70: 1301, 71: 1188, 72: 666, 73: 7311, 74: 4327, 75: 8490, 76: 7754, 77: 6195, 78: 6076, 79: 5457, 80: 5968, 81: 1496, 82: 7378, 83: 2034, 84: 9850, 85: 9769, 86: 9092, 87: 9262, 88: 3442, 89: 6272, 90: 3218, 91: 568, 92: 2651, 93: 2384, 94: 9430, 95: 6754, 96: 6277, 97: 3634, 98: 30, 99: 956, 100: 8253, 101: 9234,

In [5]:
print (nearest_neighbors)
print (nearest_neighbors_approx)

{0: 5536, 1: 9684, 2: 2992, 3: 8346, 4: 4229, 5: 3752, 6: 9447, 7: 7494, 8: 8860, 9: 9886, 10: 3035, 11: 7883, 12: 5976, 13: 8782, 14: 9113, 15: 7345, 16: 907, 17: 1962, 18: 8426, 19: 7348, 20: 1767, 21: 2584, 22: 9224, 23: 1303, 24: 9799, 25: 764, 26: 8730, 27: 9490, 28: 7471, 29: 2839, 30: 438, 31: 6491, 32: 2212, 33: 6068, 34: 2667, 35: 7515, 36: 851, 37: 1873, 38: 6670, 39: 6793, 40: 5982, 41: 7259, 42: 5204, 43: 7573, 44: 6187, 45: 68, 46: 1616, 47: 1530, 48: 2957, 49: 3355, 50: 3872, 51: 5755, 52: 6120, 53: 4610, 54: 9593, 55: 7815, 56: 8796, 57: 2214, 58: 6272, 59: 5813, 60: 2404, 61: 3650, 62: 9857, 63: 3332, 64: 4726, 65: 4031, 66: 2228, 67: 3439, 68: 8995, 69: 3645, 70: 1301, 71: 1188, 72: 666, 73: 9911, 74: 4327, 75: 8490, 76: 1907, 77: 6195, 78: 6076, 79: 5457, 80: 5968, 81: 1496, 82: 7378, 83: 2034, 84: 9850, 85: 269, 86: 9092, 87: 137, 88: 3442, 89: 6272, 90: 3218, 91: 568, 92: 6678, 93: 2384, 94: 9430, 95: 6754, 96: 6277, 97: 3634, 98: 3962, 99: 956, 100: 8253, 101: 9234

In [6]:
print (nearest_neighbors == nearest_neighbors_approx)
count = 0
cnt = 0
cn = 0
for i in range(N):
    if nearest_neighbors_approx[i] == nearest_neighbors[i]:
        count += 1
    if nearest_neighbors_approx[i] in five_nearest_neighbors[i]:
        cnt += 1
    if nearest_neighbors_approx[i] in ten_nearest_neighbors[i]:
        cn += 1    
accuracy = count/N
top5_accuracy = cnt/N
top10_accuracy = cn/N
print (accuracy)
print (top5_accuracy)
print (top10_accuracy)

False
0.8391
0.9975
0.9999


In [7]:
print (len(all_bins))
print (all_bins)
print ([len(bin) for bin in all_bins])
print ([len(bin) for bin in all_bins if len(bin) <= N/reps])

96
[[0, 1, 40, 113, 131, 182, 337, 366, 433, 489, 536, 594, 600, 693, 800, 921, 996, 1057, 1144, 1285, 1350, 1418, 1481, 1485, 1519, 1638, 1817, 1915, 1928, 1994, 2011, 2111, 2289, 2308, 2319, 2323, 2472, 2551, 2608, 2622, 2628, 2692, 2693, 2719, 2759, 2815, 2882, 2963, 2970, 2976, 2994, 3017, 3027, 3039, 3055, 3127, 3133, 3240, 3300, 3308, 3335, 3376, 3428, 3459, 3568, 3583, 3616, 3724, 3781, 3798, 3820, 3970, 4097, 4141, 4392, 4404, 4553, 4572, 4684, 4734, 4736, 4751, 4831, 4853, 4907, 4973, 4983, 5177, 5455, 5473, 5481, 5645, 5801, 5830, 5865, 5873, 5923, 5960, 6106, 6138, 6311, 6314, 6344, 6383, 6510, 6577, 6691, 6712, 6753, 6891, 6963, 6981, 7112, 7156, 7283, 7390, 7396, 7635, 7675, 7718, 7737, 8009, 8033, 8099, 8214, 8388, 8587, 8616, 8661, 8708, 8711, 8722, 8929, 8968, 9009, 9043, 9046, 9476, 9534, 9555, 9596, 9616, 9622, 9921, 9948], [2, 94, 109, 124, 130, 149, 209, 219, 225, 248, 268, 321, 335, 356, 359, 397, 398, 408, 423, 480, 517, 539, 540, 561, 607, 618, 673, 690, 695, 708

In [8]:
# when experimenting with like a bigger dataset and stuff come up with  a metric to compare these 2 dicts
# also compare the runtime complexities, cuz its possible that its working so well because of sth wrong in the implementation whiich results in the runtime
# just being the same as the brute force method

In [9]:
'''
things to check: 
- sizes of the bins, should not be big because otherwise there would be no runtime improvement -> get metrics for reduction of search space/runtime or space complexity
- varying the parameters like N, m, hash size, etc. (try bigger/more data)
- using lineres_memristive_vmm not naive_memristive_vmm
- experiment with varying non-idealities
- experiment 2, create visualizations
- write section 4 of the paper    (rn highest priority is finalizing the experiment and section 4 of the paper)
- is the change made in StaticMemristor fine?
'''

'\nthings to check: \n- sizes of the bins, should not be big because otherwise there would be no runtime improvement -> get metrics for reduction of search space/runtime or space complexity\n- varying the parameters like N, m, hash size, etc. (try bigger/more data)\n- using lineres_memristive_vmm not naive_memristive_vmm\n- experiment with varying non-idealities\n- experiment 2, create visualizations\n- write section 4 of the paper    (rn highest priority is finalizing the experiment and section 4 of the paper)\n- is the change made in StaticMemristor fine?\n'