In [1]:
import anndata as ad

import torch
import pytorch_lightning as pl
#from pytorch_lightning.callbacks.progress import RichProgressBar
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
#import torchmetrics

import utility
import starling

In [2]:
## load data (annData object)
## run a clustering method, the utility provides GMM, KM (Kmeans) or PG (PhenoGraph), k can be ignored for PG
## the resulting arrays (cluster centroids, variances and labels) setup for STARLING initializations
adata = utility.init_clustering(ad.read_h5ad('roi4_input.h5ad'), 'KM', k=10)
#adata = utility.init_clustering(ad.read_h5ad('sample_input.h5ad'), 'KM', k=10)
#adata[adata.obs['sample'] == 4].write('roi4_input.h5ad')

In [3]:
## setup starling with initializations
st = starling.ST(adata)

In [4]:
## track training process
#cb_progress = RichProgressBar()

## set early stopping criterion
cb_early_stopping = EarlyStopping(monitor = 'train_loss', mode = 'min', verbose = False)

## log training results via tensorboard
log_tb = TensorBoardLogger(save_dir = 'log')

In [6]:
## train STARLING
#trainer = pl.Trainer(max_epochs = 100, accelerator = 'auto', devices = 'auto', callbacks = [cb_progress, cb_early_stopping], logger=[log_tb])
trainer = pl.Trainer(max_epochs = 200, accelerator = 'auto', devices = 'auto', callbacks = [cb_early_stopping], logger=[log_tb])
trainer.fit(st)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  v = prob_data_given_gamma_d1.T + log_delta[1] - prob_data                    ## p(d=1,gamma=[c,c']|data)
`Trainer.fit` stopped: `max_epochs=100` reached.


In [7]:
## retrive starling results
st.result()

In [8]:
st.adata

AnnData object with n_obs × n_vars = 13685 × 24
    obs: 'sample', 'id', 'x', 'y', 'area', 'area_convex', 'neighbor', 'init_label', 'st_label', 'doublet_prob', 'max_assign_prob'
    uns: 'init_exp_centroids', 'init_exp_variances', 'init_cell_size_centroids', 'init_cell_size_variances', 'assignment_prob_matrix', 'st_exp_centroids', 'st_cell_size_centroids'

In [9]:
st.adata.obs

Unnamed: 0,sample,id,x,y,area,area_convex,neighbor,init_label,st_label,doublet_prob,max_assign_prob
4_1,4,1,0.785714,7.785714,14,14,0,2,7,0.112484,0.887512
4_2,4,2,0.823529,22.294117,17,17,0,0,7,0.473098,0.480163
4_3,4,3,0.875000,79.500000,16,16,1,3,7,0.113046,0.886954
4_4,4,4,0.666667,270.500000,12,12,0,4,7,0.135497,0.863417
4_5,4,5,0.823529,279.294130,17,17,1,8,8,0.670179,0.258559
...,...,...,...,...,...,...,...,...,...,...,...
4_13681,4,13681,997.769200,754.500000,26,26,0,6,6,0.041566,0.958391
4_13682,4,13682,998.153900,127.615390,13,13,0,4,4,0.099704,0.900296
4_13683,4,13683,998.153900,160.000000,13,13,1,2,7,0.125444,0.874538
4_13684,4,13684,997.580600,242.580640,31,33,1,6,6,0.082560,0.914910


In [10]:
## initial expression centriods matrix
st.adata.uns['init_exp_centroids']

Unnamed: 0,SMA,ECadherin,Cytokeratin,HLADR,Vimentin,CD28,CD15,CD45RA,CD66b,CD20,...,CD45RO,CD3,IFNg,TCF1,CD14,CD56,PD1,CD45,PNAd,CD31
0,1.684391,0.886015,8.09444,25.361725,442.737976,0.282243,3.649838,7.582183,0.506833,7.642591,...,18.559362,9.796923,4.632633,2.798697,11.680065,11.290042,0.550894,6.407729,1.414149,2.968246
1,0.582677,1.010169,12.734038,98.11525,63.301514,0.430788,2.518308,19.834152,0.265202,66.438675,...,21.757957,6.288233,7.033791,2.249508,8.644648,11.094724,2.266236,9.218513,0.56937,2.125594
2,1.147718,0.941582,8.827766,22.844702,243.15271,0.389161,2.637199,8.94278,0.34561,8.240726,...,21.140459,12.160061,4.638205,3.545717,9.335161,11.44218,0.636373,7.058299,0.720294,2.512595
3,2.2955,0.977868,7.423985,17.288687,833.353455,0.128297,8.239302,4.585984,0.919958,4.300542,...,11.639784,5.816568,4.327111,1.921495,17.518074,11.824206,0.454397,4.889655,7.49891,5.65334
4,0.490979,5.145037,64.37674,10.880665,30.994431,0.2816,10.176011,3.283474,0.645825,4.881508,...,12.367717,3.4359,3.652158,1.692792,10.043571,16.236395,0.463179,5.425852,0.584392,2.533465
5,1.538811,0.888884,8.10193,27.096209,337.343231,0.34058,2.358004,8.528809,0.39428,8.774178,...,20.499004,10.979179,4.693236,3.13503,10.374477,11.685,0.593339,6.888461,0.860296,2.654646
6,0.614018,2.987323,34.57663,19.842567,149.990021,0.162152,157.635086,5.675559,6.64333,10.275104,...,22.590227,4.256676,5.424716,1.644864,7.496294,14.45098,0.44447,8.557816,0.872177,2.656443
7,2.120535,0.917578,7.943658,20.079594,590.420898,0.194268,11.786936,6.109267,0.906205,6.272161,...,15.895439,7.634552,4.54614,2.419936,13.591745,11.322026,0.502854,5.72912,3.623322,4.088745
8,0.804645,1.487135,15.040268,29.569954,144.369827,0.407416,4.806731,10.36199,0.389546,14.160599,...,21.400354,10.755391,4.870096,2.924923,8.583033,11.446869,0.886599,7.278268,0.629587,2.278311
9,0.716098,0.877864,9.917091,108.284187,205.566681,0.397186,0.949367,25.182022,0.289405,49.06678,...,20.043924,6.65224,6.123187,1.853001,7.937771,10.950381,0.933277,9.303548,0.607034,2.246101


