In [1]:
import pandas as pd
import anndata as ad

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping #ModelCheckpoint

import sys
sys.path.append("/home/campbell/yulee/github/starling/")

from starling import utility
from starling import starling

In [2]:
## load data (annData object)
## run a clustering method, the utility provides GMM, KM (Kmeans) or PG (PhenoGraph), k can be ignored for PG
## the resulting arrays (cluster centroids, variances and labels) setup for STARLING initializations
adata = utility.init_clustering(ad.read_h5ad('sample_input.h5ad'), 'KM', k=10)

In [3]:
## setup starling with initializations
st = starling.ST(adata)

In [4]:
## set early stopping criterion
cb_early_stopping = EarlyStopping(monitor = 'train_loss', mode = 'min', verbose = False)

## log training results via tensorboard
log_tb = TensorBoardLogger(save_dir = 'log')

In [5]:
## train STARLING
trainer = pl.Trainer(max_epochs = 100, accelerator = 'auto', devices = 'auto', callbacks = [cb_early_stopping], logger=[log_tb])
trainer.fit(st)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)
2023-10-18 13:45:40.953396: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-18 13:45:41.046922: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-18 13:45:41.046962: E tensorflow/compiler/xla/stream_executor/cuda/cuda_ff

Training: 0it [00:00, ?it/s]

  v = prob_data_given_gamma_d1.T + log_delta[1] - prob_data                    ## p(d=1,gamma=[c,c']|data)
`Trainer.fit` stopped: `max_epochs=100` reached.


In [6]:
## retrive starling results
st.result()

In [7]:
## st object can be saved
torch.save(st, 'model.pt')

In [8]:
st.adata

AnnData object with n_obs × n_vars = 13685 × 24
    obs: 'sample', 'id', 'x', 'y', 'area', 'area_convex', 'neighbor', 'init_label', 'st_label', 'doublet_prob', 'doublet', 'max_assign_prob'
    uns: 'init_cell_size_centroids', 'init_cell_size_variances', 'st_cell_size_centroids'
    obsm: 'assignment_prob_matrix'
    varm: 'init_exp_centroids', 'init_exp_variances', 'st_exp_centroids'

In [9]:
st.adata.obs

Unnamed: 0,sample,id,x,y,area,area_convex,neighbor,init_label,st_label,doublet_prob,doublet,max_assign_prob
4_1,4,1,0.785714,7.785714,14,14,0,2,1,0.120919,0,0.879054
4_2,4,2,0.823529,22.294117,17,17,0,6,1,0.468598,0,0.471755
4_3,4,3,0.875000,79.500000,16,16,1,4,1,0.121058,0,0.878942
4_4,4,4,0.666667,270.500000,12,12,0,5,1,0.147741,0,0.850225
4_5,4,5,0.823529,279.294130,17,17,1,3,3,0.640757,1,0.242266
...,...,...,...,...,...,...,...,...,...,...,...,...
4_13681,4,13681,997.769200,754.500000,26,26,0,7,7,0.046767,0,0.953133
4_13682,4,13682,998.153900,127.615390,13,13,0,5,5,0.108821,0,0.891179
4_13683,4,13683,998.153900,160.000000,13,13,1,2,1,0.128506,0,0.871477
4_13684,4,13684,997.580600,242.580640,31,33,1,7,7,0.100225,0,0.894979


In [10]:
## initial expression centriods (p x c) matrix
pd.DataFrame(st.adata.varm['init_exp_centroids'], index=st.adata.var_names)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
SMA,0.584032,2.136922,1.156366,0.805381,2.303237,0.491908,1.708158,0.613496,1.533131,0.717891
ECadherin,1.004117,0.919413,0.940801,1.466276,0.977251,5.129963,0.886993,2.994215,0.887764,0.875708
Cytokeratin,12.682053,7.92418,8.768068,14.814448,7.393099,64.138718,8.129685,34.674114,8.098983,9.934929
HLADR,98.171906,20.072334,22.868526,29.464016,17.12813,11.080463,25.202549,19.969992,27.026192,108.989105
Vimentin,64.546646,592.572998,245.200851,146.257721,835.084229,31.13269,445.17749,149.663452,339.704529,208.9254
CD28,0.431907,0.193538,0.387714,0.406106,0.129673,0.281655,0.280278,0.163998,0.340512,0.395799
CD15,2.434009,12.067995,2.636322,4.722517,8.092181,10.216784,3.63525,157.637726,2.389326,0.882139
CD45RA,19.88493,6.068324,8.933743,10.337213,4.602188,3.349854,7.578948,5.666444,8.490561,25.269001
CD66b,0.260032,0.918311,0.347324,0.385715,0.916277,0.648941,0.50878,6.662968,0.395284,0.28733
CD20,66.191742,6.188631,8.214635,13.963896,4.332294,5.031889,7.66364,10.196347,8.726942,49.235191


