In [88]:
import numba
import numpy as np
import scanpy as sc
import anndata
import fast_histogram
from collections import Counter

In [89]:
@numba.jit(nopython=True)
def euclidean(x, y):
    return np.sqrt(np.sum((x - y)**2))

In [90]:
#Read in anndata object
preprocessed_results = '../../test_data/inputs/10x/PBMC/3k/pre-processed/pbmc3k_preprocessed.h5ad'
adata = anndata.read_h5ad(preprocessed_results)
adata.shape

(2496, 10499)

In [91]:
#convert anndata to pandas dataframe
frame = adata.to_df()
frame.shape

(2496, 10499)

In [92]:
frame_np = frame.to_numpy()
arr1 = frame_np[0,:]
arr2 = frame_np[2068,:]
#indices1 = np.nonzero(arr1)
#print(indices1)
#print(arr1[indices1])

In [93]:
row1=0
row1start=adata.X.indptr[row1]
row1end=adata.X.indptr[row1+1]
#print(row1start,row1end) 
arr1_indices = adata.X.indices[row1start:row1end]
arr1_csr = adata.X.data[row1start:row1end]
#print(arr1_indices)
#print(arr1_csr)

row2=2068
row2start=adata.X.indptr[row2]
row2end=adata.X.indptr[row2+1]
arr2_indices = adata.X.indices[row2start:row2end]
arr2_csr = adata.X.data[row2start:row2end]
#print(arr2_indices)
#print(arr2_csr)

In [94]:
%timeit euclidean(arr1, arr2)

11.2 µs ± 24.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [138]:
def calc_norm_mi_numpy(arr1, arr2, bins, m):
    """ Calculates a normalized mutual information distance D(X, Y) = 1 - I(X, Y)/H(X, Y) using bin-based method

    It takes gene expression data from single cells, and compares them using standard calculation for
    mutual information and joint entropy. It builds a 2d histogram, which is used to calculate P(arr1, arr2).

    Args:
        arr1 (pandas series): gene expression data for cell 1
        arr2 (pandas series): gene expression data for cell 2
        marginals  (ndarray): marginal probability matrix
        index1         (int): index of cell 1
        index2         (int): index of cell 2
        bins           (int): number of bins
        m              (int): number of genes
    Returns:
        a float between 0 and 1
    """
    fq = np.histogram2d(arr1, arr2, bins=(bins, bins))[0] / float(m)
    sm = np.sum(fq * float(m), axis=1)
    tm = np.sum(fq * float(m), axis=0)
    sm = np.asmatrix(sm / float(sm.sum()))
    tm = np.asmatrix(tm / float(tm.sum()))
    sm_tm = np.matmul(np.transpose(sm), tm)

    div = np.divide(fq, sm_tm, where=sm_tm != 0, out=np.zeros_like(fq))
    ent = np.log(div, where=div != 0, out=np.zeros_like(div))
    agg = np.multiply(fq, ent, out=np.zeros_like(fq), where=fq != 0)
    joint_ent = -np.multiply(fq, np.log(fq, where=fq != 0, out=np.zeros_like(fq)),
                             out=np.zeros_like(fq), where=fq != 0).sum()
    return joint_ent - agg.sum()

