In [3]:
import numpy as np
import pandas as pd 
import scanpy as sc 
import copy 

%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../src')

from spaceoracle.plotting.transitions import estimate_transition_probabilities
from celloracle.trajectory.oracle_core import Oracle 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import copy 

adata = sc.read_h5ad('.cache/adata_train.h5ad')
adata = adata[:1000, :]
delta_X = np.random.rand(adata.shape[0], adata.shape[1])

n_neighbors=200
n_pcs=20
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata)

adata

AnnData object with n_obs × n_vars = 1000 × 5013
    obs: 'cluster', 'rctd_cluster', 'rctd_celltypes'
    uns: 'log1p', 'pca', 'neighbors', 'umap'
    obsm: 'X_spatial', 'rctd_results', 'spatial', 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'imputed_count', 'normalized_count', 'raw_count'
    obsp: 'distances', 'connectivities'

In [88]:
co = Oracle()

co.adata = adata.copy()
co.embedding_name = 'X_umap'
co.adata.layers['delta_X'] = delta_X

co.estimate_transition_prob(
    n_neighbors=199,
    knn_random=False,
    n_jobs=10
)

co.calculate_embedding_shift()
co_P = co.transition_prob
co_P

array([[0.        , 0.        , 0.        , ..., 0.0044442 , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.00423118],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.004538  , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.00487962],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [89]:
my_P = estimate_transition_probabilities(adata, delta_X, adata.obsm['X_umap'])
my_P

array([[0.        , 0.        , 0.        , ..., 0.0044442 , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.00423118],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.004538  , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.00487962],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [90]:
np.max(abs(my_P - co_P))

0.0

In [91]:
np.max(my_P), np.max(co_P)

(0.016763245380754358, 0.016763245380754358)

In [92]:
co.embedding_knn.A

array([[0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [93]:
indices = np.argwhere(co.embedding_knn.A == 1)[:, 1].reshape(1000, -1)
indices

array([[  4,  18,  25, ..., 992, 993, 997],
       [ 14,  18,  19, ..., 980, 994, 999],
       [  3,  10,  12, ..., 984, 991, 996],
       ...,
       [  0,   4,   6, ..., 988, 992, 993],
       [  5,  13,  16, ..., 994, 995, 999],
       [  5,  13,  16, ..., 972, 973, 994]])

In [94]:
unmatched_neighbors = set(indices[0]) - set(adata.uns['indices'][0])
unmatched_neighbors

set()

In [95]:
corr = adata.uns['corrcoef']
P = adata.uns['P']
P *= np.exp(corr / 0.05)
P

array([[0.        , 0.        , 0.        , ..., 0.75399288, 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.67846618],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.99631097, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.88227567],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [96]:
adata.uns['indices']

array([[  0, 951, 582, ..., 993, 482,   4],
       [  1, 145, 580, ..., 994, 780, 884],
       [  2, 144, 710, ..., 300,  70, 210],
       ...,
       [997, 244, 126, ..., 712, 521, 482],
       [998, 160, 165, ..., 555, 745, 636],
       [999, 839, 130, ...,  99, 714,  16]])

In [97]:
P = adata.uns['P']
P

array([[0.        , 0.        , 0.        , ..., 0.75399288, 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.67846618],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.99631097, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.88227567],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [98]:
adata.uns['corrcoef']

array([[ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
        -1.41186177e-02,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  1.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00, -1.93960326e-02],
       [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       ...,
       [-1.84792695e-04,  0.00000000e+00,  0.00000000e+00, ...,
         1.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  1.00000000e+00, -6.26253604e-03],
       [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  1.00000000e+00]])

In [99]:
co.corrcoef

array([[ 0.        , -0.00741942,  0.0095989 , ..., -0.01411862,
        -0.00676787, -0.01400986],
       [ 0.00170498,  0.        ,  0.00518266, ...,  0.00109803,
        -0.02399032, -0.01939603],
       [ 0.01040247,  0.00569889,  0.        , ...,  0.01292022,
         0.01009845,  0.00970909],
       ...,
       [-0.00018479, -0.01047261,  0.01245032, ...,  0.        ,
        -0.00590852, -0.01013122],
       [-0.01631466, -0.035558  , -0.01062188, ..., -0.02172894,
         0.        , -0.00626254],
       [-0.00470144, -0.02084683,  0.00638379, ...,  0.00521196,
        -0.03382686,  0.        ]])