In [83]:
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

In [134]:
class LSH:
    def __init__(self, k, store_digest=False, seed=42):
        """
        
        """
        self.k = k
        self.store_digest = store_digest
        self.seed = seed
        self.unique_classes = 0 
        

    def generate_Wt(self):
        """
        Randomly Generate N points on a d-dimensional sphere
        """
        # For reproducibility 
        np.random.seed(self.seed)
        norm = np.random.normal
        normal_deviates = norm(size=(self.d, self.k))
        radius = np.sqrt((normal_deviates**2).sum(axis=0))
        points = normal_deviates/radius
        self.Wt = points.T
    
    def digest(self, X):
        if self.Wt is None:
            raise("Fit error: please invoke .fit()")
        hashes = ((self.Wt @ X) >=0).astype(int)
        items = list(map(tuple,hashes.T))
        return items 
    
    def transform(self, X):
        """
        X -- d x N matrix 
        """
        mapped = self.digest(X)
        y = []
        for i in mapped:
            if i not in self.hash_table:
                self.hash_table[i] = self.unique_classes
                self.unique_classes +=1 
            class_label = self.hash_table[i]
            y.append(class_label)
        return np.array(y).reshape(len(y), 1)
            

        
    def fit(self, X):
        self.hash_table = {}
        self.d = X.shape[0]
        self.generate_Wt()
        self.transform(X)
            
        

In [135]:
X = np.random.rand(2, 100)
X.shape

(2, 100)

In [136]:
lsh = LSH(5)

In [137]:
lsh.fit(X)

In [138]:
z = lsh.transform(X)
z

array([[0],
       [0],
       [1],
       [2],
       [3],
       [3],
       [3],
       [0],
       [3],
       [2],
       [3],
       [3],
       [0],
       [3],
       [4],
       [0],
       [3],
       [3],
       [3],
       [2],
       [3],
       [2],
       [3],
       [0],
       [3],
       [1],
       [0],
       [3],
       [0],
       [2],
       [3],
       [4],
       [3],
       [3],
       [2],
       [3],
       [3],
       [2],
       [2],
       [3],
       [3],
       [2],
       [0],
       [0],
       [0],
       [1],
       [3],
       [1],
       [3],
       [3],
       [2],
       [3],
       [1],
       [2],
       [3],
       [3],
       [3],
       [4],
       [0],
       [2],
       [3],
       [2],
       [3],
       [0],
       [1],
       [3],
       [2],
       [2],
       [1],
       [0],
       [3],
       [3],
       [3],
       [3],
       [3],
       [3],
       [3],
       [3],
       [3],
       [3],
       [3],
       [3],
       [3],
    

In [140]:
lsh.hash_table

{(0, 1, 1, 0, 1): 0,
 (0, 1, 1, 1, 1): 1,
 (1, 1, 1, 1, 0): 2,
 (1, 1, 1, 1, 1): 3,
 (1, 0, 1, 1, 0): 4}