def calc_norm_mi_fast(arr1, arr2, bins, m):
    """ Calculates a normalized mutual information distance D(X, Y) = 1 - I(X, Y)/H(X, Y) using bin-based method

    It takes gene expression data from single cells, and compares them using standard calculation for
    mutual information and joint entropy. It builds a 2d histogram, which is used to calculate P(arr1, arr2).

    Args:
        arr1 (pandas series): gene expression data for cell 1
        arr2 (pandas series): gene expression data for cell 2
        marginals  (ndarray): marginal probability matrix
        index1         (int): index of cell 1
        index2         (int): index of cell 2
        bins           (int): number of bins
        m              (int): number of genes
    Returns:
        a float between 0 and 1
    """
    fq = fast_histogram.histogram2d(arr1, arr2, range=[[arr1.min(), arr1.max()+1e-9], [arr2.min(), arr2.max()+1e-9]],
                                    bins=(bins, bins)) / float(m)
    sm = np.sum(fq * float(m), axis=1)
    tm = np.sum(fq * float(m), axis=0)
    sm = np.asmatrix(sm / float(sm.sum()))
    tm = np.asmatrix(tm / float(tm.sum()))
    sm_tm = np.matmul(np.transpose(sm), tm)

    div = np.divide(fq, sm_tm, where=sm_tm != 0, out=np.zeros_like(fq))
    ent = np.log(div, where=div != 0, out=np.zeros_like(div))
    agg = np.multiply(fq, ent, out=np.zeros_like(fq), where=fq != 0)
    joint_ent = -np.multiply(fq, np.log(fq, where=fq != 0, out=np.zeros_like(fq)),
                             out=np.zeros_like(fq), where=fq != 0).sum()
    return joint_ent - agg.sum()

In [139]:
num_bins = int((frame.shape[0]) ** (1 / 3.0))
num_genes = frame.shape[1]

In [135]:
#calc_norm_mi(arr1, arr2, num_bins, num_genes)

In [136]:
@numba.jit(nopython=True, fastmath=True)
def compute_bin(x, min, max, num_bins):
    """ Compute bin index for a give number.
    """
    # special case to mirror NumPy behavior for last bin
    if x == max:
        return num_bins - 1 # a_max always in last bin

    bin = int(num_bins * (x - min) / (max - min))

    if bin < 0 or bin >= num_bins:
        return None
    else:
        return bin
    
    
@numba.jit(nopython=True, fastmath=True)
def compute_bin_upperbound(x, max, num_bins):
    """ Compute bin index for a give number.
        Assume that min is always zero
    """
    # special case to mirror NumPy behavior for last bin
    if x == max:
        return num_bins - 1 # a_max always in last bin

    bin = int(num_bins * x / max)

    if bin >= num_bins:
        return None
    else:
        return bin
    
    

@numba.jit(nopython=True, fastmath=True)
def numba_histogram2d(arr1, arr2, num_bins):
    """ Compute the bi-dimensional histogram of two data samples.
    Args:
        arr1 (array_like, shape (N,)): An array containing the x coordinates of the points to be histogrammed.
        arr2 (array_like, shape (N,)): An array containing the y coordinates of the points to be histogrammed.
        num_bins (int): int
    Return:
        hist (2D ndarray)
    """
    bin_indices1 = np.zeros((arr1.shape[0],), dtype=np.int16)
    min1 = arr1.min()
    max1 = arr1.max()
    #note that bin_indices has same size/indices as full array x and y
    for i, x in enumerate(arr1.flat):
        bin_indices1[i] = compute_bin(x, min1, max1, num_bins)

    bin_indices2 = np.zeros((arr2.shape[0],), dtype=np.int16)
    min2 = arr2.min()
    max2 = arr2.max()
    for i, y in enumerate(arr2.flat):
        bin_indices2[i] = compute_bin(y, min2, max2, num_bins)

    hist = np.zeros((num_bins, num_bins), dtype=np.int16)
    for i, b in enumerate(bin_indices1):
        hist[b, bin_indices2[i]] += 1
    return hist


