# Multimodal Generative Modeling of Spatial Transcriptomics using VAEs and GNNs

To build the model architecture, integrating spatial transcriptomics, WSI images, and metadata, we'll use Graph Neural Networks (GNN) for the spatial transcriptomics encoder, H-Optimus-0 for the WSI encoder, and a simple Fully Connected Network (FCN) or embedding layer for the metadata. Afterward, we'll combine the outputs into a shared latent space and then decode them back to the respective domains.

In [None]:
pip install scanpy



In [None]:
pip install torch-geometric



In [None]:
import os
import glob
import re
import scanpy as sc
import json
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Dataset Preparation

In [None]:
# Set local directory fro data reading
data_dir = "/content/drive/My Drive/Projects/Synthetic Spatial Transcriptomics/hest_data_xenium"
st_dir = data_dir + '/st'  # Spatial transcriptomics files (h5ad)
metadata_dir =  data_dir + '/metadata'  # Metadata files (json)
thumbnails_dir = data_dir + '/thumbnails'  # Histology images (jpeg)

In [None]:
# Function to get all file identifiers from a directory
def get_file_identifiers(folder_path, extension):
    # List all files in the folder and filter by extension
    files = [f for f in os.listdir(folder_path) if f.endswith(extension)]
    # Extract the identifiers
    identifiers = [f.split('.')[0] for f in files]
    return identifiers

# Get identifiers for each data type
st_identifiers = get_file_identifiers(st_dir, '.h5ad')
wsi_identifiers = get_file_identifiers(thumbnails_dir, '.jpeg')
metadata_identifiers = get_file_identifiers(metadata_dir, '.json')


In [None]:
# Create the paths dynamically based on the identifiers
def generate_paths(identifiers, st_folder, wsi_folder, metadata_folder):
    st_paths = [os.path.join(st_folder, f"{identifier}.h5ad") for identifier in identifiers]
    wsi_paths = [os.path.join(wsi_folder, f"{identifier}.jpeg") for identifier in identifiers]
    metadata_paths = [os.path.join(metadata_folder, f"{identifier}.json") for identifier in identifiers]

    return st_paths, wsi_paths, metadata_paths

# Example usage
st_folder = 'path_to_st_folder'
wsi_folder = 'path_to_wsi_folder'
metadata_folder = 'path_to_metadata_folder'

# Generate paths
st_paths, wsi_paths, metadata_paths = generate_paths(st_identifiers, st_dir, thumbnails_dir, metadata_dir)


In [None]:
# Preprocess metadata to create one-hot encoding maps
def preprocess_metadata(metadata_paths):
    # Use sets to collect unique values
    unique_values = {"organ": set(), "disease_state": set(), "species": set(), "tissue": set()}

    for path in metadata_paths:
        with open(path, 'r') as f:
            metadata = json.load(f)
            for key in unique_values.keys():
                unique_values[key].add(metadata.get(key, "unknown"))

    # Create one-hot encoding maps
    encoding_maps = {
        key: {val: i for i, val in enumerate(sorted(unique_values[key]))}
        for key in unique_values
    }
    return encoding_maps

# One-hot encode metadata values
def load_metadata(metadata_path, encoding_maps):
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)

    # One-hot encode fields and concatenate
    onehot_tensors = [
        torch.nn.functional.one_hot(
            torch.tensor(encoding_maps[key].get(metadata.get(key, "unknown"), 0)),
            num_classes=len(encoding_maps[key])
        )
        for key in encoding_maps.keys()
    ]
    return torch.cat(onehot_tensors, dim=0)


In [None]:
def custom_collate_fn(batch):
    st_data_list, wsi_images, metadata_list, edge_index_list = zip(*batch)

    # Determine maximum sizes for padding
    max_spots = max(st.size(0) for st in st_data_list)
    max_genes = max(st.size(1) for st in st_data_list)

    # Pad each spatial transcriptomics tensor
    padded_st_data = [pad_tensor(st, max_spots, max_genes) for st in st_data_list]

    # Stack padded tensors into a single tensor
    padded_st_data_tensor = torch.stack(padded_st_data)

    # Ensure WSI images and metadata are handled correctly
    wsi_images_tensor = torch.stack([transforms.ToTensor()(img) for img in wsi_images])  # Assuming images are PIL
    metadata_tensor = torch.cat(metadata_list)  # Concatenate metadata tensors

    return padded_st_data_tensor, wsi_images_tensor, metadata_tensor, edge_index_list


