In [1]:
import pandas as pd
import numpy as np
import scanpy as sc
import jax.numpy as jnp
import cloudpickle

from cellrank.kernels import VelocityKernel
from cellrank.estimators import GPCCA
from moscot.problems.time import TemporalProblem
import moscot.plotting as mtp
import seaborn as sns
from cellrank.kernels import RealTimeKernel

In [2]:
Notochord = ['Notochord', 'Ciliated nodal cells']
Gut = ['Gut', 'Foregut epithelial cells', 'Pancreatic islets', 'Pancreatic acinar cells', 'Biliary epithelial cells']
Intermediate_mesoderm_and_kidney = ['Anterior intermediate mesoderm', 
                                    'Posterior intermediate mesoderm', 'Ureteric bud', 'Metanephric mesenchyme',
                                    'Collecting duct principal cells', 'Nephron progenitors',
                                    'Distal convoluted tubule', 'Ascending loop of Henle',
                                    'Podocytes', 'Proximal tubule cells', 'Connecting tubule', 
                                    'Collecting duct intercalated cells'
                                   ]
Eye_and_other = ['Naive retinal progenitor cells', 'Retinal progenitor cells', 'Ciliary margin cells', 
                 'Suprachiasmatic nucleus', 'Bipolar precursor cells', 'Photoreceptor precursor cells', 
                 'Rod precursor cells', 'Cone precursor cells'
                ]
Epithelial_cells = ['Placodal area', 'Olfactory epithelial cells', 'Olfactory bulb cells', 'Thyroid gland cells',
                    'Olfactory pit cells', 'Pituitary/Pineal gland progenitors', 'Thymic epithelial cells',
                    'Basal keratinocytes', 'Apical ectodermal ridge', 'Granular keratinocytes',
                    'Lens epithelial cells', 'Branchial arch epithelium', 'Conjunctival goblet cells',
                    'Corneal epithelial cells', 'Bladder urothelial cells', 'Parathyroid epithelial cells',
                    'Tooth junctional epithelium', 'Dental epithelial cells', 'Amniotic ectoderm', 
                    'Pre-epidermal keratinocytes', 'Otic epithelial cells'
                   ]
Glands = ['Nonsensory cochlear epithelium', 'Pineal gland', 'Pituitary gland cells', 'Cochlear hair cells']
Mesoderm = ['Chondrocytes (Atp1a2+)', 'Chondrocytes (Otor+)', 'Dermatome', 'Dermomyotome', 'Early chondrocytes',
            'Facial mesenchyme', 'Fibroblasts', 'Lateral plate and intermediate mesoderm', 
            'Limb mesenchyme progenitors', 'Mesodermal progenitors (Tbx6+)', 'Pre-osteoblasts (Sp7+)', 'Sclerotome'
           ]
Cardiocytes = ['Atrial cardiomyocytes', 'First heart field', 'Second heart field', 'Ventricular cardiomyocytes']
Aidpocytes = ['Adipocyte cells (Cyp2e1+)', 'Adipocyte progenitor cells', 'Brown adipocyte cells']
Muscle_cells = ['Muscle progenitor cells', 'Muscle progenitor cells (Prdm1+)', 'Myoblasts', 'Myofibroblasts',
                'Myotubes'
               ]
Testis_and_adrenal = ['Adrenocortical cells', 'Leydig cells']
Neural_crest_PNS_neurons = ['Dorsal root ganglion neurons', 'Enteric neurons', 'Neural crest (PNS neurons)',
                            'Otic sensory neurons', 'Parasympathetic neurons', 'Sympathetic neurons'
                           ]
Neural_crest_PNS_glia = ['Melanocyte cells', 'Myelinating Schwann cells', 'Myelinating Schwann cells (Tgfb2+)',
                         'Neural crest (PNS glia)', 'Olfactory ensheathing cells', 'Satellite glial cells'
                        ]
Olfactory_sensory_neurons = ['Corticofugal neurons', 'Olfactory sensory neurons']
Neuroectoderm_and_glia = ['Anterior floor plate', 'Anterior roof plate', 'Astrocytes', 'Diencephalon',
                          'Cerebellum-related cells', 'Dorsal telencephalon', 'Eye field', 'Floorplate and p3 domain',
                          'Hindbrain', 'Hypothalamus', 'Hypothalamus (Sim1+)', 'Midbrain', 
                          'Midbrain-hindbrain boundary', 'Multiciliated ependymal cells', 
                          'NMPs and spinal cord progenitors', 'Posterior roof plate', 'Retinal pigment cells', 
                          'Spinal cord/r7/r8', 'Telencephalon'
                         ]