#ceb create csr version of numba_histogram2d, also compute_bin with knowledge that minx will always be zero
@numba.jit(nopython=True, fastmath=True)
def numba_histogram2d_csr(arr1, cols1, arr2, cols2, ncols, num_bins):
    """ Compute the bi-dimensional histogram of two data samples.
    Args:
        arr1 (array_like, shape (N,)): An array containing the x coordinates of the points to be histogrammed.
        arr2 (array_like, shape (N,)): An array containing the y coordinates of the points to be histogrammed.
        num_bins (int): int
    Return:
        hist (2D ndarray)
    """
    #for csr arrays we have to compute zero bins ahead of time 
        
    bin_indices1 = np.zeros((ncols,), dtype=np.int16)
    max1 = arr1.max()
    #note that bin_indices has same size/indices as full array x and y
    for i, x in enumerate(arr1.flat):
        #assume zero min
        bin_indices1[cols1[i]] = compute_bin_upperbound(x, max1, num_bins)
        #bin_indices1[cols1[i]] = compute_bin(x, 0, max1, num_bins)

    bin_indices2 = np.zeros((ncols,), dtype=np.int16)
    max2 = arr2.max()
    for i, y in enumerate(arr2.flat):
        #assume zero min
        bin_indices2[cols2[i]] = compute_bin_upperbound(y, max2, num_bins)
        #bin_indices2[cols2[i]] = compute_bin(y, 0, max2, num_bins)

    hist = np.zeros((num_bins, num_bins), dtype=np.int16)
    for i, b in enumerate(bin_indices1):
        hist[b, bin_indices2[i]] += 1
        
    return hist

@numba.jit(nopython=True)
def numba_nan_fill(x):
    shape = x.shape
    x = x.ravel()
    x[np.isnan(x)] = 0.0
    x = x.reshape(shape)
    return x

@numba.jit(nopython=True)
def numba_inf_fill(x):
    shape = x.shape
    x = x.ravel()
    x[np.isinf(x)] = 0.0
    x = x.reshape(shape)
    return x

@numba.jit(nopython=True, fastmath=True)
def numba_calc_mi_dis(arr1, arr2, bins, m):
    """ Calculates a mutual information distance D(X, Y) = H(X, Y) - I(X, Y) using bin-based method

    It takes gene expression data from single cells, and compares them using standard calculation for
    mutual information and joint entropy. It builds a 2d histogram, which is used to calculate P(arr1, arr2).

    Args:
        arr1 (pandas series): gene expression data for cell 1
        arr2 (pandas series): gene expression data for cell 2
        marginals  (ndarray): marginal probability matrix
        index1         (int): index of cell 1
        index2         (int): index of cell 2
        bins           (int): number of bins
        m              (int): number of genes
    Returns:
        a float between 0 and 1
    """
    hist = numba_histogram2d(arr1, arr2, bins)
    sm = np.sum(hist, axis=1)
    tm = np.sum(hist, axis=0)
    sm = sm / float(sm.sum())
    tm = tm / float(tm.sum())

    sm_tm = np.zeros((bins, bins), dtype=np.float32)
    for i, s in enumerate(sm):
        for j, t in enumerate(tm):
            sm_tm[i, j] = s * t

    fq = hist / float(m)
    div = np.true_divide(fq, sm_tm)
    numba_nan_fill(div)
    ent = np.log(div)
    numba_inf_fill(ent)
    agg = np.multiply(fq, ent)
    joint_ent = -np.multiply(fq, numba_inf_fill(np.log(fq))).sum()
    return joint_ent - agg.sum()



@numba.jit(nopython=True, fastmath=True)
def numba_calc_mi_dis_csr(arr1, cols1, arr2, cols2, ncols, bins, m):
    """ Calculates a mutual information distance D(X, Y) = H(X, Y) - I(X, Y) using bin-based method

    It takes gene expression data from single cells, and compares them using standard calculation for
    mutual information and joint entropy. It builds a 2d histogram, which is used to calculate P(arr1, arr2).

    Args:
        arr1 (pandas series): gene expression data for cell 1
        arr2 (pandas series): gene expression data for cell 2
        marginals  (ndarray): marginal probability matrix
        index1         (int): index of cell 1
        index2         (int): index of cell 2
        bins           (int): number of bins
        m              (int): number of genes
    Returns:
        a float between 0 and 1
    """
    hist = numba_histogram2d_csr(arr1, cols1, arr2, cols2, ncols, bins)
    sm = np.sum(hist, axis=1)
    tm = np.sum(hist, axis=0)
    sm = sm / float(sm.sum())
    tm = tm / float(tm.sum())

    sm_tm = np.zeros((bins, bins), dtype=np.float32)
    for i, s in enumerate(sm):
        for j, t in enumerate(tm):
            sm_tm[i, j] = s * t

    fq = hist / float(m)
    div = np.true_divide(fq, sm_tm)
    numba_nan_fill(div)
    ent = np.log(div)
    numba_inf_fill(ent)
    agg = np.multiply(fq, ent)
    joint_ent = -np.multiply(fq, numba_inf_fill(np.log(fq))).sum()
    return joint_ent - agg.sum()


