## Example 1: Developing Chicken Heart
This notebook illustrates the training of SpaDOT on the [developing chicken heart]((https://doi.org/10.1038/s41467-021-21892-z)). The spatiotemporal dataset is measured by 10X Visium and collected from four stages: Day 4, Day 7, Day 10 and Day 14. In this dataset, SpaDOT accurately identifies valvulogenesis - a valve splits into artrioventricular valve and semilunar valve at Day 14.

For your convenience, you can download the processed data [here](https://www.dropbox.com/scl/fi/xklj0dxkd2wz10ahgbwg1/ChickenHeart.h5ad?rlkey=06245qjhv4ohij5530a1az91c&dl=0). If you would like to see the preprocessing steps, please expand the section below:

<details>
<summary>Click to expand</summary>

First, we downloaded the spatial transcritpomics data from [GSE149457](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE149457) and selected 
```
GSM4502482_chicken_heart_spatial_RNAseq_D4_filtered_feature_bc_matrix.h5
GSM4502483_chicken_heart_spatial_RNAseq_D7_filtered_feature_bc_matrix.h5
GSM4502484_chicken_heart_spatial_RNAseq_D10_filtered_feature_bc_matrix.h5
GSM4502485_chicken_heart_spatial_RNAseq_D14_filtered_feature_bc_matrix.h5
```

Second, we downloaded spatial coordinates from the analysis code shared by the paper on [Github](https://github.com/madhavmantri/chicken_heart/tree/master/data/chicken_heart_spatial_RNAseq_processed):


```
chicken_heart_spatial_RNAseq_D4_tissue_positions_list.csv
chicken_heart_spatial_RNAseq_D7_tissue_positions_list.csv
chicken_heart_spatial_RNAseq_D10_tissue_positions_list.csv
chicken_heart_spatial_RNAseq_D14_tissue_positions_list.csv
```

Third, we used the script `process_ChickenHeart.py` provided [here](https://github.com/marvinquiet/SpaDOT/blob/main/analyses/process_ChickenHeart.py) to preprocess the data by integrating them into one anndata with `timepoint` in anndata observations (obs) as one-hot encoder indicating four time points, `0`, `1`, `2` and `3` indicate Day 4, Day 7, Day 10 and Day 14, respectively. We have also put the spatial coordinates with keyword `spatial` as a numpy array inside anndata observation metadata (obsm).

After running the `process_ChickenHeart.py`, we will obtain the file `ChickenHeart.h5ad`. 
</details>

### Step 1: Perform data preprocessing and feature selection

After obtaining `ChickenHeart.h5ad`, we the perform the data preprocessing. 

In [8]:
import os
from argparse import Namespace

# --- for neat output
import warnings
# Suppress all UserWarnings
warnings.filterwarnings("ignore", category=UserWarning)
# Suppress all FutureWarnings
warnings.filterwarnings("ignore", category=FutureWarning)

import SpaDOT

result_dir = './ChickenHeart_output'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

# --- create args
preprocess_args = Namespace(
    data='./ChickenHeart.h5ad',
    prefix='preprocessed_',
    feature_selection=True,
    output_dir=result_dir
)
SpaDOT.preprocess(preprocess_args)

using default random_seed 1448145, will run SCT without randomness
gene-cell umi shape (10521, 747), n_genes 2000 n_cells 747
get_model_pars finished, cost 5.463442325592041 seconds
ksmooth finished, cost 0.41453099250793457 seconds
reg_model_pars finished, cost 1.293031930923462 seconds
pearson_residual cost 1.378955602645874 seconds
umi_corrected cost 0.19406485557556152 seconds
scale data cost 0.07704281806945801 seconds
Timepoint: 0, Number of cells: 747, Number of genes: 10521
## ===== SPARK-X INPUT INFORMATION ==== 
## number of total samples: 747
## number of total genes: 10521
## Running with 4 cores
## Testing With Projection Kernel
## Testing With Gaussian Kernel 1
## Testing With Gaussian Kernel 2
## Testing With Gaussian Kernel 3
## Testing With Gaussian Kernel 4
## Testing With Gaussian Kernel 5
## Testing With Cosine Kernel 1
## Testing With Cosine Kernel 2
## Testing With Cosine Kernel 3
## Testing With Cosine Kernel 4
## Testing With Cosine Kernel 5
Time taken for mixtu

### Step 2: Train SpaDOT to obtain latent representations

After preprocessing, we train the SpaDOT model to obtain latent representations. We recommend running the training on a GPU, which typically takes about 5 minutes. Training on a CPU is also possible but will require significantly more time.

In [9]:
train_args = Namespace(
    data=result_dir+os.sep+'preprocessed_ChickenHeart.h5ad',
    output_dir=result_dir,
    prefix="",
    config = None, # use default configuration
    save_model = True,
    device = 'cuda:0' # if GPU is not available, you can also use cpu, but not recommended
)
SpaDOT.train(train_args)

Loading data...
Preparing data...
Calculating spatial graph...
The graph contains 4482 edges, 747 cells.
6.0000 neighbors per cell on average.
Calculating spatial graph...
The graph contains 23592 edges, 1966 cells.
12.0000 neighbors per cell on average.
Calculating spatial graph...
The graph contains 22992 edges, 1916 cells.
12.0000 neighbors per cell on average.
Calculating spatial graph...
The graph contains 23604 edges, 1967 cells.
12.0000 neighbors per cell on average.
Training model...
SpaDOT(
  (SVGPEncoder): SVGPEncoder(
    (SVGP_encoder_net): Sequential(
      (0): Linear(in_features=2954, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): Linear(in_features=256, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.01)
    )
    (SVGP_fc): Linear(in_fea

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1: Training time: 2 seconds, ELBO: 199.92934340, Recon loss: 1980.31486427, SVGP KL loss: -144.38993204, GAT KL loss: 2619.26694918, Alignment loss: 16.35930274, Kmeans loss: 0.00000000, OT loss: 0.00000000


 10%|█         | 10/100 [00:28<04:07,  2.75s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 11: Training time: 2 seconds, ELBO: 183.68761158, Recon loss: 1803.41816353, SVGP KL loss: -9.40707687, GAT KL loss: 3903.62009371, Alignment loss: 7.02553735, Kmeans loss: 22.52879484, OT loss: 0.00000000


 20%|██        | 20/100 [00:59<04:19,  3.24s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 21: Training time: 2 seconds, ELBO: 180.46315527, Recon loss: 1776.28481833, SVGP KL loss: -2.49016685, GAT KL loss: 4335.06513675, Alignment loss: 3.30482818, Kmeans loss: 20.70684101, OT loss: 0.00000000


 30%|███       | 30/100 [01:32<03:49,  3.28s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 31: Training time: 2 seconds, ELBO: 179.41187039, Recon loss: 1764.10983242, SVGP KL loss: -1.42515902, GAT KL loss: 4444.60038891, Alignment loss: 3.51083552, Kmeans loss: 22.05343555, OT loss: 0.00000000


 40%|████      | 40/100 [02:05<03:16,  3.28s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 41: Training time: 2 seconds, ELBO: 178.55708155, Recon loss: 1757.15024718, SVGP KL loss: -2.67722837, GAT KL loss: 4415.97568480, Alignment loss: 2.82303327, Kmeans loss: 21.18155933, OT loss: 0.00000000


 50%|█████     | 50/100 [02:37<02:43,  3.28s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 51: Training time: 3 seconds, ELBO: 178.21011323, Recon loss: 1750.13058767, SVGP KL loss: -3.01617111, GAT KL loss: 4643.15698161, Alignment loss: 2.87018294, Kmeans loss: 20.59555884, OT loss: 0.38616459


 60%|██████    | 60/100 [03:12<02:16,  3.42s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 61: Training time: 3 seconds, ELBO: 177.63253556, Recon loss: 1745.47938430, SVGP KL loss: -2.26789193, GAT KL loss: 4785.65195298, Alignment loss: 2.63114068, Kmeans loss: 19.92806750, OT loss: 0.35011112


 70%|███████   | 70/100 [03:42<01:27,  2.91s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 71: Training time: 2 seconds, ELBO: 177.12075609, Recon loss: 1740.57220245, SVGP KL loss: -1.64303954, GAT KL loss: 4911.65964914, Alignment loss: 2.42036297, Kmeans loss: 19.32665238, OT loss: 0.39766835


 80%|████████  | 80/100 [04:11<00:56,  2.84s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 81: Training time: 2 seconds, ELBO: 177.39990779, Recon loss: 1742.40926097, SVGP KL loss: -0.93593225, GAT KL loss: 4871.13185899, Alignment loss: 2.84462279, Kmeans loss: 19.82739471, OT loss: 0.40466676


 90%|█████████ | 90/100 [04:43<00:33,  3.38s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 91: Training time: 3 seconds, ELBO: 176.66554751, Recon loss: 1736.56405506, SVGP KL loss: -0.86964505, GAT KL loss: 4999.42434799, Alignment loss: 2.32719651, Kmeans loss: 18.97556704, OT loss: 0.37892322


100%|██████████| 100/100 [05:18<00:00,  3.18s/it]

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Training finished...
Training time: 318 seconds.
Model saved to ./ChickenHeart_output





The latent representations are stored as `latent.h5ad`.

### Step 3: Infer spatial domains and domain dynamics

Once the training stage finishes, we can obtain spatial domains and generate domain dynamics. If we have prior knowledge on how many domains we might have (given by the original study), we can run:

In [10]:
analyze_args = Namespace(
    data=result_dir+os.sep+'latent.h5ad',
    prefix="",
    output_dir=result_dir,
    n_clusters=[5, 7, 7, 6]
)
SpaDOT.analyze(analyze_args)



The spatial domain images are saved in `ChickenHeart_output/` as `{tp}_domains.png`, while the dot plots indicating transitions are saved as `transition_dotplot_{tp_i}_{tp_i+1}.png`.

#### Spatial domains for developing chicken heart
| Timepoint | Day 4 | Day 7 | Day 10 | Day 14 | 
|-----------|-------|-------|--------|--------|
| Spatial Domains | ![Day 4](ChickenHeart_output/0_domains.png) | ![Day 7](ChickenHeart_output/1_domains.png) | ![Day 10](ChickenHeart_output/2_domains.png) | ![Day 14](ChickenHeart_output/3_domains.png) | 

### Domain dynamics analysis

| Timepoint | Day 4 --> Day 7 | Day 7 --> Day 10 | Day 10 --> Day 14 | 
|-----------|-----------------|------------------|-------------------|
| OT transition | ![Day 4to7](ChickenHeart_output/transition_dotplot_0_1.png) | ![Day 7to10](ChickenHeart_output/transition_dotplot_1_2.png) |  ![Day 10to14](ChickenHeart_output/transition_dotplot_2_3.png) |


### Step 4 (Optional):  Infer domains and dynamics using adaptive methods

When performing exploratory studies, number of domains are generally unknown. Therefore, we can remove `--n_clusters` option and use Elbow method to adaptively detect number of clusters.

In [11]:
analyze_args = Namespace(
    data=result_dir+os.sep+'latent.h5ad',
    prefix="adaptive_",
    output_dir=result_dir,
    n_clusters=None
)
SpaDOT.analyze(analyze_args)



We then have the plot of calculating the within-cluster sum of squares (WSS) of KMeans with the number of clusters ranging from 5 to 20. We detect the Elbow point and select the corresponding cluster number.

### Output WSS per cluster and spatial domains

| Timepoint | Day 4 | Day 7 | Day 10 | Day 14 | 
|-----------|-------|-------|--------|--------|
| WSS per cluster | ![Day 4](ChickenHeart_output/adaptive_0_WSS_vs_Clusters.png) | ![Day 7](ChickenHeart_output/adaptive_1_WSS_vs_Clusters.png) | ![Day 10](ChickenHeart_output/adaptive_2_WSS_vs_Clusters.png) | ![Day 14](ChickenHeart_output/adaptive_0_WSS_vs_Clusters.png) | 
| Spatial Domains | ![Day 4](ChickenHeart_output/adaptive_0_domains.png) | ![Day 7](ChickenHeart_output/adaptive_1_domains.png) | ![Day 10](ChickenHeart_output/adaptive_2_domains.png) | ![Day 14](ChickenHeart_output/adaptive_3_domains.png) | 

### Output OT analysis

| Timepoint | Day 4 --> Day 7 | Day 7 --> Day 10 | Day 10 --> Day 14 | 
|-----------|-----------------|------------------|-------------------|
| OT transition | ![Day 4to7](ChickenHeart_output/adaptive_transition_dotplot_0_1.png) | ![Day 7to10](ChickenHeart_output/adaptive_transition_dotplot_1_2.png) |  ![Day 10to14](ChickenHeart_output/adaptive_transition_dotplot_2_3.png) |