CNS_neurons = ['Amacrine cells', 'Amacrine/Horizontal precursor cells', 'Cajal-Retzius cells', 
               'Cerebellar Purkinje cells', 'Cholinergic amacrine cells', 'Cranial motor neurons', 
               'GABAergic cortical interneurons', 'GABAergic neurons', 'Glutamatergic neurons', 'Horizontal cells',
               'Neural progenitor cells (Neurod1+)', 'Neural progenitor cells (Ror1+)', 'Neurons (Slc17a8+)',
               'PV-containing retinal ganglion cells', 'Retinal ganglion cells', 'Spinal cord dorsal progenitors', 
               'Spinal cord motor neurons', 'Spinal cord ventral progenitors', 'Thalamic neuronal precursors'
              ]
Ependymal_cells = ['Choroid plexus', 'Ependymal cells']
Olidendrocytes = ['Committed oligodendrocyte precursors', 'Oligodendrocyte progenitor cells']
Intermediate_neuronal_progenitors = ['Cortical Interneurons (Prox1+)', 'Deep-layer neurons', 
                                     'Intermediate neuronal progenitors', 'Subplate neurons', 'Upper-layer neurons'
                                    ]
Endothelium = ['Arterial endothelial cells', 'Brain capillary endothelial cells', 'Brain pericytes', 
               'Endocardial cells', 'Endothelium', 'Glomerular endothelial cells', 'Hematoendothelial progenitors',
               'Liver sinusoidal endothelial cells', 'Lymphatic vessel endothelial cells',
               'Microvascular endothelial cells', 'Pericytes', 'Venous and capillary endothelial cells'
              ]
Definitive_erythroid = ['Definitive early erythroblasts (CD36-)', 'Definitive erythroblasts (CD36+)']
B_cells = ['B cell progenitors', 'B cells']
Hepatocytes = ['Hepatocytes']
Intestine = ['Intestinal enteroendocrine cells', 'Intestinal goblet cells', 'Midgut/Hindgut epithelial cells']
Lung_and_airways = ['Airway club cells', 'Airway goblet cells', 'Alveolar Type 1 cells', 'Alveolar Type 2 cells',
                    'Lung cells (Eln+)', 'Lung progenitor cells'
                   ]
Mast_cells = ['Mast cells', 'Mast cells (P2rx7+)']
Megakaryocytes = ['Megakaryocytes']
Primitive_erythroid = ['Primitive erythroid cells']
T_cells = ['Activated T cells', 'Natural killer cells', 'Regulatory T cells', 'T cells']
White_blood_cells = ['Adipose tissue macrophages', 'Border-associated macrophages', 
                     'Border-associated macrophages (Cd74+)', 'Border-associated macrophages (Ms4a8a+)',
                     'Conventional dendritic cells', 'Granulocytes', 'Hematopoietic stem cells (Cd34+)',
                     'Hematopoietic stem cells (Mpo+)', 'Kupffer cells', 'Microglia', 'Monocytes', 
                     'Monocytic myeloid-derived suppressor cells', 'Osteoclasts', 
                     'PMN myeloid-derived suppressor cells', 'Plasmacytoid dendritic cells'
                    ]
Extraembryonic_visceral_endoderm = ['Extraembryonic visceral endoderm']
Primordial_germ_cells = ['Primordial germ cells']

liste1 = ['Notochord','Gut','Intermediate_mesoderm_and_kidney','Eye_and_other','Epithelial_cells','Glands','Mesoderm',
          'Cardiocytes','Aidpocytes','Muscle_cells','Testis_and_adrenal','Neural_crest_PNS_neurons',
          'Neural_crest_PNS_glia','Olfactory_sensory_neurons','Neuroectoderm_and_glia','CNS_neurons',
          'Ependymal_cells','Olidendrocytes','Intermediate_neuronal_progenitors','Endothelium','Definitive_erythroid',
          'B_cells','Hepatocytes','Intestine','Lung_and_airways','Mast_cells','Megakaryocytes','Primitive_erythroid',
          'T_cells','White_blood_cells','Extraembryonic_visceral_endoderm','Primordial_germ_cells'
         ]
