In [286]:
import numba
import numpy as np
import scanpy as sc
import anndata
import fast_histogram
from collections import Counter

In [26]:
@numba.jit(nopython=True)
def euclidean(x, y):
    return np.sqrt(np.sum((x - y)**2))

In [288]:
preprocessed_results = '../../test_data/inputs/10x/PBMC/3k/pre-processed/pbmc3k_preprocessed.h5ad'
adata = anndata.read_h5ad(preprocessed_results)
frame = adata.to_df()
frame.shape

(2496, 10499)

In [297]:
frame_np = frame.to_numpy()
arr1 = frame_np[0,:]
arr2 = frame_np[2068,:]

In [28]:
%timeit euclidean(arr1, arr2)

15.7 µs ± 74.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [269]:
def calc_norm_mi(arr1, arr2, bins, m):
    """ Calculates a normalized mutual information distance D(X, Y) = 1 - I(X, Y)/H(X, Y) using bin-based method

    It takes gene expression data from single cells, and compares them using standard calculation for
    mutual information and joint entropy. It builds a 2d histogram, which is used to calculate P(arr1, arr2).

    Args:
        arr1 (pandas series): gene expression data for cell 1
        arr2 (pandas series): gene expression data for cell 2
        marginals  (ndarray): marginal probability matrix
        index1         (int): index of cell 1
        index2         (int): index of cell 2
        bins           (int): number of bins
        m              (int): number of genes
    Returns:
        a float between 0 and 1
    """
    fq = fast_histogram.histogram2d(arr1, arr2, range=[[arr1.min(), arr1.max()+1e-9], [arr2.min(), arr2.max()+1e-9]],
                                    bins=(bins, bins)) / float(m)
    sm = np.sum(fq * float(m), axis=1)
    tm = np.sum(fq * float(m), axis=0)
    sm = np.asmatrix(sm / float(sm.sum()))
    tm = np.asmatrix(tm / float(tm.sum()))
    sm_tm = np.matmul(np.transpose(sm), tm)

    div = np.divide(fq, sm_tm, where=sm_tm != 0, out=np.zeros_like(fq))
    ent = np.log(div, where=div != 0, out=np.zeros_like(div))
    agg = np.multiply(fq, ent, out=np.zeros_like(fq), where=fq != 0)
    joint_ent = -np.multiply(fq, np.log(fq, where=fq != 0, out=np.zeros_like(fq)),
                             out=np.zeros_like(fq), where=fq != 0).sum()
    return joint_ent - agg.sum()

In [33]:
num_bins = int((frame.shape[0]) ** (1 / 3.0))
num_genes = frame.shape[1]

In [270]:
calc_norm_mi(arr1, arr2, num_bins, num_genes)

0.6575192265042646

In [287]:
@numba.jit(nopython=True, fastmath=True)
def compute_bin(x, min, max, num_bins):
    """ Compute bin index for a give number.
    """
    # special case to mirror NumPy behavior for last bin
    if x == max:
        return num_bins - 1 # a_max always in last bin

    bin = int(num_bins * (x - min) / (max - min))

    if bin < 0 or bin >= num_bins:
        return None
    else:
        return bin

@numba.jit(nopython=True, fastmath=True)
def numba_histogram2d(arr1, arr2, num_bins):
    """ Compute the bi-dimensional histogram of two data samples.
    Args:
        arr1 (array_like, shape (N,)): An array containing the x coordinates of the points to be histogrammed.
        arr2 (array_like, shape (N,)): An array containing the y coordinates of the points to be histogrammed.
        num_bins (int): int
    Return:
        hist (2D ndarray)
    """
    bin_indices1 = np.zeros((arr1.shape[0],), dtype=np.int16)
    min1 = arr1.min()
    max1 = arr1.max()
    for i, x in enumerate(arr1.flat):
        bin_indices1[i] = compute_bin(x, min1, max1, num_bins)

    bin_indices2 = np.zeros((arr2.shape[0],), dtype=np.int16)
    min2 = arr2.min()
    max2 = arr2.max()
    for i, y in enumerate(arr2.flat):
        bin_indices2[i] = compute_bin(y, min2, max2, num_bins)

    hist = np.zeros((num_bins, num_bins), dtype=np.int16)
    for i, b in enumerate(bin_indices1):
        hist[b, bin_indices2[i]] += 1
    return hist

@numba.jit(nopython=True)
def numba_nan_fill(x):
    shape = x.shape
    x = x.ravel()
    x[np.isnan(x)] = 0.0
    x = x.reshape(shape)
    return x

@numba.jit(nopython=True)
def numba_inf_fill(x):
    shape = x.shape
    x = x.ravel()
    x[np.isinf(x)] = 0.0
    x = x.reshape(shape)
    return x

@numba.jit(nopython=True, fastmath=True)
def numba_calc_mi_dis(arr1, arr2, bins, m):
    """ Calculates a mutual information distance D(X, Y) = H(X, Y) - I(X, Y) using bin-based method

    It takes gene expression data from single cells, and compares them using standard calculation for
    mutual information and joint entropy. It builds a 2d histogram, which is used to calculate P(arr1, arr2).

    Args:
        arr1 (pandas series): gene expression data for cell 1
        arr2 (pandas series): gene expression data for cell 2
        marginals  (ndarray): marginal probability matrix
        index1         (int): index of cell 1
        index2         (int): index of cell 2
        bins           (int): number of bins
        m              (int): number of genes
    Returns:
        a float between 0 and 1
    """
    hist = numba_histogram2d(arr1, arr2, bins)
    sm = np.sum(hist, axis=1)
    tm = np.sum(hist, axis=0)
    sm = sm / float(sm.sum())
    tm = tm / float(tm.sum())

    sm_tm = np.zeros((bins, bins), dtype=np.float32)
    for i, s in enumerate(sm):
        for j, t in enumerate(tm):
            sm_tm[i, j] = s * t

    fq = hist / float(m)
    div = np.true_divide(fq, sm_tm)
    numba_nan_fill(div)
    ent = np.log(div)
    numba_inf_fill(ent)
    agg = np.multiply(fq, ent)
    joint_ent = -np.multiply(fq, numba_inf_fill(np.log(fq))).sum()
    return joint_ent - agg.sum()

In [99]:
%timeit numba_histogram2d(arr1, arr2, num_bins)


102 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [88]:
%timeit np.histogram2d(arr1, arr2, bins=(num_bins, num_bins))[0]

499 µs ± 4.17 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [300]:
%timeit fast_histogram.histogram2d(arr1, arr2, range=[[arr1.min(), arr1.max()+1e-9], [arr2.min(), arr2.max()+1e-9]], bins=(num_bins, num_bins))

84.7 µs ± 5.29 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [301]:
numba_calc_mi_dis(arr1, arr2, num_bins, num_genes)

0.39198347724449967

In [279]:
arr3 = frame_np[1000,:]
arr4 = frame_np[2401,:]

In [284]:
%timeit calc_norm_mi(arr3, arr4, num_bins, num_genes)

231 µs ± 22.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [285]:
%timeit numba_calc_mi_dis(arr3, arr4, num_bins, num_genes)

109 µs ± 413 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
