In [1]:
# add parent directory to path for importing HSNE modules
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import scanpy as sc
import anndata
import plotting as pl
import tools as tl
import numpy as np
import time as time

# Load Dataset
filelocation = r"datasets/VBh_converted.h5ad"
adata = anndata.read_h5ad(filelocation)

# subsampling
sc.pp.subsample(adata, 0.1)

# normalizing
adata.X = np.arcsinh(adata.X / 10)

# calc knn
sc.pp.neighbors(adata, n_neighbors=20)

  from pandas.core.index import RangeIndex


In [2]:
def _calc_P(T):
    return (T + T.transpose()) / (2 * T.shape[0])

In [3]:
from scipy.sparse import csr_matrix, spdiags
from scipy.sparse.linalg import eigs

def fresh_calc_T(adata):
    

    # get connectivities from adata
    c = adata.obsp['connectivities']

    # make sure connectivities are symmetric
    assert(len((c - c.T).data) == 0), "connectivities are not symmetric"

    # row-normalise c to give a transition matrix
    T = c.multiply(csr_matrix(1.0 / np.abs(c).sum(1)))

    # make sure it's correctly row-normalised
    assert(np.allclose(T.sum(1), 1)), "T is not row-normalised"

    # compute the stationary distribution
    #from scipy.sparse.linalg import eigs
    D, V = eigs(T.T, which='LM')
    pi = V[:, 0]

    # make sure pi is entirely real
    assert((pi.imag == 0).all()), "This is not the stationary vector, found imaginary entries"
    pi = pi.real

    # make sure all entries have the same sign
    assert((pi > 0).all() or (pi < 0).all()), "This is not the stationary vector, found positive and negative entries"
    pi /= pi.sum()

    # check pi is normalised correctly
    assert(np.allclose(pi.sum(), 1)), "Pi is not normalized correctly"

    # put the stationary dist into a diag matrix
    Pi = spdiags(pi, 0, pi.shape[0], pi.shape[0])

    # finally, check for reversibility of T
    assert(np.allclose((Pi @ T - T.T @ Pi).data, 0))
    
    return T
    
t0_new = time.time()
T_new = fresh_calc_T(adata)
t1_new = time.time()

In [4]:
# compared to old method for calculating T
import multiprocessing as mp
from scipy.special import softmax
from scipy.stats import entropy

def _calc_first_T(distances_nn, dim):
    p = mp.Pool(mp.cpu_count())
    probs = p.map(_helper_method_calc_T, [dist.data for dist in distances_nn])
    p.terminate()
    p.join()
    data = []
    for pr in probs:
        data.extend(pr)
    T = csr_matrix((data, distances_nn.indices, distances_nn.indptr), shape=(dim,dim))
    return T

def _helper_method_calc_T(dist):
    d = dist / np.max(dist)
    return softmax((-d ** 2) / _binary_search_sigma(d, len(d)))

def _binary_search_sigma(d, n_neigh):
    # binary search
    sigma = 10  # Start Sigma
    goal = np.log(n_neigh)  # log(k) with k being n_neighbors
    # Do binary search until entropy ~== log(k)
    while True:
        ent = entropy(softmax((-d ** 2) / sigma))
        # check sigma
        if np.isclose(ent, goal):
            return sigma
        if ent > goal:
            sigma *= 0.5
        else:
            sigma /= 0.5


t0_old = time.time()
T_old = _calc_first_T(adata.obsp['distances'], len(adata.X))

t1_old = time.time()



In [5]:
print("--NEW--")
print(f"shape: {np.shape(T_new)}")
print(f"length: {len(T_new.data)}")
print(f"sum of data: {sum(T_new.data)}")
P_new = _calc_P(T_new)
print(f"sum first row P {sum(sum((P_new.getrow(0)).toarray()))}")
print(f"sum first row T {sum(sum((T_new.getrow(0)).toarray()))}")
print(f"time: {t1_new-t0_new}\n")

print("--OLD--")
print(f"shape: {np.shape(T_old)}")
print(f"length: {len(T_old.data)}")
print(f"sum of data: {sum(T_old.data)}")
P_old = _calc_P(T_old)
print(f"sum first row P {sum(sum((P_old.getrow(0)).toarray()))}")
print(f"sum first row T {sum(sum((T_old.getrow(0)).toarray()))}")
print(f"time: {t1_old-t0_old}\n")


--NEW--
shape: (12946, 12946)
length: 348030
sum of data: 12945.99999087235
sum first row P 7.912657000918663e-05
sum first row T 1.0000000521540642
time: 1.366792917251587

--OLD--
shape: (12946, 12946)
length: 245974
sum of data: 12945.99999999991
sum first row P 8.126529182022665e-05
sum first row T 0.9999999999999999
time: 8.002951622009277



In [6]:
print("Old")
print(sum(T_new.toarray()[0]))
print(max(T_new.data))


print("\nNew")
print(sum(T_old.toarray()[0]))
print(max(T_old.data))


Old
1.0000000521540642
0.23137873

New
0.9999999999999999
0.05415735182233375
