In [1]:
import pandas as pd
import anndata as ad

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

import utility
import starling

In [2]:
## load data (annData object)
## run a clustering method, the utility provides GMM, KM (Kmeans) or PG (PhenoGraph), k can be ignored for PG
## the resulting arrays (cluster centroids, variances and labels) setup for STARLING initializations
adata = utility.init_clustering(ad.read_h5ad('sample_input.h5ad'), 'KM', k=10)

In [3]:
## setup starling with initializations
st = starling.ST(adata)

In [4]:
## set early stopping criterion
cb_early_stopping = EarlyStopping(monitor = 'train_loss', mode = 'min', verbose = False)

## log training results via tensorboard
log_tb = TensorBoardLogger(save_dir = 'log')

In [5]:
## train STARLING
trainer = pl.Trainer(max_epochs = 100, accelerator = 'auto', devices = 'auto', callbacks = [cb_early_stopping], logger=[log_tb])
trainer.fit(st)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  v = prob_data_given_gamma_d1.T + log_delta[1] - prob_data                    ## p(d=1,gamma=[c,c']|data)


In [6]:
## retrive starling results
st.result()

In [8]:
## st object can be saved
#torch.save(st, 'model.pt')

In [9]:
st.adata

AnnData object with n_obs × n_vars = 248881 × 24
    obs: 'sample', 'id', 'x', 'y', 'area', 'area_convex', 'neighbor', 'init_label', 'st_label', 'doublet_prob', 'doublet', 'max_assign_prob'
    uns: 'init_cell_size_centroids', 'init_cell_size_variances', 'st_cell_size_centroids'
    obsm: 'assignment_prob_matrix'
    varm: 'init_exp_centroids', 'init_exp_variances', 'st_exp_centroids'

In [10]:
st.adata.obs

Unnamed: 0,sample,id,x,y,area,area_convex,neighbor,init_label,st_label,doublet_prob,doublet,max_assign_prob
0_1,0,1,1.333333,9.481482,27,28,1,0,0,0.017456,0,0.982544
0_2,0,2,2.830509,16.203390,59,61,1,0,0,0.014936,0,0.985064
0_3,0,3,1.580645,45.741936,31,31,1,0,0,0.015740,0,0.984260
0_4,0,4,2.720930,21.744186,43,44,1,0,0,0.012079,0,0.987921
0_5,0,5,3.333333,26.916666,60,62,1,0,0,0.015560,0,0.984440
...,...,...,...,...,...,...,...,...,...,...,...,...
19_15741,19,15741,998.666700,467.500000,6,6,1,3,6,0.956219,1,0.043397
19_15742,19,15742,997.947400,742.315800,19,19,1,5,2,0.526092,1,0.396080
19_15743,19,15743,997.850000,748.200000,20,21,1,5,2,0.090487,0,0.909249
19_15744,19,15744,997.882300,787.352970,17,18,1,8,3,0.930707,1,0.058933


In [11]:
## initial expression centriods (p x c) matrix
pd.DataFrame(st.adata.varm['init_exp_centroids'], index=st.adata.var_names)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
SMA,0.517156,1.699549,1.401393,0.837417,2.249879,0.389939,0.578545,1.220393,0.430752,1.972398
ECadherin,1.350959,0.896008,0.87394,0.94447,0.955277,0.751412,0.797466,0.883485,0.749275,0.90087
Cytokeratin,8.052315,6.792813,6.707622,6.156868,6.922022,3.921097,7.060788,6.590886,5.730522,6.953934
HLADR,16.432276,36.09523,32.246449,28.147923,25.175598,50.318062,85.425095,25.225525,99.814774,30.25058
Vimentin,27.391647,327.172089,233.177948,87.576767,688.376709,32.774612,154.772171,156.192459,62.970264,460.542084
CD28,0.185546,0.344323,0.382348,0.31469,0.199818,0.245096,0.361132,0.394774,0.349374,0.280397
CD15,0.496149,2.24351,1.480684,1.007158,7.978024,0.1778,0.464624,1.423639,0.318655,3.92165
CD45RA,8.540041,11.646623,12.227525,12.11894,8.24003,21.702637,26.274673,11.443995,29.486994,9.714458
CD66b,0.129094,0.443218,0.346029,0.226207,0.860531,0.139007,0.258825,0.296441,0.201028,0.585565
CD20,14.533892,21.601505,22.656601,23.735542,13.370874,59.387054,70.293106,20.061493,98.111107,16.168745


