In [1]:
import numpy as np
from npeet.entropy_estimators import mi


def ksg_factory(k=3, pool=None):
    """
    :param k: number of nearest neighbours
    :param pool: optional pooling function (e.g. np.mean or np.max)
    :return: KSG similarity with given k and pooling (if any)
    """

    def ksg(x, y):
        """
        Kraskov–Stogbauer–Grassberger (KSG) estimator of mutual information
        between two sentences represented as word embedding matrices x and y
        :param x: list of word embeddings for the first sentence
        :param y: list of word embeddings for the second sentence
        :return: KSG similarity measure between the two sentences
        """

        if pool is None:
            xT = x.T
            yT = y.T
        else:
            xT = pool(x, axis=0).reshape(-1, 1)
            yT = pool(y, axis=0).reshape(-1, 1)

        return mi(xT, yT, base=np.e, k=k)
    return ksg


ksg3 = ksg_factory(k=3)
ksg10 = ksg_factory(k=10)

mean_ksg10 = ksg_factory(k=10, pool=np.mean)
max_ksg10 = ksg_factory(k=10, pool=np.max)

ModuleNotFoundError: No module named 'npeet'

In [2]:
import warnings

import numpy as np
import numpy.linalg as la
from numpy import log
from scipy.special import digamma
from sklearn.neighbors import BallTree, KDTree


In [3]:
def add_noise(x, intens=1e-10):
    # small noise to break degeneracy, see doc.
    return x + intens * np.random.random_sample(x.shape)

In [4]:
def build_tree(points):
    if points.shape[1] >= 20:
        return BallTree(points, metric="chebyshev")
    return KDTree(points, metric="chebyshev")

In [6]:
def count_neighbors(tree, x, r):
    return tree.query_radius(x, r, count_only=True)

In [7]:
def avgdigamma(points, dvec):
    # This part finds number of neighbors in some radius in the marginal space
    # returns expectation value of <psi(nx)>
    tree = build_tree(points)
    dvec = dvec - 1e-15
    num_points = count_neighbors(tree, points, dvec)
    return np.mean(digamma(num_points))


In [8]:
def lnc_correction(tree, points, k, alpha):
    e = 0
    n_sample = points.shape[0]
    for point in points:
        # Find k-nearest neighbors in joint space, p=inf means max norm
        knn = tree.query(point[None, :], k=k + 1, return_distance=False)[0]
        knn_points = points[knn]
        # Substract mean of k-nearest neighbor points
        knn_points = knn_points - knn_points[0]
        # Calculate covariance matrix of k-nearest neighbor points, obtain eigen vectors
        covr = knn_points.T @ knn_points / k
        _, v = la.eig(covr)
        # Calculate PCA-bounding box using eigen vectors
        V_rect = np.log(np.abs(knn_points @ v).max(axis=0)).sum()
        # Calculate the volume of original box
        log_knn_dist = np.log(np.abs(knn_points).max(axis=0)).sum()

        # Perform local non-uniformity checking and update correction term
        if V_rect < log_knn_dist + np.log(alpha):
            e += (log_knn_dist - V_rect) / n_sample
    return e

In [9]:
def query_neighbors(tree, x, k):
    return tree.query(x, k=k + 1)[0][:, k]

In [10]:
def mi(x, y, z=None, k=3, base=2, alpha=0):
    """Mutual information of x and y (conditioned on z if z is not None)
    x, y should be a list of vectors, e.g. x = [[1.3], [3.7], [5.1], [2.4]]
    if x is a one-dimensional scalar and we have four samples
    """
    assert len(x) == len(y), "Arrays should have same length"
    assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
    x, y = np.asarray(x), np.asarray(y)
    x, y = x.reshape(x.shape[0], -1), y.reshape(y.shape[0], -1)
    x = add_noise(x)
    y = add_noise(y)
    points = [x, y]
    if z is not None:
        z = np.asarray(z)
        z = z.reshape(z.shape[0], -1)
        points.append(z)
    points = np.hstack(points)
    # Find nearest neighbors in joint space, p=inf means max-norm
    tree = build_tree(points)
    dvec = query_neighbors(tree, points, k)
    if z is None:
        a, b, c, d = (
            avgdigamma(x, dvec),
            avgdigamma(y, dvec),
            digamma(k),
            digamma(len(x)),
        )
        if alpha > 0:
            d += lnc_correction(tree, points, k, alpha)
    else:
        xz = np.c_[x, z]
        yz = np.c_[y, z]
        a, b, c, d = (
            avgdigamma(xz, dvec),
            avgdigamma(yz, dvec),
            avgdigamma(z, dvec),
            digamma(k),
        )
    return (-a - b + c + d) / log(base)


In [12]:
print("MI between two independent continuous random variables X and Y:")
np.random.seed(0)
x = np.random.randn(1000, 10)
y = np.random.randn(1000, 3)
print(mi(x, y, base=2, alpha=0))

MI between two independent continuous random variables X and Y:
-0.022484075103376248
