In [1]:
# Set up autoreload for development
%load_ext autoreload
%autoreload 2

In this notebook, we'll train the SpatialDIVA model on the Valdeolivas et al. colorectal cancer dataset, comprising of 12 slides of colorectal cancer tissue from different patients.

For demonstration, we'll restrict the data to two slides for training of the model. After training, we'll extract embeddings from the different latent subspaces of SpatialDIVA, and analyze their covariance. We'll also examine how SpatialDIVA performs in terms of batch correction.

We'll use the high-level SpatialDIVA API to train the model and extract embeddings.

Let's start by loading the necessary modules and the data.

In [2]:
import sys 
import os 
sys.path.append("..")

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc 
import anndata as ann 

from api import StDIVA

We can start by loading the anndata objects for Valdeolivas et al - in this case we'll restrict it to the first two files, corresponding to two slides.

This data has been preprocessed such that spot-level deconvolution has been done, and pathologist annotations are also present in the data. UNI-feature extraction for the spot-level histpathology patches corresponding to the ST spots, has also been done. The appropriate features are stored in the `.obs` attribute of the anndata object, with the identifiers "UNI" before the feature name.

In [3]:
# Load the first two slides of the Valdeolivas et al dataset
adata_path = "/scratch/hdd001/home/hmaan/visium_datasets/valdeolivas_hest_data"
adata_files = [f for f in os.listdir(adata_path) if f.endswith("processed.h5ad")]
adata_files = [os.path.join(adata_path, f) for f in adata_files]

adata_files_sub = adata_files[:2]
adatas = []
for adata_file in adata_files_sub:
    adata = sc.read_h5ad(adata_file)
    adatas.append(adata)

In [4]:
adatas[0].obs.columns

Index(['array_row_x', 'array_col_x', 'pxl_col_in_fullres',
       'pxl_row_in_fullres', 'in_tissue_x', 'pxl_row_in_fullres_old',
       'pxl_col_in_fullres_old', 'n_genes_by_counts_x',
       'log1p_n_genes_by_counts_x', 'total_counts_x',
       ...
       'UNI-1015', 'UNI-1016', 'UNI-1017', 'UNI-1018', 'UNI-1019', 'UNI-1020',
       'UNI-1021', 'UNI-1022', 'UNI-1023', 'UNI-1024'],
      dtype='object', length=1348)

As we can see, the UNI columns are at the end of the anndata obs attributes, followed by the other metadata. The format of this anndata object is such that each spot corresponds to one observation.

The pathologist annotations per spot are stored in `.obs["Pathologist Annotation]`

The maximally represented cell-type (after deconvolution) per spot is stored in `.obs["ST_celltype"]`

The batch/sample label is stored in `.obs["sample"]`

The positions of each spot on the slide are stored in `.obsm["spatial"]`

We'll need the following information for SpatialDIVA training - 

- The dimensionality of the UNI features
- The dimensionality of the pathologist annotations (total number of classes)
- The dimensionality of the batch/sample labels (total number of classes)
- The dimensionality of the ST celltypes (total number of classes)
- The dimensionality of the neighborhood context for each spot (50 dimensional by default)

Let's go ahead and extract those from the loaded data

In [5]:
from sklearn.preprocessing import LabelEncoder

counts_dim = 2500 # Because we are using 2500 HVGS in the processor later - the default
uni_cols = [col for col in adatas[0].obs.columns if "UNI" in col]
hist_dim = len(uni_cols)
y1_dim = len(np.unique(adatas[0].obs["ST_celltype"].values))
y2_dim = 50

# Transform path labels due to string encoding and character issues
path_labels = adatas[0].obs["Pathologist Annotation"].values
le = LabelEncoder()
path_labels = le.fit_transform(path_labels)

y3_dim = len(np.unique(path_labels))
d_dim = 2 # For two slides

We can explore the data further here, or load it into the StDIVA API. This API will also take in the file locations, and perform relevant preprocessing steps (using the function `adata_process` - more information in the documentation). The API will also perform the training of the model, and extraction of embeddings.

Let's go ahead and initialize the SpatialDIVA API with the necessary parameters.

In [6]:
stdiva = StDIVA(
    counts_dim = counts_dim,
    hist_dim = hist_dim,
    y1_dim = y1_dim,
    y2_dim = y2_dim,
    y3_dim = y3_dim,
    d_dim = d_dim
)

Now we can add the data to the API in the form of a list of files corresponding to the anndata object with the information we've outlined above. 

We'll use 90% of the combined slides for training, and 10% for validation (the default split).

In [7]:
stdiva.add_data(
    adata = adata_files[0:2],
    label_key_y1 = "ST_celltype",
    label_key_y3 = "Pathologist Annotation",
    hist_col_key = "UNI"
)

Processing data..


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


Creating dataloaders..
Done!


From here, we can go ahead and train the model. The default number of epochs is 100, with early stopping enabled by default. Let's use these parameters for training.

In [8]:
stdiva.train()

/h/hmaan/.cache/pypoetry/virtualenvs/spatialdiva-PyiMLd3V-py3.9/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /h/hmaan/.cache/pypoetry/virtualenvs/spatialdiva-Pyi ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/h/hmaan/.cache/pypoetry/virtualenvs/spatialdiva-PyiMLd3V-py3.9/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install li

Starting training..



  | Name  | Type        | Params | Mode 
----------------------------------------------
0 | model | SpatialDIVA | 3.3 M  | train
----------------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.148    Total estimated model params size (MB)
110       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

In [10]:
stdiva.shape

AttributeError: 'StDIVA' object has no attribute 'shape'