In [12]:
## starling expression centriods (p x c) matrix
pd.DataFrame(st.adata.varm['st_exp_centroids'], index=st.adata.var_names)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
SMA,0.330248,0.468764,0.435604,0.472868,1.234269,0.301021,0.388124,0.335251,0.606703,0.531159
ECadherin,0.489942,3.731304,0.709027,0.653751,0.910811,0.538053,0.577533,2.976378,0.638224,0.908354
Cytokeratin,1.640574,29.385609,5.508084,6.75917,4.546995,2.328938,2.341408,9.58766,6.188167,5.053919
HLADR,3.307306,25.334877,57.960484,80.58329,10.297035,45.589912,18.783724,9.757232,17.784489,42.889027
Vimentin,19.496674,60.325493,67.690819,105.285957,239.195923,25.070839,44.290356,12.685467,183.382309,185.185379
CD28,0.061556,0.2481,0.488354,0.276518,0.103352,0.160766,0.231705,0.129516,0.577252,0.296092
CD15,0.048695,0.301525,0.168835,0.193829,0.22619,0.0814,0.070687,0.061203,0.212655,0.220284
CD45RA,1.806237,6.973409,20.140131,29.276237,4.295873,23.040325,10.484491,3.204302,8.761522,16.202665
CD66b,0.071479,0.205296,0.182092,0.209422,0.288032,0.101231,0.106532,0.115367,0.228846,0.296554
CD20,1.397495,10.704766,63.293991,72.71711,4.170593,56.15966,18.01721,4.026489,8.210855,36.733887


In [13]:
## assignment distributions (n x c maxtrix)
pd.DataFrame(st.adata.obsm['assignment_prob_matrix'], index = st.adata.obs.index)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0_1,9.825439e-01,5.640182e-25,4.326533e-28,5.903906e-25,2.248838e-17,4.872199e-15,4.344103e-19,7.436649e-12,5.114796e-28,5.699817e-27
0_2,9.850645e-01,2.789916e-23,2.639702e-28,8.098672e-24,2.814131e-16,1.113136e-13,1.051325e-18,1.097845e-09,3.118753e-28,1.936553e-26
0_3,9.842599e-01,1.374766e-25,1.606225e-28,3.006982e-25,5.884639e-19,7.159246e-16,4.047903e-19,8.699858e-15,4.273549e-28,5.292633e-28
0_4,9.879214e-01,1.911630e-25,2.747166e-29,8.463129e-26,9.505211e-18,4.914228e-16,3.383641e-20,5.903980e-13,4.841845e-29,1.880539e-27
0_5,9.844403e-01,1.056937e-24,1.724376e-28,5.062158e-25,5.850785e-17,1.229685e-14,1.283830e-18,1.601740e-11,3.356994e-28,5.129321e-27
...,...,...,...,...,...,...,...,...,...,...
19_15741,8.225721e-12,2.477125e-06,3.293754e-09,1.460000e-06,3.777934e-04,3.071592e-09,4.339716e-02,2.039251e-11,2.452743e-07,1.690620e-06
19_15742,2.256883e-21,1.820714e-11,3.960795e-01,2.857733e-11,2.816453e-12,1.646022e-10,3.451484e-05,9.457275e-15,7.779357e-02,6.707842e-07
19_15743,1.262840e-24,6.389169e-16,9.092492e-01,2.381645e-12,1.120136e-17,1.466586e-11,5.981603e-07,2.983129e-21,2.633129e-04,5.393832e-10
19_15744,2.945093e-18,2.371946e-11,9.533409e-03,5.893282e-02,3.894834e-13,2.977400e-04,5.287429e-04,4.001405e-14,7.580420e-11,2.033738e-07
