In this notebook, we'll perform conditional generation using the SpatialDIVA model, on the pancreatic cancer data from Zhou et al.

The results of this analysis will be a bit different from the paper result, as we'll only use 1000 samples (spots) for quick inference and demonstration purposes, and train the full model on a subset of the data (2 slides).

The high level and detailed processes for performing conditional generation is outlined in the manuscript. In general:

- We start by training the SpatialDIVA model on the data.
- We then use the trained model to generate latent samples for each of the factors we're considering.
- We zero out all but one of the factors, and generate samples considering only the variation of the remaining factor.

After this, we'll have generated transcriptomic counts specific to each factor, and we can perform additional downstream analyses, such as differential expression analysis, clustering, etc.

Let's start by loading the required libraries and the data.

In [4]:
%load_ext autoreload
%autoreload 2

import sys 
import os 
sys.path.append("..")

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc 
import anndata as ann 

from api import StDIVA

In [5]:
# Load the first two slides of the Zhou et al dataset
adata_path = "/projects/scDisent/adata_tumor"
adata_files = [f for f in os.listdir(adata_path) if f.endswith(".h5ad")]
adata_files = [os.path.join(adata_path, f) for f in adata_files]

adata_files_sub = adata_files[:2]
adatas = []
for adata_file in adata_files_sub:
    adata = sc.read_h5ad(adata_file)
    adatas.append(adata)

In [6]:
from sklearn.preprocessing import LabelEncoder

counts_dim = adata.shape[1] # Because we are using all the genes
uni_cols = [col for col in adatas[0].obs.columns if "UNI" in col]
hist_dim = len(uni_cols)
# When getting unique values, combine the celltypes from both slides
y1_dim = len(np.unique(np.concatenate([adatas[0].obs["ST_celltype"].values, adatas[1].obs["ST_celltype"].values])))
y2_dim = 100 # 50 PCs for the ST data and 50 PCs for the UNI data - neighbourhood context 

# Transform path labels due to string encoding and character issues
# When getting unique values, combine the pathologist annotations from both slides
path_labels = np.concatenate([adatas[0].obs["is_tumor"].values, adatas[1].obs["is_tumor"].values]) 
le = LabelEncoder()
path_labels = le.fit_transform(path_labels)

y3_dim = len(np.unique(path_labels))
d_dim = 2 # For two slides

In [7]:
stdiva = StDIVA(
    counts_dim = counts_dim,
    hist_dim = hist_dim,
    y1_dim = y1_dim,
    y2_dim = y2_dim,
    y3_dim = y3_dim,
    d_dim = d_dim,
    betas = [1, 1, 1, 1, 1, 1] # Default betas
)

In [14]:
stdiva.add_data(
    adata = adata_files[0:2],
    label_key_y1 = "ST_celltype",
    label_key_y3 = "is_tumor",
    hist_col_key = "UNI",
    hvg = False # We are not using HVGs - we are using all the genes
)

Processing data..


  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")


Creating dataloaders..


In [21]:
next(iter(stdiva.train_loader))

[tensor([[ 0.0000,  0.0000,  0.0000,  ...,  1.4588, -0.9544, -0.8738],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.6525,  1.4684, -0.7221],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.9315,  1.6959,  0.0036],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ..., -0.4829, -0.7341,  0.2258],
         [ 0.0000,  0.0000,  0.0000,  ..., -1.1429, -0.4445,  1.6168],
         [ 0.0000,  0.0000,  0.0000,  ..., -1.1215,  1.5875,  1.6937]],
        dtype=torch.float64),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64),
 tensor([[-2.1983e+03, -1.6429e+01,  2.6632e+01,  ...,  3.7452e-01,
          -7.3386e-01,  1.6814e-01],
         [-1.3133e+02,  1.4320e+01,  2.0774e+00,  ..., -6.1705e-01,
          -2.4862e-02,  1.1912e-01],
         [-2.3178e+03, -5.9742e+01, 