def compute_edge_index(spatial_coords):
    from scipy.spatial import KDTree

    tree = KDTree(spatial_coords)
    edges = tree.query_pairs(r=1.0)  # TBD: Adjust radius based on data scale
    edge_index = torch.tensor(list(edges), dtype=torch.long).t().contiguous()  # Convert to tensor
    return edge_index

def pad_tensor(tensor, max_spots, max_genes):
    # Pad spots
    pad_spots = max_spots - tensor.size(0)
    if pad_spots > 0:
        padding_spots = torch.zeros(pad_spots, tensor.size(1), dtype=tensor.dtype, device=tensor.device)
        tensor = torch.cat([tensor, padding_spots], dim=0)

    # Pad genes (if necessary)
    if tensor.size(1) < max_genes:
        padding_genes = torch.zeros(tensor.size(0), max_genes - tensor.size(1), dtype=tensor.dtype, device=tensor.device)
        tensor = torch.cat([tensor, padding_genes], dim=1)

    return tensor

class SpatialTranscriptomicsDataset(Dataset):
    def __init__(self, st_files, wsi_files, metadata_files, encoding_maps):
        self.st_files = st_files
        self.wsi_files = wsi_files
        self.metadata_files = metadata_files
        self.encoding_maps = encoding_maps

    def __len__(self):
        return len(self.st_files)

    def __getitem__(self, idx):
        # Load spatial transcriptomics data
        st_data = sc.read_h5ad(self.st_files[idx])

        # Ensure dense format
        st_data_dense = (
            st_data.X if isinstance(st_data.X, np.ndarray) else st_data.X.todense()
        )
        st_data_tensor = torch.tensor(st_data_dense, dtype=torch.float32)

        # Load spatial coordinates and compute edge index
        spatial_coords = st_data.obsm['spatial']
        edge_index = compute_edge_index(spatial_coords)

        # Load WSI image and transform
        wsi_image = Image.open(self.wsi_files[idx]).convert("RGB")
        wsi_image = wsi_image.resize((224, 224))

        # Load metadata
        metadata_path = self.metadata_files[idx]
        metadata_onehot = load_metadata(metadata_path, self.encoding_maps)

        return st_data_tensor, wsi_image, metadata_onehot, edge_index


# Initialize dataset
encoding_maps = preprocess_metadata(metadata_paths)
dataset = SpatialTranscriptomicsDataset(st_paths, wsi_paths, metadata_paths, encoding_maps)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)

# Encoder Models

## WSI Encoder (using Bioptimus/H-Optimus-0)

In [None]:
import timm
from torchvision import transforms
import torch

# Define the WSI encoder
class WSIEncoder(torch.nn.Module):
    def __init__(self):
        super(WSIEncoder, self).__init__()
        # Load H-Optimus from Huggingface via timm
        self.model = timm.create_model(
            "hf-hub:bioptimus/H-optimus-0",
            pretrained=True,
            init_values=1e-5,
            dynamic_img_size=False
        )
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.707223, 0.578729, 0.703617),
                std=(0.211883, 0.230117, 0.177517)
            ),
        ])

    def forward(self, x):
        # Apply transforms
        transformed_x = torch.stack([self.transform(img) for img in x])

        # Move input to CUDA
        transformed_x = transformed_x.to("cuda")

        # Mixed precision inference with H-Optimus
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            with torch.inference_mode():
                features = self.model(transformed_x)

        return features

# Initialize WSI encoder
wsi_encoder = WSIEncoder().to("cuda")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


## Spatial Transcriptomics Encoder (Graph Neural Network)

To be done:

- Normalize gene expression data
- Pre-define a list of genes (union of all genes in HEST-1K dataset) and map gene expression from each sample to the list, this will ensure that genes are always inputted in the same order to the model

In [None]:
import torch_geometric
from torch_geometric.nn import GCNConv

class STEncoder(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(STEncoder, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))  # First convolution
        x = torch.relu(self.conv2(x, edge_index))  # Second convolution
        return x

# Example usage (spatial transcriptomics as graph)
input_dim = 541  # Number of genes (input dimension)
hidden_dim = 128  # Hidden layer dimension
output_dim = 64  # Latent space dimension
st_encoder = STEncoder(input_dim, hidden_dim, output_dim)


## Metadata Encoder

