In [1]:
import pandas as pd
import anndata as ad

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping #ModelCheckpoint

import sys
sys.path.append("/home/campbell/yulee/github/st/")

from starling import utility
from starling import starling

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## load data (annData object)
## run a clustering method, the utility provides GMM, KM (Kmeans) or PG (PhenoGraph), k can be ignored for PG
## the resulting arrays (cluster centroids, variances and labels) setup for STARLING initializations
adata = utility.init_clustering(ad.read_h5ad('sample_input.h5ad'), 'KM', k=10)

  super()._check_params_vs_input(X, default_n_init=10)


In [3]:
## setup starling with initializations
st = starling.ST(adata)

  self.S = torch.tensor(self.adata.obs[self.cell_size_col_name]) if self.model_cell_size == 'Y' else None


In [4]:
## set early stopping criterion
cb_early_stopping = EarlyStopping(monitor = 'train_loss', mode = 'min', verbose = False)

## log training results via tensorboard
log_tb = TensorBoardLogger(save_dir = 'log')

In [5]:
## train STARLING
trainer = pl.Trainer(max_epochs = 100, accelerator = 'auto', devices = 'auto', callbacks = [cb_early_stopping], logger=[log_tb])
trainer.fit(st)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
2023-10-18 20:12:55.392903: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-18 20:12:55.456729: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-18 20:12:55.456762: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-18 20:12:55.456795: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: At

Epoch 0:  33%|███▎      | 9/27 [00:01<00:02,  8.97it/s, v_num=2, train_loss_step=88.60]

  v = prob_data_given_gamma_d1.T + log_delta[1] - prob_data                    ## p(d=1,gamma=[c,c']|data)


Epoch 99: 100%|██████████| 27/27 [00:00<00:00, 37.73it/s, v_num=2, train_loss_step=58.20, train_loss_epoch=58.40]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 27/27 [00:00<00:00, 36.78it/s, v_num=2, train_loss_step=58.20, train_loss_epoch=58.40]


In [6]:
## retrive starling results
st.result()

##Note: the information can be retrieved in annData object.
   - st.adata.varm['init_exp_centroids'] -- initial expression cluster centroids (P x C matrix)
   - st.adata.varm['st_exp_centroids'] -- STARLING expression cluster centroids (P x C matrix)
   - st.adata.uns['init_cell_size_centroids'] -- initial cell size centroids if STARLING models cell size
   - st.adata.uns['st_cell_size_centroids'] -- initial & STARLING cell size centroids if STARLING models cell size
   - st.adata.obsm['assignment_prob_matrix'] -- cell assignment distributions (N x C maxtrix)
   - st.adata.obs['doublet'] -- doublet indicator
   - st.adata.obs['doublet_prob'] -- doublet probabilities
   - st.adata.obs['init_label'] -- inital assignments
   - st.adata.obs['st_label'] -- STARLING assignments
   - st.adata.obs['max_assign_prob'] -- STARLING max probabilites of assignments
      - N: # of cells; C: # of clusters; P: # of proteins

In [7]:
## st object can be saved
torch.save(st, 'model.pt')

In [8]:
st.adata

AnnData object with n_obs × n_vars = 13685 × 24
    obs: 'sample', 'id', 'x', 'y', 'area', 'area_convex', 'neighbor', 'init_label', 'st_label', 'doublet_prob', 'doublet', 'max_assign_prob'
    uns: 'init_cell_size_centroids', 'init_cell_size_variances', 'st_cell_size_centroids'
    obsm: 'assignment_prob_matrix'
    varm: 'init_exp_centroids', 'init_exp_variances', 'st_exp_centroids'

In [9]:
st.adata.obs

Unnamed: 0,sample,id,x,y,area,area_convex,neighbor,init_label,st_label,doublet_prob,doublet,max_assign_prob
4_1,4,1,0.785714,7.785714,14,14,0,0,2,0.101385,0,0.898610
4_2,4,2,0.823529,22.294117,17,17,0,9,2,0.469688,0,0.471184
4_3,4,3,0.875000,79.500000,16,16,1,3,2,0.101763,0,0.898237
4_4,4,4,0.666667,270.500000,12,12,0,5,2,0.127615,0,0.870983
4_5,4,5,0.823529,279.294130,17,17,1,6,6,0.642768,1,0.287857
...,...,...,...,...,...,...,...,...,...,...,...,...
4_13681,4,13681,997.769200,754.500000,26,26,0,8,8,0.045701,0,0.953996
4_13682,4,13682,998.153900,127.615390,13,13,0,5,5,0.103581,0,0.896419
4_13683,4,13683,998.153900,160.000000,13,13,1,0,2,0.115575,0,0.884402
4_13684,4,13684,997.580600,242.580640,31,33,1,8,8,0.117628,0,0.872245


