## Example 1: Developing Chicken Heart
This notebook illustrates the training of SpaDOT on the [developing chicken heart]((https://doi.org/10.1038/s41467-021-21892-z)). The spatiotemporal dataset is measured by 10X Visium and collected from four stages: Day 4, Day 7, Day 10 and Day 14. In this dataset, SpaDOT accurately identifies valvulogenesis - a valve splits into artrioventricular valve and semilunar valve at Day 14.

For your convenience, you can download the processed data [here](https://www.dropbox.com/scl/fi/xklj0dxkd2wz10ahgbwg1/ChickenHeart.h5ad?rlkey=06245qjhv4ohij5530a1az91c&dl=0). If you would like to see the preprocessing steps, please expand the section below:

<details>
<summary>Click to expand</summary>

First, we downloaded the spatial transcritpomics data from [GSE149457](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE149457) and selected 
```
GSM4502482_chicken_heart_spatial_RNAseq_D4_filtered_feature_bc_matrix.h5
GSM4502483_chicken_heart_spatial_RNAseq_D7_filtered_feature_bc_matrix.h5
GSM4502484_chicken_heart_spatial_RNAseq_D10_filtered_feature_bc_matrix.h5
GSM4502485_chicken_heart_spatial_RNAseq_D14_filtered_feature_bc_matrix.h5
```

Second, we downloaded spatial coordinates from the analysis code shared by the paper on [Github](https://github.com/madhavmantri/chicken_heart/tree/master/data/chicken_heart_spatial_RNAseq_processed):


```
chicken_heart_spatial_RNAseq_D4_tissue_positions_list.csv
chicken_heart_spatial_RNAseq_D7_tissue_positions_list.csv
chicken_heart_spatial_RNAseq_D10_tissue_positions_list.csv
chicken_heart_spatial_RNAseq_D14_tissue_positions_list.csv
```

Third, we used the script `process_ChickenHeart.py` provided [here](https://github.com/marvinquiet/SpaDOT/blob/main/analyses/process_ChickenHeart.py) to preprocess the data by integrating them into one anndata with `timepoint` in anndata observations (obs) as one-hot encoder indicating four time points, `0`, `1`, `2` and `3` indicate Day 4, Day 7, Day 10 and Day 14, respectively. We have also put the spatial coordinates with keyword `spatial` as a numpy array inside anndata observation metadata (obsm).

After running the `process_ChickenHeart.py`, we will obtain the file `ChickenHeart.h5ad`. 
</details>

### Step 1: Perform data preprocessing and feature selection

After obtaining `ChickenHeart.h5ad`, we the perform the data preprocessing. 

In [None]:
import os
from argparse import Namespace
import SpaDOT

result_dir = './ChickenHeart_output'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

# --- create args
preprocess_args = Namespace(
    data='./ChickenHeart.h5ad',
    prefix='preprocessed_',
    feature_selection=True,
    output_dir=result_dir
)
SpaDOT.preprocess(preprocess_args)

  from pkg_resources import get_distribution, DistributionNotFound
  from .autonotebook import tqdm as notebook_tqdm


using default random_seed 1448145, will run SCT without randomness
gene-cell umi shape (10521, 747), n_genes 2000 n_cells 747
get_model_pars finished, cost 5.468864679336548 seconds


  o = np.array(sorted(range(0, len(bins)), key=lambda x1: bins[x1]))
  o = np.array(sorted(range(0, len(bins)), key=lambda x1: bins[x1]))
  o = np.array(sorted(range(0, len(bins)), key=lambda x1: bins[x1]))
  o = np.array(sorted(range(0, len(x)), key=lambda i: x[i]))
  o = np.array(sorted(range(0, len(x)), key=lambda i: x[i]))


ksmooth finished, cost 0.6347630023956299 seconds
reg_model_pars finished, cost 1.5023093223571777 seconds
pearson_residual cost 1.5920424461364746 seconds
umi_corrected cost 0.968726634979248 seconds
scale data cost 0.0821387767791748 seconds
Timepoint: 0, Number of cells: 747, Number of genes: 10521
## ===== SPARK-X INPUT INFORMATION ==== 
## number of total samples: 747
## number of total genes: 10521
## Running with 4 cores
## Testing With Projection Kernel
## Testing With Gaussian Kernel 1
## Testing With Gaussian Kernel 2
## Testing With Gaussian Kernel 3
## Testing With Gaussian Kernel 4
## Testing With Gaussian Kernel 5
## Testing With Cosine Kernel 1
## Testing With Cosine Kernel 2
## Testing With Cosine Kernel 3
## Testing With Cosine Kernel 4
## Testing With Cosine Kernel 5
Time taken for mixture kernels: 3.01 seconds


KeyboardInterrupt: 

### Step 2: Train SpaDOT to obtain latent representations

After preprocessing, we train the SpaDOT model to obtain latent representations. We recommend running the training on a GPU, which typically takes about 5 minutes. Training on a CPU is also possible but will require significantly more time.

In [None]:
train_args = Namespace(
    data=result_dir+os.sep+'preprocessed_ChickenHeart.h5ad',
    output_dir=result_dir,
    prefix="",
    config = None, # use default configuration
    save_model = True,
    device = 'cuda:0' # if GPU is not available, you can also use cpu, but not recommended
)
SpaDOT.train(train_args)

Loading data...
Preparing data...
Calculating spatial graph...
The graph contains 4482 edges, 747 cells.
6.0000 neighbors per cell on average.
Calculating spatial graph...


  edge_index=torch.tensor(tp_edge_index, dtype=torch.long),


The graph contains 23592 edges, 1966 cells.
12.0000 neighbors per cell on average.
Calculating spatial graph...


  edge_index=torch.tensor(tp_edge_index, dtype=torch.long),


The graph contains 22992 edges, 1916 cells.
12.0000 neighbors per cell on average.
Calculating spatial graph...


  edge_index=torch.tensor(tp_edge_index, dtype=torch.long),


The graph contains 23604 edges, 1967 cells.
12.0000 neighbors per cell on average.
Training model...


  edge_index=torch.tensor(tp_edge_index, dtype=torch.long),
  nn.init.xavier_uniform(self.gat1.lin.weight)
  nn.init.xavier_uniform(self.gat2.lin.weight)
  nn.init.xavier_uniform(self.gat3.lin.weight)


SpaDOT(
  (SVGPEncoder): SVGPEncoder(
    (SVGP_encoder_net): Sequential(
      (0): Linear(in_features=2954, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): Linear(in_features=256, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.01)
    )
    (SVGP_fc): Linear(in_features=64, out_features=20, bias=True)
  )
  (GATEncoder): GATEncoder(
    (gat1): GATConv(2954, 512, heads=4)
    (gat2): GATConv(2048, 512, heads=4)
    (gat3): GATConv(2048, 512, heads=4)
    (GAT_fc): Linear(in_features=512, out_features=20, bias=True)
  )
  (decoder): Decoder(
    (decoder_net): Sequential(
      (0): Linear(in_features=20, out_features=64, bias=True)
      (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): Linea

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1: Training time: 3 seconds, ELBO: 199.92934340, Recon loss: 1980.31486427, SVGP KL loss: -144.38993206, GAT KL loss: 2619.26694914, Alignment loss: 16.35930274, Kmeans loss: 0.00000000, OT loss: 0.00000000


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 11: Training time: 3 seconds, ELBO: 183.69589750, Recon loss: 1803.56546095, SVGP KL loss: -9.39387559, GAT KL loss: 3897.85068257, Alignment loss: 7.02857143, Kmeans loss: 22.46709194, OT loss: 0.00000000


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 21: Training time: 3 seconds, ELBO: 180.51586263, Recon loss: 1776.48100437, SVGP KL loss: -2.61144509, GAT KL loss: 4351.56473364, Alignment loss: 3.30476795, Kmeans loss: 21.02128920, OT loss: 0.00000000


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 31: Training time: 3 seconds, ELBO: 179.30501430, Recon loss: 1764.36556343, SVGP KL loss: -1.12587345, GAT KL loss: 4354.95360031, Alignment loss: 3.14776449, Kmeans loss: 21.18186142, OT loss: 0.00000000


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 41: Training time: 3 seconds, ELBO: 178.65540986, Recon loss: 1757.43223922, SVGP KL loss: -1.40420062, GAT KL loss: 4460.16931497, Alignment loss: 2.97084869, Kmeans loss: 21.69084135, OT loss: 0.00000000


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 51: Training time: 3 seconds, ELBO: 178.24418488, Recon loss: 1751.28202528, SVGP KL loss: -2.85507334, GAT KL loss: 4485.64796494, Alignment loss: 2.64995224, Kmeans loss: 20.43639373, OT loss: 0.35878296


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 61: Training time: 3 seconds, ELBO: 178.00210147, Recon loss: 1748.51838869, SVGP KL loss: -1.02684434, GAT KL loss: 4881.82801740, Alignment loss: 2.80988362, Kmeans loss: 20.18780617, OT loss: 0.36231082


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 71: Training time: 3 seconds, ELBO: 177.15708616, Recon loss: 1741.50812431, SVGP KL loss: -2.00642915, GAT KL loss: 4782.77606446, Alignment loss: 2.46534532, Kmeans loss: 19.47674959, OT loss: 0.33378664


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 81: Training time: 3 seconds, ELBO: 176.84784008, Recon loss: 1738.66490035, SVGP KL loss: -0.92498846, GAT KL loss: 4920.20155377, Alignment loss: 2.35207746, Kmeans loss: 19.21405369, OT loss: 0.33271677


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Epoch 91: Training time: 2 seconds, ELBO: 176.50096695, Recon loss: 1735.18748160, SVGP KL loss: -1.74584502, GAT KL loss: 5085.28900723, Alignment loss: 2.45251149, Kmeans loss: 18.85693263, OT loss: 0.34274548


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torc

OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
OT iter 0
OT iter 1
OT iter 2
Training finished...
Training time: 350 seconds.
Model saved to ./output


  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)
  X = torch.tensor(X, dtype=self.dtype, device=self.device)
  Y = torch.tensor(Y, dtype=self.dtype, device=self.device)
  edge_index = torch.tensor(edge_index, dtype=torch.long, device=self.device)


The latent representations are stored as `latent.h5ad`.

### Step 3: Infer spatial domains and domain dynamics

Once the training stage finishes, we can obtain spatial domains and generate domain dynamics. If we have prior knowledge on how many domains we might have (given by the original study), we can run:

In [None]:
analyze_args = Namespace(
    data=result_dir+os.sep+'latent.h5ad',
    prefix="",
    output_dir=result_dir,
    n_clusters=[5, 7, 7, 6]
)
SpaDOT.analyze(analyze_args)

  latent_cell_sets = adata.obs.groupby('SpaDOT_pred_labels').apply(lambda x: x.index.tolist()).to_dict()


The spatial domain images are saved in `output/` as `{tp}_domains.png`, while the dot plots indicating transitions are saved as `transition_dotplot_{tp_i}_{tp_i+1}.png`.

#### Spatial domains for developing chicken heart
| Timepoint | Day 4 | Day 7 | Day 10 | Day 14 | 
|-----------|-------|-------|--------|--------|
| Spatial Domains | ![Day 4](output/0_domains.png) | ![Day 7](output/1_domains.png) | ![Day 10](output/2_domains.png) | ![Day 14](output/3_domains.png) | 

### Domain dynamics analysis

| Timepoint | Day 4 --> Day 7 | Day 7 --> Day 10 | Day 10 --> Day 14 | 
|-----------|-----------------|------------------|-------------------|
| OT transition | ![Day 4to7](output/transition_dotplot_0_1.png) | ![Day 7to10](output/transition_dotplot_1_2.png) |  ![Day 10to14](output/transition_dotplot_2_3.png) |


### Step 4 (Optional):  Infer domains and dynamics using adaptive methods

When performing exploratory studies, number of domains are generally unknown. Therefore, we can remove `--n_clusters` option and use Elbow method to adaptively detect number of clusters.

In [None]:
analyze_args = Namespace(
    data=result_dir+os.sep+'latent.h5ad',
    prefix="",
    output_dir=result_dir,
    n_clusters=None
)
SpaDOT.analyze(analyze_args)

  latent_cell_sets = adata.obs.groupby('SpaDOT_pred_labels').apply(lambda x: x.index.tolist()).to_dict()


We then have the plot of calculating the within-cluster sum of squares (WSS) of KMeans with the number of clusters ranging from 5 to 20. We detect the Elbow point and select the corresponding cluster number.

### Output WSS per cluster and spatial domains

| Timepoint | Day 4 | Day 7 | Day 10 | Day 14 | 
|-----------|-------|-------|--------|--------|
| WSS per cluster | ![Day 4](output/0_WSS_vs_Clusters.png) | ![Day 7](output/1_WSS_vs_Clusters.png) | ![Day 10](output/2_WSS_vs_Clusters.png) | ![Day 14](output/0_WSS_vs_Clusters.png) | 
| Spatial Domains | ![Day 4](output/0_domains_adaptive.png) | ![Day 7](output/1_domains_adaptive.png) | ![Day 10](output/2_domains_adaptive.png) | ![Day 14](output/3_domains_adaptive.png) | 

### Output OT analysis

| Timepoint | Day 4 --> Day 7 | Day 7 --> Day 10 | Day 10 --> Day 14 | 
|-----------|-----------------|------------------|-------------------|
| OT transition | ![Day 4to7](output/transition_dotplot_0_1_adaptive.png) | ![Day 7to10](output/transition_dotplot_1_2_adaptive.png) |  ![Day 10to14](output/transition_dotplot_2_3_adaptive.png) |