In [11]:
## starling expression centriods matrix
st.adata.uns['st_exp_centroids']

Unnamed: 0,SMA,ECadherin,Cytokeratin,HLADR,Vimentin,CD28,CD15,CD45RA,CD66b,CD20,...,CD45RO,CD3,IFNg,TCF1,CD14,CD56,PD1,CD45,PNAd,CD31
0,2.093377,0.583514,5.217323,12.99554,258.45285,0.160151,2.710453,4.529563,0.245545,4.471069,...,9.964396,5.520292,2.897931,1.393679,8.023635,6.068335,0.273581,3.679453,1.419795,2.077661
1,0.625944,0.685225,9.619546,87.240852,109.653114,0.351776,0.415323,23.249788,0.198709,57.552998,...,18.366541,3.870985,6.629697,1.657478,7.091712,9.78292,0.803155,8.980189,0.552725,2.090986
2,1.824534,1.727772,7.311969,12.80131,337.802917,0.234337,1.374708,8.127346,0.306861,5.070455,...,11.74671,6.166959,4.220049,8.54158,12.625811,27.64657,0.475406,5.887822,0.66555,5.33874
3,2.65569,1.107005,8.373694,15.416349,727.6474,0.092781,13.118151,3.256447,0.956178,3.054418,...,10.352385,4.438157,4.526082,1.953285,17.908558,12.636078,0.375228,4.25745,10.403078,8.086592
4,0.446937,5.55705,70.772942,7.763813,27.329767,0.242918,7.134166,1.386255,0.213565,1.838054,...,10.252459,1.860309,3.390498,1.466654,10.013685,16.761089,0.30822,5.06825,0.555401,2.58181
5,2.327031,0.881263,7.30818,21.387983,377.527863,0.274181,1.868226,6.600942,0.343449,6.63911,...,16.97884,9.334948,4.094649,2.686027,11.720682,12.609639,0.472785,5.729698,0.982065,3.103856
6,0.528175,3.215037,35.432358,33.933472,128.27948,0.154416,99.455994,7.807944,4.091126,13.505425,...,20.322618,3.306753,5.600163,1.496524,8.166835,16.360945,0.432769,8.582721,0.654417,2.647938
7,0.889283,0.796868,6.80752,10.737956,369.016022,0.053176,3.250176,3.255893,0.310507,2.381728,...,9.522055,2.925043,3.643824,1.434653,10.464071,8.626833,0.272783,3.912809,0.765465,2.729759
8,0.648694,0.779432,7.600148,19.832525,293.098328,0.428573,0.547194,7.460526,0.293096,4.619991,...,23.012684,14.650786,4.889534,3.197329,9.508484,11.141891,0.560937,7.70974,0.620639,2.237002
9,0.599464,0.606886,8.205359,55.675297,235.810577,0.228356,0.486989,14.23198,0.238609,25.971575,...,14.910774,7.612883,4.746088,1.71247,7.458127,7.743265,0.46192,6.256056,0.496675,1.780692


In [12]:
## assignment distributions (n x c maxtrix)
st.adata.uns['assignment_prob_matrix']

array([[7.29533010e-11, 1.71617751e-08, 1.69775576e-15, ...,
        8.87512216e-01, 3.63796700e-06, 4.08523140e-10],
       [3.77818148e-10, 8.54152315e-05, 3.89025882e-11, ...,
        4.80163408e-01, 4.66533873e-02, 1.34144247e-09],
       [4.95239385e-11, 1.06862311e-15, 2.01771674e-13, ...,
        8.86953501e-01, 1.57483661e-10, 6.52981339e-15],
       ...,
       [4.13032667e-12, 7.98493614e-08, 9.71292377e-18, ...,
        8.74537530e-01, 1.80948306e-05, 9.45445640e-12],
       [1.26703666e-14, 1.87165258e-13, 8.44877548e-18, ...,
        9.92575559e-09, 4.72952896e-10, 3.44527269e-15],
       [4.36804427e-11, 8.11209566e-06, 3.04226110e-09, ...,
        8.52871457e-01, 2.14068054e-03, 1.89613556e-09]])