In [10]:
## initial expression centriods (p x c) matrix
pd.DataFrame(st.adata.varm['init_exp_centroids'], index=st.adata.var_names)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
SMA,1.148126,0.582641,2.120535,2.2955,1.538811,0.491241,0.804634,0.716098,0.614018,1.684391
ECadherin,0.941576,1.007733,0.917578,0.977868,0.888884,5.141973,1.484624,0.877864,2.987322,0.886015
Cytokeratin,8.830395,12.708248,7.943659,7.423984,8.101929,64.333679,15.00185,9.91709,34.57663,8.09444
HLADR,22.848648,98.138229,20.079594,17.288685,27.096207,10.944695,29.556576,108.284195,19.842569,25.361725
Vimentin,243.190369,63.304062,590.420959,833.353394,337.343201,31.089279,144.486023,205.566696,149.990051,442.737976
CD28,0.389227,0.430784,0.194268,0.128297,0.34058,0.281808,0.407358,0.397186,0.162152,0.282243
CD15,2.638944,2.450385,11.786937,8.239302,2.358004,10.230668,4.80124,0.949367,157.635086,3.649838
CD45RA,8.942804,19.837212,6.109267,4.585984,8.528809,3.308713,10.354916,25.182024,5.675559,7.582182
CD66b,0.345698,0.26031,0.906205,0.919958,0.39428,0.649864,0.38948,0.289405,6.643331,0.506833
CD20,8.239733,66.463295,6.27216,4.300542,8.774179,4.924405,14.146482,49.06678,10.275104,7.642591


In [11]:
## starling expression centriods (p x c) matrix
pd.DataFrame(st.adata.varm['st_exp_centroids'], index=st.adata.var_names)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
SMA,1.856169,0.625896,0.896344,2.931808,2.362362,0.452633,0.647195,0.601289,0.533174,2.265134
ECadherin,1.752379,0.684874,0.794378,1.14053,0.889083,5.442095,0.780747,0.604517,3.296637,0.603578
Cytokeratin,7.144927,9.611339,6.761675,8.458087,7.473826,68.80127,7.591178,8.195832,35.796188,5.524339
HLADR,12.59522,87.814018,10.817248,12.579467,21.29423,8.191945,19.641893,54.762936,35.732903,13.885956
Vimentin,337.241364,109.927078,368.312866,880.669189,376.222473,30.647604,294.651825,236.557831,128.00354,278.096436
CD28,0.233382,0.351048,0.052545,0.092156,0.274304,0.245081,0.428253,0.227249,0.157632,0.166838
CD15,1.410196,0.403816,3.482865,12.992934,1.880894,7.510751,0.5378,0.487119,96.615601,2.834917
CD45RA,8.541049,23.630945,3.249801,3.004163,6.540343,1.463255,7.479087,14.056887,8.352734,4.733887
CD66b,0.308428,0.19793,0.311546,1.042953,0.342445,0.220183,0.291976,0.240959,3.998243,0.264742
CD20,5.040942,58.27911,2.418272,2.840937,6.650702,1.946869,4.579585,25.466352,14.674814,4.598309


In [12]:
## assignment distributions (n x c maxtrix)
pd.DataFrame(st.adata.obsm['assignment_prob_matrix'], index = st.adata.obs.index)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
4_1,2.163412e-15,1.773794e-08,8.986104e-01,2.040122e-17,2.429650e-16,1.083843e-07,4.588316e-06,5.576983e-10,2.515836e-12,2.480487e-11
4_2,6.017434e-11,7.704447e-05,4.711837e-01,1.344293e-15,5.758044e-12,1.331391e-08,5.905133e-02,1.948176e-09,9.966562e-10,4.357041e-10
4_3,2.507708e-13,9.686772e-16,8.982366e-01,1.918956e-08,2.456857e-15,2.488170e-11,1.697231e-10,9.049102e-15,3.842503e-12,5.649181e-11
4_4,2.983992e-15,3.285619e-07,8.709832e-01,7.537796e-15,1.184206e-14,2.643209e-04,1.137312e-03,4.081606e-09,2.515627e-10,2.103558e-10
4_5,9.762505e-10,7.301825e-08,6.933164e-02,5.665508e-13,3.726940e-09,4.396118e-07,2.878572e-01,9.553643e-06,5.794974e-13,3.342430e-05
...,...,...,...,...,...,...,...,...,...,...
4_13681,7.815181e-16,8.655882e-15,9.156883e-08,9.431755e-16,6.159585e-18,3.033172e-04,3.750620e-13,2.296102e-17,9.539959e-01,2.469924e-14
4_13682,9.959414e-17,3.521862e-13,2.581485e-10,4.352986e-21,5.634993e-19,8.964192e-01,1.198390e-10,2.443304e-13,6.812414e-09,2.314731e-17
4_13683,1.193154e-17,7.985612e-08,8.844023e-01,2.219167e-16,1.028393e-16,2.023451e-07,2.220071e-05,1.502397e-11,6.707016e-12,2.380328e-12
4_13684,2.605712e-17,4.341093e-13,3.295740e-08,5.432328e-17,1.501018e-18,1.012646e-02,1.204077e-09,1.190830e-14,8.722455e-01,1.866526e-14