Compare runtimes for for different histogram2d implementations

In [107]:
%timeit np.histogram2d(arr1, arr2, bins=(num_bins, num_bins))[0]

387 µs ± 1.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [108]:
%timeit fast_histogram.histogram2d(arr1, arr2, range=[[arr1.min(), arr1.max()+1e-9], [arr2.min(), arr2.max()+1e-9]], bins=(num_bins, num_bins))

120 µs ± 170 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [109]:
%timeit numba_histogram2d(arr1, arr2, num_bins)

62.8 µs ± 11.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [120]:
%timeit numba_histogram2d_csr(arr1_csr, arr1_indices, arr2_csr, arr2_indices, adata.shape[1], num_bins)

25.8 µs ± 9.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [122]:
#print("numpy_hist =   ",(np.histogram2d(arr1, arr2, bins=(num_bins, num_bins))[0]).astype(int) ) 
#print("fast_hist =    ",(fast_histogram.histogram2d(arr1, arr2, range=[[arr1.min(), arr1.max()+1e-9], [arr2.min(), arr2.max()+1e-9]], bins=(num_bins, num_bins))).astype(int) )  
#print("numba_hist =   ",numba_histogram2d(arr1, arr2, num_bins))
#print("numba_csr_hist=",numba_histogram2d_csr(arr1_csr, arr1_indices, arr2_csr, arr2_indices, adata.shape[1], num_bins))

In [126]:
arr3 = frame_np[1000,:]
arr4 = frame_np[2401,:]

In [140]:
%timeit calc_norm_mi_numpy(arr3, arr4, num_bins, num_genes)

552 µs ± 2.34 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [141]:
%timeit calc_norm_mi_fast(arr3, arr4, num_bins, num_genes)

252 µs ± 713 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [146]:
%timeit numba_calc_mi_dis(arr3, arr4, num_bins, num_genes)

68.8 µs ± 139 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [143]:
ncols=adata.shape[1]
row1=1000
row1start=adata.X.indptr[row1]
row1end=adata.X.indptr[row1+1]
#print(row1start,row1end) 
arr1_indices = adata.X.indices[row1start:row1end]
arr1_csr = adata.X.data[row1start:row1end]


row2=2401
row2start=adata.X.indptr[row2]
row2end=adata.X.indptr[row2+1]
arr2_indices = adata.X.indices[row2start:row2end]
arr2_csr = adata.X.data[row2start:row2end]


In [145]:
%timeit numba_calc_mi_dis_csr(arr1_csr, arr1_indices, arr2_csr, arr2_indices, ncols, num_bins, num_genes)

32.9 µs ± 20.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [147]:
print(calc_norm_mi_numpy(arr3, arr4, num_bins, num_genes))
print(calc_norm_mi_fast(arr3, arr4, num_bins, num_genes))
print(numba_calc_mi_dis(arr3, arr4, num_bins, num_genes))
print(numba_calc_mi_dis_csr(arr1_csr, arr1_indices, arr2_csr, arr2_indices, ncols, num_bins, num_genes))

0.7153987171007082
0.7153987171007082
0.7153987389627494
0.7153987389627494
