In [None]:
import numpy as np
import scipy.stats

In [None]:
class H1:
    
    def __init__(self, epsilon):
        self.epsilon = epsilon
        self.b = np.random.uniform(0.0, epsilon)
    
    def perform_hash(self, x):
        return np.floor((x + self.b) / self.epsilon).astype('int32')
    
    
class H2:
    
    def __init__(self, N, cH):
        self.cHN = N * cH

    def perform_hash(self, sample):

        str_sample = str(sample)
        native_hash = hash(str_sample)
        result = np.mod(native_hash, self.cHN)
        
        return result
    

class EDGE:
    
    def __init__(self):
        self.N = X.shape[0]
        self.cH = 4
        self.epsilon = 0.008
        self.n_buckets = X.shape[0] * self.cH
        
        self.h1 = H1(self.epsilon)
        self.h2 = H2(self.N, self.cH)

    def _g(self, x):
        result = np.zeros_like(x).astype('float32')
        result[x != 0.0] = x[x != 0.0] * np.log(x[x != 0.0])
        return result

    def _count_collisions(self, X, Y):

        counts_i = np.zeros(self.n_buckets).astype('int32')
        counts_j = np.zeros(self.n_buckets).astype('int32')
        counts_ij = np.zeros((self.n_buckets, self.n_buckets)).astype('int32')
        for k in range(self.N):
            h_x = self.h2.perform_hash(self.h1.perform_hash(X[k]))
            h_y = self.h2.perform_hash(self.h1.perform_hash(Y[k]))
            counts_i[h_x] += 1
            counts_j[h_y] += 1
            counts_ij[h_x, h_y] += 1
        
        return counts_i, counts_j, counts_ij
  
    def _compute_edge_weights(self, counts_i, counts_j, counts_ij):
        w_i = counts_i / self.N
        w_j = counts_j / self.N

        # this will cause division by zero warnings
        w_ij = counts_ij * self.N / (counts_i * counts_j) 
        w_ij[np.isinf(w_ij)] = 0  # workaround
        w_ij[np.isnan(w_ij)] = 0  # workaround
        
        return w_i, w_j, w_ij

    def estimate_mi(self, X, Y):
        
        counts_i, counts_j, counts_ij = self._count_collisions(X, Y)
        w_i, w_j, w_ij = self._compute_edge_weights(counts_i, counts_j, counts_ij)
        
        g_applied = self._g(w_ij)
        # lower bound # used bins for Y
        used_bins_y = np.sum(counts_j[counts_j != 0])
        U = np.ones_like(g_applied) * used_bins_y

        stacked = np.stack([g_applied, U])
        g_schlange = np.max(stacked, axis=0)

        nonzero = np.nonzero(w_ij)
        MI = 0
        for idx in range(len(nonzero[0])):
            i_idx = nonzero[0][idx]
            j_idx = nonzero[1][idx]
            MI += w_i[i_idx] * w_j[j_idx] * g_schlange[i_idx, j_idx]
            
        return MI

In [None]:
import time

In [None]:
X = scipy.stats.norm.rvs(size=(1000, 10))  # N x dims
Y = scipy.stats.norm.rvs(size=(1000, 1))  # N x dims

estimator = EDGE()

start = time.time()
MI = estimator.estimate_mi(X,Y)
end = time.time()
print(end-start)

print(MI)