# Getting started with Starling (ST)


In [1]:
%pip install biostarling
%pip install lightning_lite

import anndata as ad
import pandas as pd
import torch
from starling import starling, utility
from lightning_lite import seed_everything
import pytorch_lightning as pl
import numpy as np
from torch.utils.data import DataLoader


Collecting biostarling
  Downloading biostarling-0.1.4-py3-none-any.whl.metadata (5.0 kB)
Collecting flowsom<0.2.0,>=0.1.1 (from biostarling)
  Downloading FlowSom-0.1.1.tar.gz (6.8 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting phenograph<2.0.0,>=1.5.7 (from biostarling)
  Downloading PhenoGraph-1.5.7-py3-none-any.whl.metadata (8.4 kB)
Collecting pytorch-lightning<3.0.0,>=2.3.3 (from biostarling)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting scanpy<2.0.0,>=1.10.2 (from biostarling)
  Downloading scanpy-1.10.4-py3-none-any.whl.metadata (9.3 kB)
Collecting FlowCytometryTools (from flowsom<0.2.0,>=0.1.1->biostarling)
  Downloading FlowCytometryTools-0.5.1-py3-none-any.whl.metadata (7.9 kB)
Collecting minisom (from flowsom<0.2.0,>=0.1.1->biostarling)
  Downloading MiniSom-2.3.3.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting leidenalg>=0.8.2 (from phenograph<2.0.0,>=1.5.7->biostarling)
  Downloading l

## Setting seed for everything


In [2]:
seed_everything(10, workers=True)

INFO:lightning_lite.utilities.seed:Global seed set to 10


10

## Loading annData objects


In [3]:
anndata_train = ad.read_h5ad("train_adata.h5ad")
anndata_train.X = np.arcsinh(anndata_train.layers['exprs'] / 5.0)

anndata_train.obs["cell_cat_labels"] = anndata_train.obs["cell_labels"].astype('category').cat.codes.values
labels_map = anndata_train.obs.set_index('cell_cat_labels')['cell_labels'].to_dict()

labels = anndata_train.obs["cell_cat_labels"]

adata = utility.init_clustering("KM", anndata_train, k=20, labels=labels)


- The input anndata object should contain a cell-by-protein matrix of segmented single-cell expression profiles in the `.X` position. Optionally, cell size information can also be provided as a column of the `.obs` DataFrame. In this case `model_cell_size` should be set to `True` and the column specified in the `cell_size_col_name`argument.
- Users might want to arcsinh protein expressions in \*.h5ad (for example, `sample_input.h5ad`).
- The `utility.py` provides an easy setup of GMM, KM (Kmeans) or PG (PhenoGraph).
- Default settings are applied to each method.
- k can be omitted when PG is used.


## Setting initializations


The example below uses defualt parameter settings based on benchmarking results (more details in manuscript).


In [4]:
st = starling.ST(adata)

  torch.tensor(self.adata.obs[self.cell_size_col_name])


A list of parameters are shown:

- adata: annDATA object of the sample
- dist_option (default: 'T'): T for Student-T (df=2) and N for Normal (Gaussian)
- singlet_prop (default: 0.6): the proportion of anticipated segmentation error free cells
- model_cell_size (default: 'Y'): Y for incoporating cell size in the model and N otherwise
- cell_size_col_name (default: 'area'): area is the column name in anndata.obs dataframe
- model_zplane_overlap (default: 'Y'): Y for modeling z-plane overlap when cell size is modelled and N otherwise
  Note: if the user sets model_cell_size = 'N', then model_zplane_overlap is ignored
- model_regularizer (default: 1): Regularizier term impose on synthetic doublet loss (BCE)
- learning_rate (default: 1e-3): The learning rate of ADAM optimizer for STARLING

Equivalent to the above example:
```python
st = starling.ST(adata, 'T', 'Y', 'area', 'Y', 1, 1e-3)
```


## Setting training log


Once training starts, a new directory 'log' will be created.

In [5]:
## log training results via tensorboard
log_tb = pl.loggers.TensorBoardLogger(save_dir="log")

One could view the training information via tensorboard. Please refer to torch lightning (https://lightning.ai/docs/pytorch/stable/api_references.html#profiler) for other possible loggers.


## Setting early stopping criterion


In [6]:
## set early stopping criterion
cb_early_stopping = pl.callbacks.EarlyStopping(monitor="train_loss", mode="min", verbose=False)

Training loss is monitored.


## Training Starling


In [7]:
## train ST
st.train_and_fit(
    callbacks=[cb_early_stopping],
    logger=[log_tb],
)

INFO:lightning_lite.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:lightning_lite.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:lightning_lite.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type | Params | Mode
-------------------------------------
-------------------------------------
0         Trainable params
0         Non-trainable params
0         Total params
0.000     Total estimated model params size (MB)
0         Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

  prob_data_given_gamma_d1.T + log_delta[1] - prob_data


## Appending STARLING results to the annData object


In [8]:
## retrive starling results
result = st.result()

## The following information can be retrived from the annData object:

- st.adata.varm['init_exp_centroids'] -- initial expression cluster centroids (P x C matrix)
- st.adata.varm['st_exp_centroids'] -- ST expression cluster centroids (P x C matrix)
- st.adata.uns['init_cell_size_centroids'] -- initial cell size centroids if STARLING models cell size
- st.adata.uns['st_cell_size_centroids'] -- initial & ST cell size centroids if ST models cell size
- st.adata.obsm['assignment_prob_matrix'] -- cell assignment probability (N x C maxtrix)
- st.adata.obsm['gamma_prob_matrix'] -- gamma probabilitiy of two cells (N x C x C maxtrix)
- st.adata.obs['doublet'] -- doublet indicator
- st.adata.obs['doublet_prob'] -- doublet probabilities
- st.adata.obs['init_label'] -- initial assignments
- st.adata.obs['st_label'] -- ST assignments
- st.adata.obs['max_assign_prob'] -- ST max probabilites of assignments

_N: # of cells; C: # of clusters; P: # of proteins_


## Saving the model


In [9]:
## st object can be saved
torch.save(st, "model.pt")

model.pt will be saved in the same directory as this notebook.


## Showing STARLING results


In [10]:
display(result)

AnnData object with n_obs × n_vars = 253433 × 40
    obs: 'image', 'sample_id', 'ObjectNumber', 'Pos_X', 'Pos_Y', 'area', 'major_axis_length', 'minor_axis_length', 'eccentricity', 'width_px', 'height_px', 'acquisition_id', 'SlideId', 'Study', 'Box.Description', 'Position', 'SampleId', 'Indication', 'BatchId', 'SubBatchId', 'ROI', 'ROIonSlide', 'includeImage', 'flag_no_cells', 'flag_no_ROI', 'flag_total_area', 'flag_percent_covered', 'small_cell', 'celltypes', 'flag_tumor', 'PD1_pos', 'Ki67_pos', 'cleavedPARP_pos', 'GrzB_pos', 'tumor_patches', 'distToCells', 'CD20_patches', 'Batch', 'cell_labels', 'classifier', 'cell_cat_labels', 'init_label', 'st_label', 'doublet_prob', 'doublet', 'max_assign_prob'
    var: 'channel', 'use_channel', 'marker'
    uns: 'init_cell_size_centroids', 'init_cell_size_variances', 'st_cell_size_centroids'
    obsm: 'assignment_prob_matrix', 'gamma_assignment_prob_matrix'
    varm: 'init_exp_centroids', 'init_exp_variances', 'st_exp_centroids'
    layers: 'exprs

One could easily perform further analysis such as co-occurance, enrichment analysis and etc.


In [11]:
result.obs

Unnamed: 0,image,sample_id,ObjectNumber,Pos_X,Pos_Y,area,major_axis_length,minor_axis_length,eccentricity,width_px,...,CD20_patches,Batch,cell_labels,classifier,cell_cat_labels,init_label,st_label,doublet_prob,doublet,max_assign_prob
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_1,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,1,300.846154,0.692308,13,6.094800,2.780135,0.889904,600,...,,Batch20191023,MacCD163,v1,6,9,9,0.014333,0,0.985667
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_3,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,3,26.982143,0.928571,56,21.520654,3.368407,0.987675,600,...,,Batch20191023,Mural,v1,7,2,0,0.999878,1,0.000122
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_5,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,5,309.083333,0.750000,12,5.294329,2.862220,0.841267,600,...,,Batch20191023,DC,v1,4,9,14,0.255307,0,0.743813
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_7,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,7,431.916667,0.750000,12,5.294329,2.862220,0.841267,600,...,,Batch20191023,Tumor,v1,11,9,14,0.988802,1,0.011100
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_8,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01...,8,116.931034,1.206897,29,9.216670,4.112503,0.894932,600,...,,Batch20191023,Tumor,v1,11,5,0,0.999991,1,0.000009
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2713,IMMUcan_Batch20220908_S-220729-00002_002.tiff,IMMUcan_Batch20220908_S-220729-00002_002,2713,596.548387,596.709677,31,6.857501,5.700162,0.555928,600,...,,Batch20220908,Mural,v3,7,4,0,0.153569,0,0.846214
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2715,IMMUcan_Batch20220908_S-220729-00002_002.tiff,IMMUcan_Batch20220908_S-220729-00002_002,2715,180.300000,597.400000,20,6.484816,3.840203,0.805803,600,...,,Batch20220908,Mural,v3,7,2,2,0.011500,0,0.988500
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2721,IMMUcan_Batch20220908_S-220729-00002_002.tiff,IMMUcan_Batch20220908_S-220729-00002_002,2721,48.370370,598.111111,27,10.732613,3.134663,0.956397,600,...,,Batch20220908,CD8,v3,3,6,10,0.941535,1,0.053996
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2722,IMMUcan_Batch20220908_S-220729-00002_002.tiff,IMMUcan_Batch20220908_S-220729-00002_002,2722,207.969697,598.060606,33,12.864691,3.228974,0.967988,600,...,,Batch20220908,Mural,v3,7,2,2,0.010262,0,0.989738


Starling provides doublet probabilities and cell assignment if it were a singlet for each cell.


## Showing initial expression centroids:


In [12]:
## initial expression centroids (p x c) matrix
pd.DataFrame(result.varm["init_exp_centroids"], index=result.var_names)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.295385,0.040726,0.033524,0.048627,0.04255,0.032858,0.041984,0.031854,0.085768,0.04705,0.034844,0.069225,0.057223,0.050255,0.064835,0.039726,0.036379,0.034857,0.033007,0.066797
1,0.560624,0.609609,0.405913,0.598308,0.576092,0.475896,0.504753,0.421326,0.681874,0.437516,0.479586,0.692169,0.574404,0.419985,0.576361,0.460929,0.49948,0.380119,0.47463,0.580024
2,0.116681,0.075613,0.20637,0.078157,0.227811,0.051397,0.105165,0.027739,0.160618,0.1006,0.086347,0.076624,0.097981,0.03026,0.134632,0.042369,0.120207,0.057443,0.063212,0.050871
3,0.154653,0.095507,0.06654,0.167863,0.152986,0.060975,0.325791,0.052691,0.480484,0.317296,0.124786,0.256894,0.191715,0.054583,0.454213,0.175659,0.160498,0.058657,0.080934,0.137897
4,0.095495,0.094285,0.051146,0.114582,0.252947,0.055007,0.120812,0.047075,0.353003,0.062731,0.068576,0.206099,0.222428,0.045222,0.203666,0.050528,0.084738,0.029115,0.07606,0.112451
5,0.361603,0.235463,0.189961,0.534303,0.318711,0.205662,0.63445,0.180275,0.722695,0.633981,0.41889,0.445609,0.671301,0.191448,0.712819,0.489509,0.492731,0.247176,0.619016,0.355135
6,0.163571,0.14507,0.093837,0.178739,0.260349,0.090671,0.259387,0.076195,0.415628,0.118706,0.200964,0.257899,0.426357,0.088432,0.248794,0.109189,0.248393,0.067547,0.235919,0.173858
7,0.854543,0.090813,0.092899,0.196872,0.124417,0.108279,0.157909,0.30023,0.352194,0.150944,0.15449,0.201327,0.183821,0.609968,0.25284,0.300455,0.142361,0.31686,0.166997,0.712662
8,0.130582,0.160335,0.068882,0.163456,0.217025,0.107571,0.156198,0.083868,0.305202,0.086324,0.123481,0.252245,0.481467,0.074406,0.20106,0.081778,0.125626,0.05375,0.493236,0.144303
9,0.150384,0.052071,0.056115,0.07832,0.097065,0.034907,0.286982,0.027974,0.343123,0.370636,0.096677,0.114641,0.219669,0.038048,0.44036,0.116245,0.139202,0.035698,0.097335,0.080694


There are 10 centroids since we set Kmeans (KM) as k = 10 earlier.


## Showing Starling expression centroids:


In [13]:
## starling expression centroids (p x c) matrix
pd.DataFrame(result.varm["st_exp_centroids"], index=result.var_names)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.08831,0.035699,0.028571,0.050209,0.031487,0.030066,0.034919,0.0179,0.067762,0.037712,0.035996,0.080057,0.06386,0.033445,0.038874,0.034771,0.0275,0.029019,0.032581,0.036391
1,0.598753,0.648708,0.405355,0.616516,0.415206,0.525414,0.507858,0.362126,0.674223,0.434506,0.50529,0.800926,0.697469,0.417665,0.482432,0.41792,0.472023,0.36635,0.46477,0.446327
2,0.112488,0.060168,0.18346,0.033847,0.131609,0.025227,0.117041,0.005792,0.087386,0.06164,0.020982,0.015534,0.086602,0.002582,0.061546,0.01824,0.066411,0.026034,0.029711,0.011038
3,0.08659,0.065735,0.03062,0.127968,0.063406,0.044933,0.315055,0.045011,0.474662,0.29206,0.254307,0.294497,0.44808,0.030169,0.390466,0.067154,0.055204,0.042705,0.065874,0.075075
4,0.109527,0.078129,0.025266,0.132621,0.058232,0.045363,0.110859,0.052861,0.264341,0.041421,0.05329,0.254396,0.33086,0.028461,0.084549,0.035792,0.037134,0.019478,0.089424,0.066537
5,0.27266,0.169274,0.114482,0.321584,0.06754,0.141315,0.642051,0.076426,0.696266,0.621752,0.58512,0.37694,0.751391,0.106304,0.703294,0.34874,0.366075,0.270755,0.636915,0.285969
6,0.194038,0.138597,0.063439,0.199337,0.052791,0.088558,0.286381,0.068617,0.279989,0.099484,0.190478,0.318179,0.477134,0.052495,0.145583,0.082827,0.188424,0.047905,0.323836,0.114952
7,0.48861,0.042316,0.03332,0.14718,0.173466,0.194011,0.092874,0.207793,0.243327,0.074105,0.107063,0.149886,0.149899,0.500532,0.188164,0.240737,0.070945,0.341735,0.095888,0.359077
8,0.147131,0.141088,0.044631,0.193073,0.132363,0.083221,0.138287,0.084947,0.22127,0.0604,0.066855,0.298639,0.350271,0.046864,0.113043,0.061624,0.075832,0.036312,0.478569,0.107337
9,0.068045,0.035319,0.023265,0.059379,0.035061,0.02591,0.296779,0.019947,0.272475,0.34366,0.150946,0.090967,0.319505,0.016162,0.381804,0.03778,0.04032,0.018996,0.049083,0.03885


From here one could easily annotate cluster centroids to cell type.


## Showing Assignment Distributions:


In [14]:
## assignment distributions (n x c maxtrix)
pd.DataFrame(result.obsm["assignment_prob_matrix"], index=result.obs.index)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_1,5.775801e-10,1.980849e-21,3.408138e-14,4.152050e-22,3.624247e-18,1.051834e-18,1.144711e-12,2.711477e-28,5.721248e-23,9.856669e-01,8.833747e-10,3.183581e-36,7.591824e-32,2.099520e-25,1.036578e-07,2.906274e-12,2.955773e-10,1.652950e-21,6.722211e-17,2.169940e-17
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_3,1.218731e-04,1.433469e-11,3.104155e-12,7.389566e-11,7.296896e-11,4.680176e-14,4.304200e-15,5.481113e-25,7.952722e-21,1.151607e-11,1.376607e-15,6.538259e-28,4.037570e-30,2.119317e-24,5.831701e-15,1.235262e-11,1.521163e-07,5.346448e-23,1.258591e-15,2.091331e-09
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_5,8.739931e-12,2.182692e-23,5.830727e-25,3.937041e-13,2.295950e-30,2.594017e-26,7.831545e-04,1.314567e-37,2.529398e-05,7.154425e-05,3.451181e-08,1.184412e-19,1.199568e-13,7.898838e-35,7.438125e-01,3.656435e-20,3.322358e-13,6.842759e-33,1.015738e-12,8.976816e-18
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_7,1.150198e-07,3.721155e-17,7.240515e-27,1.619404e-08,2.121338e-33,3.983535e-23,8.858858e-05,4.423387e-32,8.659900e-06,2.876451e-10,8.757416e-07,2.119820e-14,1.486312e-15,8.175927e-36,1.109952e-02,3.129511e-15,2.173388e-12,3.252204e-36,2.230272e-09,2.265843e-10
IMMUcan_batch20191023_10032145-THOR-VAR-TIS-01-IMC-01_002.tiff_8,8.565061e-06,1.592968e-12,1.351129e-16,1.634774e-09,2.542595e-11,6.746577e-15,1.513820e-11,6.416272e-23,9.462639e-14,1.095931e-08,1.086477e-11,2.761245e-24,1.533386e-24,4.859027e-24,1.641538e-09,1.294544e-11,2.096248e-08,1.066399e-23,1.599284e-15,4.787739e-08
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2713,8.462138e-01,6.320227e-18,3.457147e-09,6.422610e-16,1.952411e-29,2.307193e-12,7.923821e-11,1.929915e-21,1.263615e-14,1.800345e-04,2.893519e-05,3.143585e-23,1.064585e-24,9.960212e-21,9.549359e-09,4.440447e-08,8.426731e-06,5.152920e-18,2.543493e-11,1.835984e-14
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2715,1.193234e-17,4.828956e-26,9.885001e-01,1.666410e-33,3.093451e-24,3.852478e-18,1.684177e-30,2.321469e-27,1.084113e-40,1.133461e-14,1.621414e-23,5.086610e-51,1.444299e-48,1.472443e-14,1.559118e-28,1.534491e-16,2.850451e-12,3.059428e-14,2.311130e-26,1.724169e-27
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2721,8.778254e-07,2.866262e-22,3.392628e-22,3.582203e-18,1.167909e-41,5.370554e-21,2.263284e-03,3.438187e-35,3.916748e-09,1.545232e-05,5.399572e-02,2.025976e-26,1.738577e-18,5.390520e-35,2.188053e-03,1.493285e-16,1.325114e-06,1.949205e-33,1.057401e-07,1.663007e-19
IMMUcan_Batch20220908_S-220729-00002_002.tiff_2722,1.078441e-13,1.373117e-20,9.897379e-01,1.266124e-28,1.344325e-19,4.437907e-13,8.100045e-27,1.945736e-25,8.248012e-36,2.330622e-14,9.501648e-21,3.285013e-45,5.936807e-44,3.311825e-12,1.895773e-24,1.029122e-14,3.989773e-08,4.919938e-11,9.954169e-24,2.121528e-22


Currently, we assign a cell label based on the maximum probability among all possible clusters. However, these could be mislabeled because maximum and second highest probabilies can be very close.

## Assign labels to clusters


In [19]:
# run logistic regression on top of cluster probabilities
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report

X_train = result.obsm["assignment_prob_matrix"]
y_train = result.obs["cell_labels"]
clf = LogisticRegression(random_state=0, max_iter=1000).fit(X_train, y_train)

y_pred = clf.predict(X_train)
print(classification_report(y_train, y_pred))

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           B       1.00      0.00      0.00      3731
         BnT       0.44      0.28      0.35      6493
         CD4       0.47      0.01      0.02     13238
         CD8       0.47      0.59      0.52     22722
          DC       1.00      0.00      0.00      4921
       HLADR       0.00      0.00      0.00      3925
    MacCD163       0.59      0.56      0.57     15288
       Mural       0.80      0.39      0.53     20537
          NK       0.00      0.00      0.00      1112
  Neutrophil       0.41      0.29      0.34      7386
        Treg       0.00      0.00      0.00      6381
       Tumor       0.71      0.98      0.82    138266
         pDC       0.00      0.00      0.00      1561
      plasma       0.16      0.00      0.00      7872

    accuracy                           0.67    253433
   macro avg       0.43      0.22      0.22    253433
weighted avg       0.62      0.67      0.59    253433



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


## Testing

In [20]:
anndata_test = ad.read_h5ad("test_adata.h5ad")
anndata_test.X = np.arcsinh(anndata_test.layers['exprs'] / 5.0)


In [25]:
def test(trained_model, test_adata, threshold: float = 0.5):
        """Test the trained model on the test data.

        :param threshold: minimum threshold for singlet probability
        """
        S = test_adata.obs['area']
        model_pred_loader = DataLoader(
            utility.ConcatDataset([test_adata.X, S]), batch_size=1000, shuffle=False
        )

        singlet_prob, singlet_assig_prob, gamma_assig_prob = utility.predict(
            model_pred_loader,
            trained_model.model_params,
            trained_model.dist_option,
            trained_model.model_cell_size,
            trained_model.model_zplane_overlap,
            threshold,
        )

        test_adata.obs["st_label"] = np.array(
            singlet_assig_prob.max(1).indices
        )  ##p(z=c|d=1)
        test_adata.obs["doublet_prob"] = 1 - np.array(singlet_prob)
        test_adata.obs["doublet"] = 0
        test_adata.obs.loc[test_adata.obs["doublet_prob"] > 0.5, "doublet"] = 1
        test_adata.obs["max_assign_prob"] = np.array(singlet_assig_prob.max(1).values)

        test_adata.obsm["assignment_prob_matrix"] = np.array(singlet_assig_prob)
        test_adata.obsm["gamma_assignment_prob_matrix"] = np.array(gamma_assig_prob)
        c = trained_model.model_params["log_mu"].detach().exp().cpu().numpy()

        test_adata.varm[
            "st_exp_centroids"
        ] = c.T  # pd.DataFrame(c, columns=test_adata.var_names)

        if trained_model.model_cell_size:
            test_adata.uns["st_cell_size_centroids"] = (
                trained_model.model_params["log_psi"]
                .reshape(-1, 1)
                .detach()
                .exp()
                .cpu()
                .numpy()
                .T
            )


        return test_adata

In [26]:
test_result = test(st, anndata_test)

  return tuple(d[i] for d in self.datasets)


In [31]:
test_df = pd.DataFrame(test_result.obsm["assignment_prob_matrix"], index=test_result.obs.index)

In [28]:
test_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
IMMUcan_Batch20191023_S-190805-00002_006.tiff_1,1.199576e-05,2.324981e-13,3.511556e-07,1.297566e-17,2.988778e-10,3.025840e-16,1.960271e-12,4.891818e-30,4.395652e-18,9.826600e-01,1.622901e-07,3.565726e-32,1.497411e-25,2.497936e-21,2.623666e-08,1.219858e-07,1.167062e-06,1.195426e-14,1.702750e-16,1.797768e-15
IMMUcan_Batch20191023_S-190805-00002_006.tiff_2,2.221720e-12,2.842817e-20,3.588599e-04,4.163888e-24,1.163561e-15,1.145841e-10,6.643460e-27,4.731213e-18,4.915772e-32,3.290197e-14,5.549986e-17,6.519130e-41,4.962291e-40,4.488931e-09,1.295718e-22,9.453038e-06,4.725510e-12,7.599697e-01,1.185310e-22,9.002831e-15
IMMUcan_Batch20191023_S-190805-00002_006.tiff_3,6.752328e-05,3.853947e-10,3.718559e-05,5.033108e-15,1.295779e-01,4.438860e-09,5.421951e-17,4.811211e-19,1.500312e-22,6.728992e-07,6.528135e-09,9.635234e-30,3.154611e-32,1.269332e-15,9.674398e-13,4.912334e-01,1.231685e-04,2.474383e-09,1.251684e-13,6.437899e-09
IMMUcan_Batch20191023_S-190805-00002_006.tiff_4,9.738394e-08,1.498377e-14,4.605784e-04,9.050819e-20,2.544416e-10,3.107262e-09,1.015815e-20,6.256985e-21,1.935917e-28,7.172126e-08,1.218143e-12,2.534199e-36,2.656511e-35,1.253602e-13,2.008478e-17,9.791808e-04,8.207771e-09,5.556397e-08,4.734657e-19,3.953055e-13
IMMUcan_Batch20191023_S-190805-00002_006.tiff_5,3.541753e-09,4.529734e-17,9.117110e-06,4.055979e-21,7.226432e-09,4.407496e-12,2.849258e-24,8.927606e-19,4.480402e-31,1.280554e-12,4.060306e-17,1.458828e-40,8.679117e-40,1.551544e-13,2.195382e-20,2.906789e-06,8.929939e-09,2.808113e-08,1.712842e-19,4.613040e-14
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
IMMUcan_Batch20220908_S-220715-00002_002.tiff_2396,8.802162e-12,8.678885e-20,5.005523e-01,8.316892e-26,2.642470e-16,2.646337e-04,9.787987e-28,7.973989e-12,2.111861e-33,2.888232e-11,1.056863e-18,2.932654e-43,1.886405e-43,3.618618e-01,3.316650e-21,1.490459e-04,1.267083e-09,2.008607e-06,7.023956e-22,6.139020e-13
IMMUcan_Batch20220908_S-220715-00002_002.tiff_2397,1.315479e-06,2.586633e-10,3.787558e-05,8.899209e-19,2.181616e-16,8.958791e-01,7.723554e-24,4.791259e-13,3.586591e-28,2.074148e-14,4.718556e-19,2.998018e-36,5.597044e-37,3.217429e-07,4.325683e-19,6.649108e-09,5.249570e-09,2.148005e-13,6.100022e-20,9.620820e-12
IMMUcan_Batch20220908_S-220715-00002_002.tiff_2398,8.858963e-02,1.471692e-06,3.274291e-08,4.629064e-11,5.361191e-14,4.663840e-04,8.605269e-16,1.456481e-11,6.938308e-19,1.043114e-06,5.032007e-13,1.497922e-27,2.187974e-31,5.226727e-14,8.559448e-10,2.221987e-04,5.165349e-08,6.655552e-17,2.553764e-14,4.408730e-03
IMMUcan_Batch20220908_S-220715-00002_002.tiff_2399,3.047618e-03,6.259003e-08,6.628579e-23,2.679024e-10,5.871555e-30,2.318083e-12,1.861389e-08,1.712335e-25,5.674882e-11,3.518206e-06,6.399436e-09,1.085164e-24,8.000903e-22,1.197022e-29,1.241621e-02,9.724147e-14,1.488022e-12,3.368170e-30,5.985888e-13,1.402090e-09


In [33]:
test_results = clf.predict_proba(test_df)
test_results = pd.DataFrame(test_results, columns=clf.classes_, index=test_df.index)
test_results['true_label'] = test_result.obs['cell_labels']
test_results['predicted_label'] = clf.predict(test_df)
test_results

Unnamed: 0,B,BnT,CD4,CD8,DC,HLADR,MacCD163,Mural,NK,Neutrophil,Treg,Tumor,pDC,plasma,true_label,predicted_label
IMMUcan_Batch20191023_S-190805-00002_006.tiff_1,0.009736,2.135655e-05,2.108830e-02,0.000184,0.108737,0.046533,6.845353e-01,0.012659,4.517568e-03,0.002583,0.000484,0.102112,0.001924,0.004885,DC,MacCD163
IMMUcan_Batch20191023_S-190805-00002_006.tiff_2,0.006702,5.274262e-04,7.184278e-03,0.025902,0.000713,0.010703,1.109649e-03,0.006995,2.791446e-03,0.011508,0.000545,0.924455,0.000462,0.000404,Tumor,Tumor
IMMUcan_Batch20191023_S-190805-00002_006.tiff_3,0.008990,2.702483e-04,3.160680e-03,0.000273,0.004490,0.022480,3.374063e-04,0.040289,3.712207e-03,0.011657,0.000285,0.901778,0.000496,0.001782,Tumor,Tumor
IMMUcan_Batch20191023_S-190805-00002_006.tiff_4,0.019752,2.545545e-02,5.117524e-02,0.085205,0.019888,0.022029,5.142217e-02,0.108504,5.644628e-03,0.036473,0.020977,0.509223,0.006035,0.038215,Tumor,Tumor
IMMUcan_Batch20191023_S-190805-00002_006.tiff_5,0.019694,2.565105e-02,5.128452e-02,0.085960,0.019885,0.021935,5.183192e-02,0.108207,5.627455e-03,0.036448,0.021118,0.508023,0.006045,0.038290,Tumor,Tumor
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
IMMUcan_Batch20220908_S-220715-00002_002.tiff_2396,0.029004,3.873121e-04,9.677042e-03,0.008933,0.002628,0.060335,1.119069e-03,0.023899,2.430786e-03,0.040566,0.000584,0.810066,0.001290,0.009081,Tumor,Tumor
IMMUcan_Batch20220908_S-220715-00002_002.tiff_2397,0.000008,2.904297e-07,2.506896e-07,0.000002,0.000015,0.000019,8.018480e-08,0.000079,4.044196e-07,0.000020,0.000001,0.999852,0.000002,0.000001,Tumor,Tumor
IMMUcan_Batch20220908_S-220715-00002_002.tiff_2398,0.021528,1.676872e-02,4.846414e-02,0.045161,0.022866,0.028394,3.017866e-02,0.155457,7.378670e-03,0.064843,0.015422,0.470487,0.008587,0.064464,Tumor,Tumor
IMMUcan_Batch20220908_S-220715-00002_002.tiff_2399,0.020225,2.365727e-02,5.205278e-02,0.079193,0.021113,0.022593,5.435410e-02,0.108012,5.902727e-03,0.038079,0.019632,0.509080,0.006342,0.039765,Tumor,Tumor


In [35]:
from sklearn.metrics import classification_report

print(classification_report(test_results['true_label'], test_results['predicted_label']))

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           B       1.00      0.00      0.01      2767
         BnT       0.47      0.58      0.52      3341
         CD4       0.55      0.01      0.02      6139
         CD8       0.40      0.53      0.46      6890
          DC       0.00      0.00      0.00      2048
       HLADR       0.00      0.00      0.00      1627
    MacCD163       0.64      0.55      0.59      6482
       Mural       0.78      0.41      0.53      8977
          NK       0.00      0.00      0.00       498
  Neutrophil       0.26      0.27      0.27      2633
        Treg       0.00      0.00      0.00      1903
       Tumor       0.64      0.98      0.78     41456
         pDC       0.00      0.00      0.00       843
      plasma       0.40      0.00      0.00      4199

    accuracy                           0.61     89803
   macro avg       0.37      0.24      0.23     89803
weighted avg       0.56      0.61      0.52     89803



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [36]:
test_results.to_csv('starling-lr_predictions.tsv', sep='\t', index=False)