In [2]:
%load_ext autoreload
%autoreload 2

%config InlineBackend.figure_format = 'retina'

In [1]:
import matplotlib.pyplot as plt
import scanpy as sc
import numpy as np
import seaborn as sns
import torch
import pandas as pd

from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import r2_score


In [7]:
import sys
sys.path.append('../src')

from spaceoracle.models.parallel_estimators import SpatialCellularProgramsEstimator

In [48]:
adata_train = sc.read_h5ad('/ix/djishnu/alw399/SpaceOracle/notebooks/.cache/adata_train.h5ad')
adata_train

AnnData object with n_obs × n_vars = 11567 × 5013
    obs: 'cluster', 'rctd_cluster', 'rctd_celltypes'
    uns: 'log1p'
    obsm: 'X_spatial', 'rctd_results', 'spatial'
    layers: 'imputed_count', 'normalized_count', 'raw_count'

In [6]:
# subsample data for lil test run

estimator = SpatialCellularProgramsEstimator(
    adata=adata_train,
    target_gene='Pax5',
)

subset = set(estimator.receptors + estimator.ligands + estimator.tfl_regulators + estimator.tfl_ligands)
subset.update({'Pax5'})
len(subset)

90

In [None]:
adata = adata_train[:, adata_train.var_names.isin(subset)]
adata = adata[:100, :]
adata

View of AnnData object with n_obs × n_vars = 100 × 90
    obs: 'cluster', 'rctd_cluster', 'rctd_celltypes'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'received_ligands', 'received_ligands_tfl', 'ligand_receptor', 'ligand_regulator', 'hvg'
    obsm: 'X_spatial', 'rctd_results', 'spatial', 'spatial_maps'
    layers: 'imputed_count', 'normalized_count', 'raw_count'

In [38]:
from spaceoracle.oracles import SpaceOracle

so = SpaceOracle(
    adata=adata,
    annot='rctd_cluster', 
    max_epochs=5, 
    learning_rate=7e-4, 
    spatial_dim=35,
    batch_size=256,
    rotate_maps=True,
    alpha=0.4,
)


In [None]:
%%time
so.run()

In [None]:
so.beta_dict = so._get_spatial_betas_dict()

In [None]:
betas_df = betaoutput.betas

ligands = betaoutput.ligands
receptors = betaoutput.receptors
tfl_regulators = betaoutput.tfl_regulators
tfl_ligands= betaoutput.tfl_ligands
modulators = betaoutput.modulator_genes


In [None]:
so.perturb(target='Pax5', n_propagation=3, gene_expr=0, n_jobs=1)

In [None]:
## Testing

In [216]:
tfs = ['A', 'B']
ligs = ['C', 'D']
recs = ['E', 'F']
genes = ['A', 'B', 'C', 'D', 'E', 'F']

# A = bB + b(C$E) + b(C#B)
# B = bA + b(C$E)
# C = bA + bB
# D = bA + b(C$E) + b(C#A)
# E = bA + b(D#B)
# F = bB + b(D#A) + b(D#B) 

In [244]:
n_cells = 100
n_genes = len(genes)

gene_mtx = np.random.rand(n_cells, n_genes)
gene_mtx = (gene_mtx * 12).astype(int)
gene_mtx = pd.DataFrame(gene_mtx, columns=genes, index=np.arange(n_cells).astype(str))

In [245]:
from spaceoracle.oracles import BetaOutput

gene2index = {gene: i for i, gene in enumerate(genes)}
ligands = ligs

def get_betaoutputs(all_modulators, ncells=100):
    tfs = [i for i in all_modulators if '$' not in i and '#' not in i]
    lr_pairs = [i for i in all_modulators if '$' in i]
    tfl_pairs = [i for i in all_modulators if '#' in i]
    
    ligands = [i.split('$')[0] for i in lr_pairs]
    receptors = [i.split('$')[1] for i in lr_pairs]

    tfl_ligands = [i.split('#')[0] for i in tfl_pairs]
    tfl_regulators = [i.split('#')[1] for i in tfl_pairs]

    modulators = np.unique(tfs + ligands + receptors + tfl_ligands + tfl_regulators)   # sorted names
    modulator_gene_indices = [gene2index[m] for m in modulators] 
    modulators = [f'beta_{m}' for m in modulators]

    all_modulators = [f'beta_{m}' for m in all_modulators]
    betadata = create_betadata(ncells, all_modulators)

    return BetaOutput(
            betas=betadata[['beta0']+all_modulators],
            modulator_genes=modulators,
            modulator_gene_indices=modulator_gene_indices,
            ligands=ligands,
            receptors=receptors,
            tfl_ligands=tfl_ligands,
            tfl_regulators=tfl_regulators
        )

def create_betadata(ncells, all_modulators):
    betas = np.random.rand(ncells, 1 + len(all_modulators))
    index = np.arange(n_cells).astype(str)
    betas = pd.DataFrame(betas, index=index, columns=['beta0'] + all_modulators)
    return betas

