# **The Fog of Cellular Fate: TrajectoryNet Maps the Developmental Trajectories of Human Embryonic Stem Cells**

<div style="background-color:#FFF3E9; border: 1px solid #FFE0C3; border-radius: 10px; margin-top:1rem; margin-bottom:1rem">
    <p style="margin:1rem; padding-left: 1rem; padding-right: 1rem; line-height: 2.5;">
         “<span style="color:purple">It is not our abilities that show what we truly are, it is our choices. ——Albus Percival Wulfric Brian Dumbledore”
</div>

At each decision point of fate, how do cells make their choices?

In this repository, I try to learn the basic principles of TrajectoryNet[1] and how to use TrajectoryNet and PHATE [2] (Potential of Heat-diffusion for Affinity-based Transition Embedding) datasets to analyze a time series of human embryonic stem cell (EB) differentiation over 27 days, involving 31,000 cells.

This repository references the original author's [presentation slides](https://icml.cc/media/icml-2020/Slides/6491_ntqEW8z.pdf) and the actual demonstration is adapted from the [notebook provided by the original author](https://github.com/KrishnaswamyLab/TrajectoryNet/blob/master/notebooks/EmbryoidBody_TrajectoryInference.ipynb).

We will follow these steps:

[0. Basic Introduction to TrajectoryNet](#math)  
[1. Loading 10X Data](#loading)  
[2. Preprocessing: Filtering, Normalization, and Transformation](#preprocessing)  
[3. Using PHATE to Embed Data](#embedding)  
[4. Modeling Cell Dynamic Transitions with TrajectoryNet](#trajectory)

References:

1. Tong, A., Huang, J., Wolf, G., van Dijk, D. & Krishnaswamy, S. TrajectoryNet: A Dynamic Optimal Transport Network for Modeling Cellular Dynamics. in Proceedings of the 37th International Conference on Machine Learning (2020). [url](http://proceedings.mlr.press/v119/tong20a/tong20a.pdf)
2. Moon, K. R. et al. Visualizing structure and transitions in high-dimensional biological data. Nature Biotechnology 37, 1482–1492 (2019). [url](https://doi.org/10.1038/s41587-019-0336-3)


<a id='math'></a>
## **0. TrajectoryNet's Background**

### **0.0 Introduction of TrajectoryNet**

TrajectoryNet is a dynamic optimal transport network used to simulate cellular dynamics, while PHATE is a heat-diffusion potential method based on affinity transition embedding, utilized for visualizing structures and transitions in high-dimensional biological data. These tools are very useful for understanding how cells develop and differentiate over time, especially in large-scale cellular datasets.  
Below, we briefly introduce TrajectoryNet.

* **Problem Definition**

In the analysis of single-cell omics data, an important task is to determine the order of differentiation or evolution of individual cells, which means obtaining **pseudotime** for single cells. Pseudotime is a biological concept used to study the temporal trajectory that single-cell organisms experience during development. Single-cell pseudotime analysis can reveal functional characteristics and transcriptional levels of cells at different stages, providing us with opportunities to gain insights into cellular developmental processes. By analyzing single-cell data in pseudotime, we can depict the evolutionary process of cell types and states, thus providing important information for studying biological processes such as tissue development, disease progression, and signal transduction. Additionally, pseudotime analysis can help us identify potential determinants of cell fate, laying the groundwork for further experimental research and clinical applications.

Commonly used methods for single-cell pseudotime often rely on known markers or differential analysis methods to estimate pseudotime. Such methods can introduce bias, and there are documented cases of failure in the literature. Moreover, these methods do not consider the actual experimental time.

TrajectoryNet leverages real experimental time, treating single-cell data at each time point as a sampling result from a distribution, thereby modeling the evolution between distributions. Thus, TrajectoryNet defines the pseudotime problem for single cells as an unbalanced dynamic transport problem.

In other words, TrajectoryNet aims to define an efficient and meaningful transformation equation that establishes the transition relationship between two distributions.

* **Continuous Normalizing Flows**

Typically, when considering transformations between distributions, we think of Continuous Normalizing Flows (CNF). CNF maps any complex distribution to a simple base distribution (such as a Gaussian distribution) through continuous transformations, enabling effective sampling and probability density calculations. It is commonly applied in fields such as image generation, anomaly detection, reinforcement learning, and natural language processing. By treating optimal transport as an optimization objective (viewed as a form of regularization, as shown in the figure), it induces an optimal transport path between distributions. We can intuitively define **energy minimization in optimal transport** as regularization, and defining energy minimization as regularization in CNF essentially forms the core of TrajectoryNet discussed in this article.

<p align="center">
  <img width="810" height="" src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/15237/f72c74a06947434fb94f825d46836676/YbEY5FuEpW3NKYjyD4triQ.png">
</p>

* **Optimal Transport**

From another perspective, if we want the transformation between two distributions to achieve energy minimization, we can require the Wasserstein distance to be minimized. This achieves optimal transport (as shown in the figure), where the two distributions &mu; and &nu; represent the boundary conditions.

<p align="center">
  <img width="630" height="" src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/15237/f72c74a06947434fb94f825d46836676/Lrr1vQmufGkx7oXH6T-ppg.png">
</p>

Solving the equations for these boundary conditions is not straightforward. If we change the boundary condition &nu; to a soft constraint added to the loss, we obtain a regularized CNF (as shown in the figure).

<p align="center">
  <img width="630" height="" src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/15237/f72c74a06947434fb94f825d46836676/RjuWCTdAK9te-SFwePKRdw.png">
</p>

* **TrajectoryNet**

By incorporating biological priors from RNAseq data, including density, velocity, and growth, TrajectoryNet constructs a method using dynamic optimal transport to simulate the continuous dynamics and nonlinear paths of entities within a system.

Ultimately, TrajectoryNet defines the loss function as a combination of dynamic optimal transport and biological priors.

<p align="center">
  <img width="360" height="" src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/15237/f72c74a06947434fb94f825d46836676/X6YelKWXL-hYCJ9LGnbEzQ.png">
</p>

Here, density primarily considers KNN density estimation, meaning that cells must transition through allowed portions of the state space. For example, the figure shows an unreasonable path and provides the formula for density penalty.

<p align="center">
  <img width="630" height="" src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/15237/f72c74a06947434fb94f825d46836676/H2fdCJJ0SonT9eOf91E56A.png">
</p>

Velocity primarily considers the direction of RNA changes, which refers to the velocity vector in the cell state space, mainly referencing La Manno et al.'s Velocyto and Volker et al.'s ScVelo from 2018.

<p align="center">
  <img width="630" height="" src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/15237/f72c74a06947434fb94f825d46836676/bQZkxHOQWYXpO3lYBuTWPQ.png">
</p>

Growth mainly considers that there may be some disappearance of cell states during state transitions, so TrajectoryNet allows for unbalanced transport, permitting cells to "die" rather than moving them to unreasonable positions.

<p align="center">
  <img width="630" height="" src="https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/15237/f72c74a06947434fb94f825d46836676/flm0cb76yXB-cbKXrrXY9w.png">
</p>

### **0.1 Dataset Description: Time Series of Human Embryonic Stem Cells**

Researchers use specific culture conditions and techniques to culture and differentiate human embryonic stem cells (hESCs). These cells can differentiate into various types under certain conditions, simulating the developmental process of early human embryos. The research team collected cell samples at different time points to study the dynamic process of cell differentiation. Additionally, they analyzed these cells using single-cell sequencing technology (10x Genomics), a powerful technique that provides in-depth insights into each individual cell.

The timeline for the differentiation of human embryoid bodies is as follows: low passage H1 human embryonic stem cells (hESCs) are maintained in culture dishes coated with Matrigel, using DMEM/F12-N2B27 medium supplemented with FGF2. To form embryoid bodies, the cells are treated with Dispase, separated into small clumps, and cultured in a non-adherent manner in medium supplemented with 20% fetal bovine serum, which has been screened for suitability for embryoid body differentiation. During a 27-day differentiation time series, samples are collected every three days. An undifferentiated hESC sample is also included (Figure S7D). The induction of major germ layer markers in these embryoid body cultures was validated by qPCR (data not shown). For single-cell analysis, the embryoid body cultures were dissociated, sorted by flow cytometry (FACS) to remove doublets and dead cells, and processed on a 10x Genomics instrument to generate cDNA libraries, which were then sequenced. Small-scale sequencing confirmed that we successfully collected data from approximately 31,000 cells, evenly distributed throughout the time series.


### **0.2 Install Packages**

In [None]:
! pip install TrajectoryNet
! pip install python-magic
! apt-get update
! apt-get install -y libmagic1
! pip install phate
! pip install scprep
! pip install --user magic-impute

In [None]:
! wget https://codeload.github.com/KrishnaswamyLab/TrajectoryNet/zip/refs/heads/master
! unzip master

--2023-11-10 18:19:06--  https://codeload.github.com/KrishnaswamyLab/TrajectoryNet/zip/refs/heads/master
Resolving ga.dp.tech (ga.dp.tech)... 10.255.254.7, 10.255.254.18, 10.255.254.37
Connecting to ga.dp.tech (ga.dp.tech)|10.255.254.7|:8118... connected.
Proxy request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘master.1’

master.1                [               <=>  ]  35.99M  3.96MB/s    in 11s     

2023-11-10 18:19:18 (3.32 MB/s) - ‘master.1’ saved [37736431]

Archive:  master
162e6c77728135f27ad04f1c83d78a319e79dff4
replace TrajectoryNet-master/.flake8? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

<a id='loading'></a>
## **1. Load 10X Data**

### **Importing Data from Mendeley Datasets**

This embryoid body dataset comes from the public Mendeley dataset <https://data.mendeley.com/datasets/v6n743h5ng/>, which contains single-cell RNA sequencing data from a time series of human embryoid body development. The data is located in the /data/ directory and includes samples from four time points (T0_1A, T2_3B, T4_5C, and T6_7D). These files are in the standard format for single-cell RNA sequencing data, typically generated by the CellRanger software. For more information on how CellRanger generates these files, please refer to the [Gene-Barcode Matrices Documentation](https://support.10xgenomics.com/single-cell-gene-expression/software/pipelines/latest/output/matrices).

Here is the directory structure:
```
download_path
└── scRNAseq
    ├── scRNAseq.zip
    ├── T0_1A
    │   ├── barcodes.tsv
    │   ├── genes.tsv
    │   └── matrix.mtx
    ├── T2_3B
    │   ├── barcodes.tsv
    │   ├── genes.tsv
    │   └── matrix.mtx
    ├── T4_5C
    │   ├── barcodes.tsv
    │   ├── genes.tsv
    │   └── matrix.mtx
    ├── T6_7D
    │   ├── barcodes.tsv
    │   ├── genes.tsv
    │   └── matrix.mtx
    └── T8_9E
        ├── barcodes.tsv
        ├── genes.tsv
        └── matrix.mtx
```
```


### **Using scprep to Import Data into Pandas DataFrame**

Next, we will use a toolkit called scprep to load and manipulate single-cell data. The function load_10X can automatically load 10X single-cell RNA sequencing datasets (as well as other datasets) into a Pandas DataFrame.

Let’s load the data and create a single matrix that we can use for preprocessing, visualization, and analysis.

#### **1. Import data** 

In [2]:
import pandas as pd
import numpy as np
import phate
import scprep
import magic
import matplotlib.pyplot as plt
import sklearn.preprocessing

# matplotlib settings for Jupyter notebooks only
%matplotlib inline

#### **2. Use scprep.io.load_10X to Import Three Matrices for Each Sample into a DataFrame (Reading the Data Will Take About 5 Minutes)**

Note: By default, scprep.io.load_10X uses Pandas' SparseDataFrame [(**see Pandas docs**)](https://pandas.pydata.org/pandas-docs/stable/sparse.html) to maximize memory efficiency. However, this may be slower than loading dense matrices. To load dense matrices, pass the sparse=False parameter to load_10X. We use gene_labels = 'both' to see gene symbols while preserving the uniqueness of gene IDs.


In [3]:
sparse=True
T1 = scprep.io.load_10X(os.path.join(download_path, "scRNAseq", "T0_1A"), sparse=sparse, gene_labels='both')
T2 = scprep.io.load_10X(os.path.join(download_path, "scRNAseq", "T2_3B"), sparse=sparse, gene_labels='both')
T3 = scprep.io.load_10X(os.path.join(download_path, "scRNAseq", "T4_5C"), sparse=sparse, gene_labels='both')
T4 = scprep.io.load_10X(os.path.join(download_path, "scRNAseq", "T6_7D"), sparse=sparse, gene_labels='both')
T5 = scprep.io.load_10X(os.path.join(download_path, "scRNAseq", "T8_9E"), sparse=sparse, gene_labels='both')
T1.head()

Unnamed: 0_level_0,RP11-34P13.3 (ENSG00000243485),FAM138A (ENSG00000237613),OR4F5 (ENSG00000186092),RP11-34P13.7 (ENSG00000238009),RP11-34P13.8 (ENSG00000239945),RP11-34P13.14 (ENSG00000239906),RP11-34P13.9 (ENSG00000241599),FO538757.3 (ENSG00000279928),FO538757.2 (ENSG00000279457),AP006222.2 (ENSG00000228463),...,AC007325.2 (ENSG00000277196),BX072566.1 (ENSG00000277630),AL354822.1 (ENSG00000278384),AC023491.2 (ENSG00000278633),AC004556.1 (ENSG00000276345),AC233755.2 (ENSG00000277856),AC233755.1 (ENSG00000275063),AC240274.1 (ENSG00000271254),AC213203.1 (ENSG00000277475),FAM231B (ENSG00000268674)
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACATACCAGAGG-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAACATTGAAAGCA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAACATTGAAGTGA-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAACATTGGAGGTG-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAACATTGGTTTCT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


#### **3. Library Size Filtering**

We filter out cells with very large or very small library sizes. For this dataset, library size is somewhat correlated with the samples, so we filter based on each sample. In this case, we exclude the top 20% and bottom 20% of cells from each sample. A simpler, less conservative filtering method can also yield similar results.

In [4]:
scprep.plot.plot_library_size(T1, percentile=20)

<Figure size 640x480 with 1 Axes>

In [5]:
filtered_batches = []
for batch in [T1, T2, T3, T4, T5]:
    batch = scprep.filter.filter_library_size(batch, percentile=20, keep_cells='above')
    batch = scprep.filter.filter_library_size(batch, percentile=75, keep_cells='below')
    filtered_batches.append(batch)
del T1, T2, T3, T4, T5 # removes objects from memory

#### **4. Merge All Datasets and Create a Time Series Vector Representing Each Sample**

In [6]:
EBT_counts, sample_labels = scprep.utils.combine_batches(
    filtered_batches, 
    ["Day 00-03", "Day 06-09", "Day 12-15", "Day 18-21", "Day 24-27"],
    append_to_cell_names=True
)
del filtered_batches # removes objects from memory
EBT_counts.head()

Unnamed: 0,A1BG (ENSG00000121410),A1BG-AS1 (ENSG00000268895),A1CF (ENSG00000148584),A2M (ENSG00000175899),A2M-AS1 (ENSG00000245105),A2ML1 (ENSG00000166535),A2ML1-AS1 (ENSG00000256661),A2ML1-AS2 (ENSG00000256904),A3GALT2 (ENSG00000184389),A4GALT (ENSG00000128274),...,ZXDC (ENSG00000070476),ZYG11A (ENSG00000203995),ZYG11B (ENSG00000162378),ZYX (ENSG00000159840),ZZEF1 (ENSG00000074755),ZZZ3 (ENSG00000036549),bP-21264C1.2 (ENSG00000278932),bP-2171C21.3 (ENSG00000279501),bP-2189O9.3 (ENSG00000279579),hsa-mir-1253 (ENSG00000272920)
AAACATTGAAAGCA-1_Day 00-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAACCGTGCAGAAA-1_Day 00-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAACCGTGGAAGGC-1_Day 00-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAACGCACCGGTAT-1_Day 00-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
AAACGCACCTATTC-1_Day 00-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0


<a id='preprocessing'></a>

## **2. Data Preprocessing: Filtering, Normalization, and Transformation**

### **Filtering**
We filter the data in the following ways:

1. Filter based on library size (if we did not perform this before merging batches).
2. Remove genes expressed in relatively few cells.
3. Remove dead cells.

It is important to note that dead cells should be filtered after library size normalization, as library size may not necessarily correlate with cell status.

**Filtering I: Library Size Filtering**

We have previously performed library size filtering because there is a strong correlation between library size and our samples. However, if you want to perform a simpler filtering, you can run the following operation here:

`EBT_counts, sample_labels = scprep.filter.library_size_filter(EBT_counts, sample_labels, cutoff=2000)`

#### **Filtering II: Filtering Genes with Low Detection Rates**

We filtered out genes that are expressed in 10 or fewer cells.

In [7]:
EBT_counts = scprep.filter.filter_rare_genes(EBT_counts, min_cells=10)

### **Normalization**

To correct for differences in library sizes, we divide the expression levels of each cell by its library size and then rescale by the median library size.

In Python, this operation can be performed using the preprocessing method `library_size_normalize()`.

In [8]:
EBT_counts = scprep.normalize.library_size_normalize(EBT_counts)

#### **Filtering III: Removing Dead Cells**

Dead cells are likely to have higher mitochondrial RNA expression levels than live cells. Therefore, we remove suspected dead cells by eliminating those with the highest average mitochondrial RNA expression levels.

First, let's take a look at the distribution of mitochondrial genes.

In [10]:
mito_genes = scprep.select.get_gene_set(EBT_counts, starts_with="MT-") # Get all mitochondrial genes. There are 14, FYI.
scprep.plot.plot_gene_set_expression(EBT_counts, genes=mito_genes, percentile=90)

<Figure size 640x480 with 1 Axes>

We can see that there is a sharp increase in mitochondrial RNA expression above the 0.9 percentile. Therefore, we filter out this portion of cells.

In [11]:
EBT_counts, sample_labels = scprep.filter.filter_gene_set_expression(
    EBT_counts, sample_labels, genes=mito_genes, 
    percentile=90, keep_cells='below')

### **Transformation**

In single-cell RNA sequencing analysis, the data is often subjected to log transformation. It is usually necessary to add small values to avoid taking log(0). In this tutorial, we completely avoid this issue by using a square root transformation instead. The square root function has a form similar to that of the log function and is more stable at zero.。

In [12]:
EBT_counts = scprep.transform.sqrt(EBT_counts)

<a id='embedding'></a>
## **3. Using PHATE to Embed Data**

### **3.1 Instantiating the PHATE Estimator**

The API of the PHATE model is similar to that of Scikit Learn. First, you need to instantiate a PHATE estimator object using parameters suitable for the given dataset. Then, you can use the `fit` and `fit_transform` functions to generate embeddings. For more information, please refer to the [**PHATE readthedocs page**](http://phate.readthedocs.io/).

We will use only the default parameters for now, but the following parameters can be adjusted (please read the documentation at [phate.readthedocs.io](https://phate.readthedocs.io/) for more information):

* `knn`: Number of nearest neighbors (default: 5). If your PHATE embeddings appear very disconnected, you can increase this value (e.g., set it to 20). If your dataset is very large (e.g., >100k cells), consider increasing `knn` as well.
* `decay`: Alpha decay (default: 15). Decreasing `decay` increases connectivity in the graph, while increasing `decay` reduces connectivity. This rarely needs adjustment. Set it to `None` for a k-nearest neighbors kernel.
* `t`: Operator power (default: 'auto'). This equals the number of smoothing operations performed on the data. It is automatically selected by default, but if your embeddings lack structure, you can increase it, or decrease it if the structure appears too compact.
* `gamma`: Information distance constant (default: 1). `gamma=1` gives the PHATE log potential, but other information distances can also be interesting. If most points seem concentrated in one part of the plot, you can try `gamma=0`.

Since we are looking for detailed structures, and we expect some trajectories to be sparse, we may want to decrease `knn` from the default value of 5 and lower `t` from the automatic value of 21 (as shown in the output above). For single-cell RNA sequencing, if you are looking for subtle structures, you can set `knn` to 3 or 4, or set it to 30 or 40 if you have hundreds of thousands of cells. We will also lower `alpha` to 15 to partially offset the decrease in connectivity caused by reducing `knn`.

In [13]:
phate_operator = phate.PHATE(n_jobs=-2, random_state=42)
Y_phate = phate_operator.fit_transform(EBT_counts)

Calculating PHATE...
  Running PHATE on 16821 observations and 17845 variables.
  Calculating graph and diffusion operator...
    Calculating PCA...
    Calculated PCA in 43.01 seconds.
    Calculating KNN search...
    Calculated KNN search in 9.08 seconds.
    Calculating affinities...
    Calculated affinities in 0.90 seconds.
  Calculated graph and diffusion operator in 55.26 seconds.
  Calculating landmark operator...
    Calculating SVD...
    Calculated SVD in 2.47 seconds.
    Calculating KMeans...
    Calculated KMeans in 2.98 seconds.
  Calculated landmark operator in 6.91 seconds.
  Calculating optimal t...
    Automatically selected t = 19
  Calculated optimal t in 1.31 seconds.
  Calculating diffusion potential...
  Calculated diffusion potential in 0.20 seconds.
  Calculating metric MDS...
  Calculated metric MDS in 5.78 seconds.
Calculated PHATE in 69.49 seconds.


In [14]:
scprep.plot.scatter2d(Y_phate, c=sample_labels, figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="PHATE")

<Figure size 1200x800 with 1 Axes>

<a id='trajectory'></a>
# **4. Modeling Cellular Dynamic Transitions Using TrajectoryNet**

This section demonstrates how to compute cellular transition trajectories in gene space using TrajectoryNet.

The `backward_trajectories.npy` computed in this section is an array with the shape `[timepoints, cells, pcs]`, the same as used in the paper, generated with the following command:

```
python main.py --save [SAVE_DIR] --dataset EB-PCA --top_k_reg 0.1 --training_noise 0.0 --max_dim 5
```

This is essentially the default setting with a small amount of density regularization (referred to as `top_k_reg`). This will compute the model and save some checkpoints, with the final weights stored in `checkpt.pt`. The main parameters for the experiments conducted in the paper are `--top_k_reg` (density regularization) and `--vecint` (velocity regularization).

We then run:

```
python eval.py --save [SAVE_DIR] --dataset EB-PCA --top_k_reg 0.1 --training_noise 0.0 --max_dim 5
```

This will create `backwards_trajectories.npy` in the `[SAVE_DIR]` directory from the saved model.

These trajectories will integrate points from the final time point back to the starting time point, totaling 100 evenly distributed time points. In other words, based on some existing embeddings in latent space (i.e., the reduced-dimensional data), such as the results from PCA, TrajectoryNet constructs dynamic trajectories in the gene space of this embedded data.

## **4.1 Running TrajectoryNet for Trajectory Inference**

In [18]:
%%bash
cd ./TrajectoryNet-master/TrajectoryNet/
python main.py --save ../results/fig8_results/ --dataset EB-PCA --top_k_reg 0.1 --training_noise 0.0 --max_dim 5


mkdir: cannot create directory ‘notebook_results’: File exists
/data/bioinfo/TrajectoryNet/TrajectoryNet-master/TrajectoryNet/main.py
""" main.py

Learns ODE from scrna data

"""
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import time

import torch
import torch.nn.functional as F
import torch.optim as optim

from TrajectoryNet.lib.growth_net import GrowthNet
from TrajectoryNet.lib import utils
from TrajectoryNet.lib.visualize_flow import visualize_transform
from TrajectoryNet.lib.viz_scrna import (
    save_trajectory,
    trajectory_to_video,
    save_vectors,
)
from TrajectoryNet.lib.viz_scrna import save_trajectory_density


# from train_misc import standard_normal_logprob
from TrajectoryNet.train_misc import (
    set_cnf_options,
    count_nfe,
    count_parameters,
    count_total_time,
    add_spectral_norm,
    spectral_norm_power_iteration,
    create_regularization_fns,
    get_regularization,
    append_regularization_to_log,
    build_mode

In [20]:
%%bash
cd ./TrajectoryNet-master/TrajectoryNet/
python eval.py --save ../results/fig8_results/ --dataset EB-PCA --top_k_reg 0.1 --training_noise 0.0 --max_dim 5


integrating backwards


## **4.2 Importing the Results of Trajectory Inference**

In [15]:
zs = np.load('./TrajectoryNet-master/results/fig8_results/backward_trajectories.npy')

In [16]:
zs.shape

(100, 3332, 5)

The output `zs` from TrajectoryNet is a tensor corresponding to 100 time points, with 3332 cells at the last time point, each having 5 principal components. To align these trajectories with the PHATE embedding, we extracted the principal components stored in `phate_operator.graph.data_nu` and calculated the mean and standard deviation using `StandardScaler`. We then scaled `zs` to the scale and mean of the first 5 principal components and plotted the trajectories of some sample cells from day 24 to day 27.

In [17]:
scaler = sklearn.preprocessing.StandardScaler()
scaler.fit(phate_operator.graph.data_nu)

In [18]:
phate_operator.graph.data_nu.shape

(16821, 100)

In [19]:
zss = zs * scaler.scale_[:5] + scaler.mean_[:5]

In [20]:
fig, ax = plt.subplots(1,1)
scprep.plot.scatter2d(phate_operator.graph.data_nu, c=sample_labels, figsize=(30,15), cmap="Spectral",
                      ticks=True, label_prefix="PC", ax=ax, title='Trajectories for 200 cells from last timepoint')

for i in range(200):
    ax.plot(zss[:,i,0], zss[:,i,1])

<Figure size 640x480 with 1 Axes>

In [21]:
trajectory_gene_space = np.dot(zss, phate_operator.graph.data_pca.components_[:5,:])

In [22]:
trajectory_gene_space.shape

(100, 3332, 17845)

In [23]:
genes = [x.split(' ')[0] for x in EBT_counts.columns]

In [24]:
eb_marker_genes = np.loadtxt('./TrajectoryNet-master/data/eb_genes.txt', dtype='str')

In [25]:
genes_index = {}

In [26]:
for gene in eb_marker_genes:
    genes_index[gene] = genes.index(gene)

In [27]:
trajectory_eb = trajectory_gene_space[:,:,np.array(list(genes_index.values()))]

In [28]:
# np.save('./trajectory_eb.npy', trajectory_eb)
np.save('./TrajectoryNet-master/results/fig8_results/trajectory_eb.npy', trajectory_eb)


In [29]:
trajectory_eb_magic = np.zeros((100,3332,68))

In [30]:
m_op = magic.MAGIC()
for i in range(100):
    trajectory_eb_magic[i,:,:] = m_op.fit_transform(trajectory_eb[i,:,:])

Calculating MAGIC...
  Running MAGIC on 3332 cells and 68 genes.
  Calculating graph and diffusion operator...
    Calculating KNN search...
    Calculated KNN search in 0.43 seconds.
    Calculating affinities...
    Calculated affinities in 0.44 seconds.
  Calculated graph and diffusion operator in 0.88 seconds.
  Calculating imputation...
  Calculated imputation in 0.02 seconds.
Calculated MAGIC in 0.90 seconds.
Calculating MAGIC...
  Running MAGIC on 3332 cells and 68 genes.
  Calculating graph and diffusion operator...
    Calculating KNN search...
    Calculated KNN search in 0.42 seconds.
    Calculating affinities...
    Calculated affinities in 0.43 seconds.
  Calculated graph and diffusion operator in 0.85 seconds.
  Calculating imputation...
  Calculated imputation in 0.02 seconds.
Calculated MAGIC in 0.87 seconds.
Calculating MAGIC...
  Running MAGIC on 3332 cells and 68 genes.
  Calculating graph and diffusion operator...
    Calculating KNN search...
    Calculated KNN se

In [31]:
# np.save('./trajectory_eb_magic.npy', trajectory_eb_magic)
np.save('./TrajectoryNet-master/results/trajectory_eb_magic.npy', trajectory_eb_magic)


## **4.3 Exploring Genes of Interest**

Next, we defined "endpoint genes" representing four different cell populations in this dataset. We ran MAGIC on the raw counts of these genes and visualized the estimated expression.

In [32]:
end_genes = ['PDGFRA ', 'HAND1', 'SOX17', 'ONECUT2', ]
end_points = ['Muscle', 'Cardiac', 'Endothelial', 'Neuronal',]

colors = dict(zip(*[end_genes, [plt.get_cmap('tab10')(i+1) for i in range(len(end_genes))]]))

In [33]:
other_genes = ['GATA6 ', 'SATB1', 'T ', 'EOMES', 'NANOG', 'TNNT2', 'DLX1', 'TBX18', 'MAP2 ']
genes_of_interest = [*other_genes, *end_genes]

genes_of_interest_end = scprep.select.get_gene_set(EBT_counts, starts_with=end_genes)
genes_of_interest_full = scprep.select.get_gene_set(EBT_counts, starts_with=genes_of_interest)

In [34]:
genes_mask = EBT_counts.columns.isin(genes_of_interest_full)
genes = EBT_counts.columns[genes_mask]

In [35]:
inverse = np.dot(zss, phate_operator.graph.data_pca.components_[:5, genes_mask])
end_gene_indexes = [(np.where(genes_of_interest_full == gene)[0][0]) for gene in genes_of_interest_end]

In [36]:
m_op = magic.MAGIC()
m_op.graph = phate_operator.graph
EBT_magic = m_op.transform(EBT_counts, genes=genes_of_interest_full)

Calculating imputation...
Calculated imputation in 0.03 seconds.


In [37]:
fig, ax = plt.subplots(1,len(end_genes), figsize=(4*len(end_genes),4))
ax = ax.flatten()
for i in range(len(end_genes)):
    scprep.plot.scatter2d(Y_phate, 
                          c=EBT_magic[scprep.select.get_gene_set(EBT_counts, starts_with=end_genes[i])], 
                          ax=ax[i],
                          title='%s - %s' % (end_points[i], end_genes[i]),
                          ticks=[],
                         )

<Figure size 1600x400 with 8 Axes>

## **4.4 Plotting Cell Trajectories**

We defined `EBT_5` as the counts from the last time point and then selected the top 9 cells expressing each endpoint gene at this final time point. We then plotted the trajectories of these cells.

In [38]:
EBT_5 = EBT_counts[sample_labels == 'Day 24-27']

In [39]:
masks = {}
top_idxs = {}
for gene in end_genes:
    top_idx = np.array(EBT_5[scprep.select.get_gene_set(EBT_counts, starts_with=gene)]).flatten().argsort()[-9:]
    top_mask = np.array(pd.Series(range(3332)).isin(top_idx))
    masks[gene] = top_mask
    top_idxs[gene] = top_idx
    print(gene, top_idx)

PDGFRA  [1896 2642  488  748  664  697 1419  300  432]
HAND1 [1959 3195  936 2133  375  668  501 2484 1476]
SOX17 [2053 2595 2664 1988 2353 1129 2823  738  532]
ONECUT2 [2920  555 2283  718 1275 2125 1838 2129 2277]


In [40]:
fig, ax = plt.subplots(1,1)
scprep.plot.scatter2d(Y_phate, c='Gray', alpha=0.1, ax=ax)
for i, gene in enumerate(end_genes):
    scprep.plot.scatter2d(Y_phate[sample_labels=='Day 24-27'][masks[gene]], 
                          ax=ax, c = colors[gene], label=end_points[i], ticks=[])
    
plt.legend()

<Figure size 640x480 with 1 Axes>

In [41]:
fig, ax = plt.subplots(1,1)
scprep.plot.scatter2d(phate_operator.graph.data_nu, c='Gray', alpha=0.1, ax=ax)
for i, gene in enumerate(end_genes):
    scprep.plot.scatter2d(phate_operator.graph.data_nu[sample_labels=='Day 24-27'][masks[gene]], ax=ax, 
                          label='%s - %s' % (end_points[i], end_genes[i]), c=colors[gene], ticks=[])

for gene in end_genes:
    for g in top_idxs[gene]:
        ax.plot(zss[:,g,0], zss[:,g,1], c=colors[gene])
plt.legend()

<Figure size 640x480 with 1 Axes>

## **4.5 Plotting Gene Expression Dynamics**

For these 36 cells (9 genes selected for each endpoint), we plot the changes in gene expression across 100 time points and identify how different subsets of cells exhibit distinct dynamic behaviors.

In [42]:
fig, ax = plt.subplots(2,2, figsize=(6,6), sharex=True)
ax = ax.flatten()

for i, gene in enumerate(end_genes):
    for j, eg in enumerate(end_genes):
        for cell in top_idxs[eg]:
            ax[i].plot(np.linspace(1,5,100)[::-1], inverse[:,cell,end_gene_indexes[i]], c=colors[eg])
            ax[i].set_title(gene)
            ax[i].set_yticks([])
            ax[i].set_xticks(range(1,6))

<Figure size 600x600 with 4 Axes>