liste2 = [Notochord,Gut,Intermediate_mesoderm_and_kidney,Eye_and_other,Epithelial_cells,Glands,Mesoderm,Cardiocytes,
          Aidpocytes,Muscle_cells,Testis_and_adrenal,Neural_crest_PNS_neurons,Neural_crest_PNS_glia,
          Olfactory_sensory_neurons,Neuroectoderm_and_glia,CNS_neurons,Ependymal_cells,Olidendrocytes,
          Intermediate_neuronal_progenitors,Endothelium,Definitive_erythroid,B_cells,Hepatocytes,Intestine,
          Lung_and_airways,Mast_cells,Megakaryocytes,Primitive_erythroid,T_cells,White_blood_cells,
          Extraembryonic_visceral_endoderm,Primordial_germ_cells
         ]

dic = {}

for i in range(len(liste1)):
    for celltype in liste2[i]:
        dic[celltype] = liste1[i]

In [3]:
adata = sc.read("/lustre/groups/ml01/workspace/monge_velo/data/adata_JAX_dataset_1.h5ad")
meta = pd.read_csv('/lustre/groups/ml01/workspace/monge_velo/data/df_cell.csv', index_col=0)
adata.obs['celltype'] = pd.Categorical(
    adata.obs['cell_id'].to_frame().merge(meta, on='cell_id', how='inner')['celltype_update']
)
adata.obs['major_trajectory'] = pd.Categorical(
    adata.obs['cell_id'].to_frame().merge(meta, on='cell_id', how='inner')['major_trajectory']
)
adata.obs['annotations_moscot'] = pd.Categorical(
    [dic[celltype] for celltype in adata.obs['celltype']]
)

leave_in = [8.5, 8.75]
adata_time = adata[adata.obs['day'].isin(leave_in)].copy()
del adata
adata_time.obs['day'] = adata_time.obs['day'].astype('category')

In [4]:
sc.pp.subsample(adata_time, n_obs=80000)
sc.pp.pca(adata_time)

adata_time.obs["source_marginals"] = np.ones(adata_time.n_obs)
adata_time.obs["target_marginals"] = np.ones(adata_time.n_obs)
adata_time.obs["target_marginals"] -= (1-1/16)*(adata_time.obs['annotations_moscot']=='Mesoderm')
adata_time.obs["target_marginals"] -= (1-1/16)*(adata_time.obs['annotations_moscot']=='Neuroectoderm_and_glia')
adata_time.obs["target_marginals"] -= (1-1/4)*(adata_time.obs['annotations_moscot']=='Endothelium')
adata_time.obs["target_marginals"] -= (1-1/4)*(adata_time.obs['annotations_moscot']=='Epithelial_cells')
adata_time.obs["target_marginals"] -= (1-1/4)*(adata_time.obs['annotations_moscot']=='Neural_crest_PNS_glia')
adata_time.obs["target_marginals"] -= (1-1/2)*(adata_time.obs['annotations_moscot']=='Primitive_erythroid')
adata_time.obs["target_marginals"] -= (1-50)*(adata_time.obs['annotations_moscot']=='CNS_Neurons')
adata_time.obs["target_marginals"] -= (1-50)*(adata_time.obs['annotations_moscot']=='Hepatocytes')
adata_time.obs["target_marginals"] -= (1-50)*(adata_time.obs['annotations_moscot']=='Primordial_germ_cells')
adata_time.obs["target_marginals"] -= (1-50)*(adata_time.obs['annotations_moscot']=='Neural_crest_PNS_neurons')
adata_time.obs["target_marginals"] -= (1-50)*(adata_time.obs['annotations_moscot']=='Muscle_cells')

tp = TemporalProblem(adata_time)
tp = tp.prepare("day", joint_attr="X_pca", a='source_marginals', b='target_marginals')
tp = tp.solve(epsilon=1e-2,
              #initializer="random",
              #rank=5000,
              batch_size=2048
             )
tp[(8.5, 8.75)].solution.to('cpu')

