In [None]:
import scanpy as sc
import pandas as pd
import os
import warnings
import numpy as np
import anndata
import torch
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import Data
from tqdm import tqdm
# Suppress all UserWarnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

# 1. Load and Visualize the Spatial Transcriptomics Data

In [None]:
# Define your sample IDs (Here, we use the OSCC Dataset as an example)
sample_ids = ['GSM6339632_s2',
     'GSM6339638_s8',
     'GSM6339631_s1',
     'GSM6339635_s5',
     'GSM6339637_s7',
     'GSM6339634_s4',
     'GSM6339636_s6',
     'GSM6339639_s9',
     'GSM6339641_s11',
     'GSM6339642_s12',
     'GSM6339633_s3',
     'GSM6339640_s10']

In [None]:
for sample_id in sample_ids:
    data_dir = "./data"  # change this to your actual path that stores all the spatial transcriptomics data
    sample_dir = os.path.join(data_dir, sample_id)
    
    # Load the 10x Visium data
    adata = sc.read_10x_h5(os.path.join(data_dir, f"{sample_id}_filtered_feature_bc_matrix.h5"))
    adata.var_names_make_unique()
    
    # Load spatial coordinates
    positions_file = os.path.join(data_dir, f"{sample_id}_tissue_positions_list.csv.gz")
    positions = pd.read_csv(positions_file, header=None)
    positions.columns = ["barcode", "in_tissue", "array_row", "array_col", "pxl_row_in_fullres", "pxl_col_in_fullres"]

    barcodes = np.array(list(adata.obs.index))
    spatial = np.empty((adata.X.toarray().shape[0], 2))
    for i in range(spatial.shape[0]):
        barcode = barcodes[i]
        spatial[i, :] = positions[positions['barcode'] == barcode][["pxl_row_in_fullres", "pxl_col_in_fullres"]].values[0]

    # Save the ST data and spatial transcriptomics for downstream usage
    np.save(f"{sample_id}_SpatialTranscriptomics.npy", adata.X.toarray())
    np.save(f"{sample_id}_Barcodes.npy", barcodes)
    np.save(f"{sample_id}_SpatialCoordinates.npy", spatial)
    np.save("Gene_Names.npy", np.array(list(adata.var_names)))
    
    # Plot spatial coordinates
    spatial_coords = spatial
    
    plt.figure(figsize=(8, 6))
    plt.scatter(spatial_coords[:, 0], spatial_coords[:, 1], s=10, alpha=0.5, c='gray')
    plt.xticks(fontsize = 12)
    plt.yticks(fontsize = 12)
    plt.xlabel("Spatial X", fontsize = 14)
    plt.ylabel("Spatial Y", fontsize = 14)
    plt.grid()
    plt.title(f"Visium Glioblastoma ID: {sample_id}, Data: {adata.X.toarray().shape}", fontsize = 14)
    plt.savefig(f"Visium Glioblastoma ID {sample_id}.pdf", bbox_inches = "tight")
    plt.show()

# 2. Generate Subgraphs to be fed into SAGE-FM

In [None]:
genes = np.load("Matched_Gene_Names.npy", allow_pickle = True)

In [None]:
def generate_augmented_subgraphs(adata, n_neighbors=15):
    """Generate augmented GNN-compatible subgraphs for a single AnnData object.
    param adata: the AnnData object loaded from .h5ad file
    param n_neighbors: number of neighbors in each subgraph
    """
    features = np.log1p(adata.X)
    
    # Extract spatial coordinates
    spatial_coords = adata.obsm["spatial"]
    
    # Fit NearestNeighbors model
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="ball_tree").fit(spatial_coords)
    distances, indices = nbrs.kneighbors(spatial_coords)
    
    subgraphs = []
    
    # Generate a subgraph for each node
    for i, neighbors in enumerate(indices):
        # Subset features and spatial coordinates for this subgraph
        sub_features = features[neighbors]
        sub_coords = spatial_coords[neighbors]
        
        # Create edges (fully connected within neighbors)
        edge_index = []
        for idx, source in enumerate(neighbors):
            for target in neighbors:
                if source != target:  # Skip self-loops
                    edge_index.append([idx, np.where(neighbors == target)[0][0]])
        
        edge_index = torch.tensor(edge_index, dtype=torch.long).t()
        
        # Create PyTorch Geometric Data object
        x = torch.tensor(sub_features, dtype=torch.float)
        subgraph = Data(x=x, edge_index=edge_index, coords=torch.tensor(sub_coords, dtype=torch.float))
        subgraphs.append(subgraph)
    
    return subgraphs

In [None]:
subgraph_shape_list = []

for slice_id  in np.unique(sample_ids): # Generate subgraphs for each sample/patient
    print("Sample ID: ", slice_id)
    X = np.load(f"{slice_id}_SpatialTranscriptomics.npy")
    spatial = np.load(f"{slice_id}_SpatialCoordinates.npy")
    assert X.shape[0] == spatial.shape[0]
    assert spatial.shape[1] == 2
    barcodes = np.load(f"{slice_id}_Barcodes.npy")
    assert X.shape[0] == barcodes.shape[0]
    
    # Create AnnData object
    adata_one_sample = sc.AnnData(X=X)
    
    # Set observation (cell/spot) names and spatial coordinates
    adata_one_sample.obs_names = barcodes
    adata_one_sample.obs['spatial'] = list(spatial)  # Optional: if you want to keep coordinates in obs
    
    # Set variable (gene) names
    adata_one_sample.var_names = var_names
    
    # Also store spatial coordinates in obsm['spatial'], which is the standard for spatial tools
    adata_one_sample.obsm['spatial'] = spatial
    
    # Optional: make var_names unique
    adata_one_sample.var_names_make_unique()
    
    
    print("Number of Spots: ", adata_one_sample.shape[0])
        
    # Select the 14558 genes that are modeled in the spatial foundation models from the adata, if a gene is missing, make the value of it to be 0. The 14558 gene names are stored in the np.array called genes
        # Ensure genes are in the same order and fill missing ones with 0
    genes_set = set(adata_one_sample.var_names)
    all_gene_matrix = np.zeros((adata_one_sample.shape[0], len(genes)))
    missing_gene_number = 0
    
    for i, gene in enumerate(genes):
        if gene in genes_set:
            gene_idx = adata_one_sample.var_names.get_loc(gene)
            all_gene_matrix[:, i] = adata_one_sample.X[:, gene_idx].squeeze()
        else:
            missing_gene_number += 1

    print(f"Maximum Count: {np.max(all_gene_matrix)}")
    
    # Replace the X matrix in adata_one_sample with the padded version
    adata_one_sample = anndata.AnnData(X=all_gene_matrix,
                                       obs=adata_one_sample.obs.copy(),
                                       obsm=adata_one_sample.obsm.copy(),
                                       var=pd.DataFrame(index=genes))

    print("Missing Gene Number: ", missing_gene_number)
    
    # Generate subgraphs
    subgraphs = generate_augmented_subgraphs(adata_one_sample, use_uce=False, n_neighbors=15)
    subgraph_shape_list.append(len(subgraphs))

    # Save all subgraphs as a single .pt file
    file_name = slice_id
    output_dir = "./gnn_subgraphs/OSCC"
    torch.save(subgraphs, os.path.join(output_dir, f"{file_name}.pt"))