betas_dict = {
    'A': get_betaoutputs(['B', 'C$E', 'C#B']),
    'B': get_betaoutputs(['A', 'C$E']),
    'C': get_betaoutputs(['A', 'B']),
    'D': get_betaoutputs(['A', 'C$E', 'C#A']),
    'E': get_betaoutputs(['A', 'D#B']),
    'F': get_betaoutputs(['B', 'D#A', 'D#B'])
}

In [246]:
adata = ad.AnnData(gene_mtx)
adata.layers['raw_count'] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
adata.layers['normalized_count'] = adata.X.copy()

adata.obsm['spatial'] = np.random.rand(n_cells, 2)
adata

AnnData object with n_obs × n_vars = 100 × 6
    obsm: 'spatial'
    layers: 'raw_count', 'normalized_count'

In [247]:
so = SpaceOracle(adata=adata)
so.beta_dict = betas_dict
so.ligands = ligands

In [248]:
from spaceoracle.models.parallel_estimators import received_ligands
weighted_ligands = received_ligands(
    so.adata.obsm['spatial'], 
    gene_mtx[list(ligands)]
)
weighted_ligands.index = weighted_ligands.index.astype(str)
wbeta_dict = so._get_wbetas_dict(so.beta_dict, gene_mtx, n_jobs=1)

output = wbeta_dict['A'].wbetas
input = wbeta_dict['A'].betas
weights = gene_mtx

  0%|          | 0/6 [00:00<?, ?it/s]

In [249]:
output

Unnamed: 0,beta_B,beta_C,beta_E
0,4.254500,2.670818,2.268348
1,4.852398,7.721592,1.074011
2,1.454800,6.079462,2.739403
3,5.360922,6.984121,2.869091
4,3.307415,4.002586,2.524527
...,...,...,...
95,0.579218,1.009980,1.406438
96,0.759339,10.011951,4.915714
97,3.284269,4.990796,2.783348
98,3.070260,9.001718,3.672601


In [251]:
# A = bB + b(C$E) + b(C#B)
# dA/dB = b + bwC
# dA/dC = bE + bB
# dA/dE = bwC

dAdB = wbeta_dict['A'].betas['beta_B'] + \
      (wbeta_dict['A'].betas['beta_C#B'] * weighted_ligands['C'])
dAdC = (wbeta_dict['A'].betas['beta_C$E'] * gene_mtx['E']) + \
      (wbeta_dict['A'].betas['beta_C#B'] * gene_mtx['B'])
dAdE = (wbeta_dict['A'].betas['beta_C$E'] * weighted_ligands['C'])

manual_betas = pd.concat([dAdB, dAdC, dAdE], axis=1, keys=['beta_B', 'beta_C', 'beta_E'])
np.all(manual_betas == output)

True

In [256]:
# F = bB + b(D#A) + b(D#B) 
# dF/dB = b + bwD
# dF/dD = bA + bB
# dF/dA = bwD 

dFdB = wbeta_dict['F'].betas['beta_B'] + \
      (wbeta_dict['F'].betas['beta_D#B'] * weighted_ligands['D'])
dFdD = wbeta_dict['F'].betas['beta_D#A'] * gene_mtx['A'] + \
       wbeta_dict['F'].betas['beta_D#B'] * gene_mtx['B']
dFdA = wbeta_dict['F'].betas['beta_D#A'] * weighted_ligands['D']

manual_betas = pd.concat([dFdA, dFdB, dFdD], axis = 1, keys=['beta_A', 'beta_B', 'beta_D'])
np.all(manual_betas == wbeta_dict['F'].wbetas)

True

In [257]:
gem_simulated = so.perturb(gene_mtx=gene_mtx, target='A', n_propagation=2)
gem_simulated

  0%|          | 0/6 [00:00<?, ?it/s]

Running simulation 1/2: 100%|██████████| 100/100 [00:00<00:00, 18647.15it/s]


  0%|          | 0/6 [00:00<?, ?it/s]

Running simulation 2/2: 100%|██████████| 100/100 [00:00<00:00, 27545.18it/s]


array([[ 0.        ,  0.        ,  0.25573102,  0.        ,  0.        ,
         3.50841969],
       [ 0.        ,  8.56308223,  3.66582324,  0.        ,  0.        ,
         0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         3.52114887],
       [ 0.        ,  6.        , 10.        ,  2.        ,  2.        ,
         9.        ],
       [ 0.        ,  0.        ,  1.53584598,  0.        ,  5.72202772,
         3.19961141],
       [ 0.        ,  2.57159854,  7.06914131,  0.        ,  0.        ,
         0.11563628],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ],
       [ 0.        ,  0.        ,  5.        , 11.        ,  5.        ,
         4.        ],
       [ 0.        ,  0.        ,  1.04522017,  0.        ,  0.        ,
         0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.