In [None]:
class MetadataEncoder(torch.nn.Module):
    def __init__(self, input_dim, embedding_dim, output_dim):
        super(MetadataEncoder, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, embedding_dim)
        self.fc2 = torch.nn.Linear(embedding_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

metadata_encoder = MetadataEncoder(input_dim=4, embedding_dim=32, output_dim=16)  # Example dims


# Latent Space and Decoder

To be done:
- Implement decoder to reconstruct WSI and Spatial Transcriptomic data

In [None]:
class VAE(torch.nn.Module):
    def __init__(self, st_encoder, wsi_encoder, metadata_encoder, latent_dim=64):
        super(VAE, self).__init__()
        self.st_encoder = st_encoder
        self.wsi_encoder = wsi_encoder
        self.metadata_encoder = metadata_encoder

        # Latent space
        self.latent_dim = latent_dim
        self.fc1 = torch.nn.Linear(64 + 768 + 16, latent_dim)  # Combine all modalities
        self.fc2 = torch.nn.Linear(latent_dim, 64)  # Decoder layer to reconstruct spatial transcriptomics

    def forward(self, st_data, wsi_data, metadata, edge_index):
        # Encode spatial transcriptomics (GNN)
        st_embedding = self.st_encoder(st_data, edge_index)

        # Encode histopathology images (H-Optimus)
        wsi_embedding = self.wsi_encoder(wsi_data)  # Returns features for each WSI

        # Encode metadata
        metadata_embedding = self.metadata_encoder(metadata)

        # Combine embeddings from all modalities
        combined_embedding = torch.cat([st_embedding.mean(dim=0), wsi_embedding.mean(dim=0), metadata_embedding], dim=-1)

        # Latent space
        latent_space = torch.relu(self.fc1(combined_embedding))
        latent_space = self.fc2(latent_space)

        return latent_space  # Return latent space instead of reconstructed ST directly

# Initialize the VAE
vae = VAE(st_encoder, wsi_encoder, metadata_encoder, latent_dim=64).to("cuda")


# Training loop

In [None]:
import torch.optim as optim

def adjust_edge_index(edge_index, current_size):
    # Ensure edge_index is a tensor
    if isinstance(edge_index, np.ndarray):
        edge_index = torch.tensor(edge_index, dtype=torch.long)

    adjusted_edges = []

    # Transpose the edge index for easier access
    for src, dst in edge_index.t().tolist():
        if src < current_size and dst < current_size:  # Only keep valid edges
            adjusted_edges.append((src, dst))

    # Convert back to tensor and ensure it's on the right device
    if adjusted_edges:
        adjusted_edge_index = torch.tensor(adjusted_edges, dtype=torch.long).t().contiguous()
    else:
        # If no edges are valid, create an empty edge index
        adjusted_edge_index = torch.empty((2, 0), dtype=torch.long)

    return adjusted_edge_index.to(edge_index.device)  # Use the device of the original edge index

# Set up optimizer and loss function
optimizer = optim.Adam(vae.parameters(), lr=1e-4)
reconstruction_loss_fn = torch.nn.MSELoss()

# Training loop
epochs = 100
for epoch in range(epochs):
    vae.train()
    for batch in dataloader:
        st_data_tensor, wsi_images_tensor, metadata_tensor, edge_index_list = batch

        optimizer.zero_grad()

        # Adjust edge indices based on current sizes before passing them into VAE
        adjusted_edge_indices = [adjust_edge_index(edge_index.numpy(), st.size(0)) for edge_index, st in zip(edge_index_list, st_data_tensor)]

        # Ensure all adjusted_edge_indices are tensors and have correct shape
        adjusted_edge_indices = [edge_index.to("cuda") for edge_index in adjusted_edge_indices]

        print("ST Data Shape:", st_data_tensor.shape)
        print("WSI Image Shape:", wsi_images_tensor.shape)
        print("Metadata Shape:", metadata_tensor.shape)
        print("Adjusted Edge Indices Shape:", [edge_index.shape for edge_index in adjusted_edge_indices])

        reconstructed_st = vae(
            st_data_tensor.to("cuda"),
            wsi_images_tensor.to("cuda"),
            metadata_tensor.to("cuda"),
            adjusted_edge_indices
        )

        loss = reconstruction_loss_fn(reconstructed_st, st_data_tensor.to("cuda"))

        loss.backward()
        optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

ST Data Shape: torch.Size([8, 11845, 10017])
WSI Image Shape: torch.Size([8, 3, 224, 224])
Metadata Shape: torch.Size([296])
Adjusted Edge Indices Shape: [torch.Size([2, 0]), torch.Size([2, 0]), torch.Size([2, 0]), torch.Size([2, 0]), torch.Size([2, 0]), torch.Size([2, 0]), torch.Size([2, 0]), torch.Size([2, 0])]