[34mINFO    [0m Ordering [1;35mIndex[0m[1m([0m[1m[[0m[32m'run_4_P2-04H.AAACTGAACTTATAGACGCA-0'[0m,                                                    
                [32m'run_4_P2-01B.AACGCGTCTCTTGGTAATG-0'[0m,                                                              
                [32m'run_4_PE-10D.CGCTAACCTTGGCCGGCCT-0'[0m,                                                              
                [32m'run_15_PB-11D_S180.CTGGAAGATGCTAACTTGC-1'[0m,                                                        
                [32m'run_4_PC-08G.CTACCTGGTCCATAAGTCC-0'[0m,                                                              
                [32m'run_4_PD-12D.CCGTCGATTATGGCTCTGC-0'[0m,                                                              
                [32m'run_4_PD-07G.AAGGCTACTTATAGACGCA-0'[0m,                                                              
                [32m'run_4_PC-09G.ACGATATCATTCAAGCCGAT-0'[0m,                          

2023-09-26 15:55:44.033365: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.2.140). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.




OTTOutput[shape=(61467, 18533), cost=5652738.5, converged=False]

In [5]:
myorder00 = adata_time[adata_time.obs['day']==8.50].obs['annotations_moscot'].unique().sort_values()
myorder25 = adata_time[adata_time.obs['day']==8.75].obs['annotations_moscot'].unique().sort_values()

In [6]:
res = tp.cell_transition(
    source=8.50,
    target=8.75,
    source_groups={'annotations_moscot': list(myorder00)},
    target_groups={'annotations_moscot': list(myorder25)},
    forward=True,
    key_added="tp_transitions",
    batch_size=4096
)
res

Unnamed: 0,CNS_neurons,Cardiocytes,Endothelium,Epithelial_cells,Extraembryonic_visceral_endoderm,Gut,Hepatocytes,Intermediate_mesoderm_and_kidney,Megakaryocytes,Mesoderm,Muscle_cells,Neural_crest_PNS_glia,Neural_crest_PNS_neurons,Neuroectoderm_and_glia,Notochord,Primitive_erythroid,Primordial_germ_cells
CNS_neurons,0.003303,0.117799,0.061381,0.092522,0.02655,0.086774,0.034635,0.022794,0.000662,0.120528,0.093911,0.026295,0.090767,0.121093,0.011841,0.026643,0.062502
Cardiocytes,0.000173,0.8634,0.011122,0.016098,0.004624,0.010164,0.007733,0.003459,0.000112,0.021567,0.010764,0.00425,0.011677,0.016856,0.001665,0.005873,0.010461
Definitive_erythroid,0.001109,0.113395,0.075605,0.10338,0.034378,0.063733,0.019516,0.02307,0.00073,0.141574,0.037228,0.028846,0.097069,0.128616,0.008893,0.046825,0.076032
Endothelium,0.000618,0.085407,0.332172,0.07043,0.026837,0.046999,0.093292,0.015141,0.001005,0.071822,0.044013,0.019508,0.041356,0.054218,0.007458,0.037284,0.05244
Epithelial_cells,0.00108,0.095303,0.056336,0.165624,0.027758,0.083564,0.087884,0.024157,0.000619,0.098029,0.074709,0.027519,0.063824,0.085089,0.011154,0.030277,0.067074
Extraembryonic_visceral_endoderm,0.000548,0.05688,0.039311,0.053563,0.488436,0.040219,0.059708,0.010109,0.000352,0.059826,0.025858,0.012896,0.036213,0.051367,0.005207,0.022104,0.037404
Gut,0.001158,0.095874,0.05433,0.096235,0.037381,0.133116,0.15047,0.019027,0.000529,0.092657,0.052632,0.021676,0.057858,0.078281,0.014691,0.030522,0.063563
Intermediate_mesoderm_and_kidney,0.001253,0.106579,0.053872,0.101573,0.01721,0.077812,0.069719,0.102491,0.000827,0.106172,0.099429,0.031492,0.062257,0.076198,0.010844,0.029146,0.053126
Mesoderm,0.001149,0.116929,0.056366,0.097858,0.023423,0.083501,0.11154,0.025836,0.000906,0.127649,0.076457,0.029251,0.063768,0.080753,0.010744,0.034392,0.05948
Neural_crest_PNS_glia,0.00792,0.082463,0.058564,0.107204,0.020119,0.066212,0.069846,0.030752,0.000683,0.109899,0.070448,0.109497,0.058875,0.097602,0.009265,0.035663,0.064988


In [7]:
res.to_csv('/home/icb/jonas.flor/gastrulation_atlas/moscot/data/tp0850_unbalanced.csv')