In [11]:
## starling expression centriods (p x c) matrix
pd.DataFrame(st.adata.varm['st_exp_centroids'], index=st.adata.var_names)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
SMA,0.626866,0.900065,1.973233,0.647311,2.93618,0.447092,0.79785,0.528065,2.566199,0.662092
ECadherin,0.68256,0.810908,2.164109,0.780709,1.120742,5.586512,0.392296,3.256017,1.012677,0.620682
Cytokeratin,9.622328,7.038451,7.088001,7.585248,8.179035,70.551765,3.138251,35.526558,7.876649,9.205347
HLADR,88.720276,10.64693,12.326794,19.711575,12.725853,7.773327,8.857308,34.737801,20.593437,60.313816
Vimentin,107.669998,376.578735,334.727844,293.988281,871.673828,26.460754,121.359421,130.812881,397.19278,188.111313
CD28,0.352102,0.049722,0.195985,0.425148,0.093689,0.243654,0.114702,0.150064,0.261936,0.244742
CD15,0.395785,3.483162,1.183142,0.533324,13.048368,7.080815,1.28971,93.362602,1.856901,0.424741
CD45RA,23.84458,3.171917,12.381987,7.46128,2.991127,1.357749,4.34867,8.442175,6.442048,16.366781
CD66b,0.19596,0.318177,0.327747,0.292196,1.029042,0.213198,0.118466,3.851811,0.35262,0.225679
CD20,58.921364,2.2273,5.572648,4.604578,2.860218,1.840962,4.957422,14.008409,6.535254,29.884245


In [12]:
## assignment distributions (n x c maxtrix)
pd.DataFrame(st.adata.obsm['assignment_prob_matrix'], index = st.adata.obs.index)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
4_1,2.151747e-08,8.790541e-01,3.475121e-15,8.703498e-06,3.039447e-17,1.573994e-07,1.800253e-05,1.009620e-11,2.906209e-15,2.212197e-10
4_2,4.366947e-05,4.717552e-01,6.896898e-10,5.960279e-02,1.516439e-15,7.062534e-09,3.428305e-10,2.149167e-09,1.812662e-11,8.305608e-10
4_3,4.941552e-16,8.789423e-01,2.015450e-13,1.406913e-10,1.252216e-08,1.046348e-11,4.055715e-12,3.340457e-12,3.038856e-14,9.631621e-16
4_4,3.091711e-07,8.502253e-01,4.019638e-16,1.762074e-03,1.354273e-14,2.701313e-04,1.315503e-06,3.899719e-10,1.450138e-13,2.077522e-09
4_5,4.273313e-08,2.339694e-02,3.325816e-10,2.422661e-01,6.168263e-13,3.250750e-07,9.357876e-02,6.037313e-13,1.096507e-08,6.803629e-07
...,...,...,...,...,...,...,...,...,...,...
4_13681,4.512264e-15,4.413084e-08,1.171007e-15,2.736252e-13,6.506645e-16,1.000425e-04,2.266050e-14,9.531330e-01,3.630863e-17,4.700534e-18
4_13682,3.705795e-13,4.801065e-10,1.012208e-16,1.786756e-10,6.303121e-21,8.911785e-01,6.054635e-17,2.096511e-08,5.535319e-18,1.102232e-13
4_13683,2.787279e-08,8.714774e-01,3.055476e-18,1.674799e-05,1.886464e-16,7.509011e-08,3.715073e-11,6.842821e-12,5.545540e-16,5.933008e-13
4_13684,2.485049e-13,1.210574e-08,2.006384e-17,8.983221e-10,3.805069e-17,4.795779e-03,7.990278e-14,8.949791e-01,7.642492e-18,3.262482e-15
