<center>Applying STAligner for Breast-Cancer data set from 10X genomics

In [30]:
import STAligner
import os

import anndata as ad
import scanpy as sc
import pandas as pd
import numpy as np
import scipy.sparse as sp
import scipy.linalg

import scipy
import networkx

import torch

used_device = torch.device('cuda')

In [31]:
pathout = "/data/kanferg/Sptial_Omics/SpatialOmicsToolkit/out_4"
path_age_58 = "/data/kanferg/Sptial_Omics/playGround/Data/Breast_Cancer/age_58/binned_outputs/square_016um"
path_age_76 = "/data/kanferg/Sptial_Omics/playGround/Data/Breast_Cancer/age_76/binned_outputs/square_016um"
path_list = [path_age_58,path_age_76]

In [27]:
def read_data(path_data):
    def parquet_to_csv(path):
        '''
        Converts a Parquet file to a CSV file if the CSV file does not already exist.
        '''
        file_path = os.path.join(path,'spatial/tissue_positions_list.csv')
        if not os.path.exists(file_path):
            df = pd.read_parquet(os.path.join(path,'spatial/tissue_positions.parquet'))
            # Write to a CSV file
            df.to_csv(os.path.join(path,'spatial/tissue_positions_list.csv'), index=False)
        return
    parquet_to_csv(path_data)
    andata = sc.read_visium(path=path_data,load_images=False)
    positions = pd.read_csv(os.path.join(path_data,'spatial/tissue_positions_list.csv'),index_col=0,)
    positions.columns = [
                "in_tissue",
                "array_row",
                "array_col",
                "pxl_col_in_fullres",
                "pxl_row_in_fullres",
            ]
    andata.obs = andata.obs.join(positions, how="left")
    andata.obsm["spatial"] = andata.obs[
                ["pxl_row_in_fullres", "pxl_col_in_fullres"]
            ].to_numpy()
    andata.obs.drop(
        columns=["pxl_row_in_fullres", "pxl_col_in_fullres"],
        inplace=True,
    )
    andata.obsm['spatial'] = np.array(andata.obsm['spatial'], dtype=np.float64)
    andata.var_names_make_unique()
    # for keeping the cell names for compitability with other programs
    andata.obs["obs_names"] = andata.obs_names
    return andata

### Function Summary: Construct Spatial Neighbor Networks - Cal_Spatial_Net

This function constructs spatial neighbor networks for the spots in an AnnData object based on either radius or nearest neighbors.

**Parameters:**
- `adata`: Input AnnData object.
- `rad_cutoff`: Radius cutoff for connecting spots when `model='Radius'`.
- `k_cutoff`: Number of nearest neighbors for connecting spots when `model='KNN'`.
- `max_neigh`: Maximum number of neighbors to consider.
- `model`: Network construction method. Options:
  - `'Radius'`: Connects spots within a specified radius (`rad_cutoff`).
  - `'KNN'`: Connects each spot to its `k_cutoff` nearest neighbors.
- `verbose`: If True, outputs detailed progress.

**Returns:**  
The spatial networks are saved in `adata.uns['Spatial_Net']`.


In [35]:
Batch_list = []
adj_list = []
section_ids = ['0','1']
for i,p in enumerate(path_list):
    adata = read_data(p)
    adata.obs_names = [x + '_' + section_ids[i] for x in adata.obs_names]
    STAligner.Cal_Spatial_Net(adata)  
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata,flavor = 'cell_ranger',n_top_genes = 2000)
    
    adata = adata[:, adata.var['highly_variable']]
    adj_list.append(adata.uns['adj'])
    Batch_list.append(adata)

Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.


------Calculating spatial graph...
The graph contains 0 edges, 119082 cells.
0.0000 neighbors per cell on average.


Variable names are not unique. To make them unique, call `.var_names_make_unique`.
Variable names are not unique. To make them unique, call `.var_names_make_unique`.


------Calculating spatial graph...
The graph contains 0 edges, 175095 cells.
0.0000 neighbors per cell on average.


In [36]:
adata

View of AnnData object with n_obs × n_vars = 175095 × 2000
    obs: 'in_tissue', 'array_row', 'array_col', 'obs_names'
    var: 'gene_ids', 'feature_types', 'genome', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'spatial', 'Spatial_Net', 'adj', 'log1p', 'hvg'
    obsm: 'spatial'

In [37]:
adata_concat = ad.concat(Batch_list, label="slice_name", keys=section_ids)
adata_concat.obs["batch_name"] = adata_concat.obs["slice_name"].astype('category')

In [38]:
adata_concat

AnnData object with n_obs × n_vars = 294177 × 982
    obs: 'in_tissue', 'array_row', 'array_col', 'obs_names', 'slice_name', 'batch_name'
    obsm: 'spatial'

In [39]:
def memory_efficient_block_diag(adj_list):
    """
    Constructs a block diagonal matrix from a list of sparse adjacency matrices in a memory-efficient way.
    
    Parameters:
        adj_list (list): A list of sparse matrices (scipy.sparse).

    Returns:
        scipy.sparse.csr_matrix: Block diagonal sparse matrix.
    """
    # Calculate the total size of the final matrix
    total_rows = sum(adj.shape[0] for adj in adj_list)
    total_cols = sum(adj.shape[1] for adj in adj_list)

    # Initialize sparse matrices for the data, row indices, and column indices
    data = []
    row_indices = []
    col_indices = []

    row_offset = 0
    col_offset = 0

    for adj in adj_list:
        # Convert sparse matrix to COO format for efficient indexing
        adj_coo = adj.tocoo()

        # Append the data and adjusted indices
        data.append(adj_coo.data)
        row_indices.append(adj_coo.row + row_offset)
        col_indices.append(adj_coo.col + col_offset)

        # Update offsets for the next block
        row_offset += adj.shape[0]
        col_offset += adj.shape[1]

    # Concatenate data and indices from all blocks
    data = np.concatenate(data)
    row_indices = np.concatenate(row_indices)
    col_indices = np.concatenate(col_indices)

    # Construct the final sparse matrix
    block_diag_matrix = sp.csr_matrix((data, (row_indices, col_indices)), shape=(total_rows, total_cols))
    return block_diag_matrix

In [40]:
# Example usage
adj_concat = memory_efficient_block_diag(adj_list)
adata_concat.uns['edgeList'] = np.nonzero(adj_concat)

In [41]:
adata_concat

AnnData object with n_obs × n_vars = 294177 × 982
    obs: 'in_tissue', 'array_row', 'array_col', 'obs_names', 'slice_name', 'batch_name'
    uns: 'edgeList'
    obsm: 'spatial'

### Function Summary: Train Graph Attention Auto-Encoder with Batch Correction

This function trains a graph attention auto-encoder on an AnnData object and performs batch correction in the embedding space using spot triplets across slices.

**Parameters:**
- `adata`: Input AnnData object.
- `hidden_dims`: Dimension of the encoder.
- `n_epochs`: Number of training epochs.
- `lr`: Learning rate for optimization.
- `key_added`: Key to store latent embeddings in `adata.obsm`.
- `gradient_clipping`: Applies gradient clipping during training.
- `weight_decay`: Regularization parameter for Adam optimizer.
- `margin`: Margin in triplet loss for batch correction; larger values enforce stronger correction.
- `iter_comb`: Order of pairwise slice integration (e.g., `(0, 1)` aligns slice 0 with slice 1 as the reference).
- `knn_neigh`: Number of nearest neighbors for constructing mutual nearest neighbors (MNNs).
- `device`: Specifies the computation device (e.g., CPU or GPU).

**Returns:**  
Updated AnnData object with batch-corrected latent embeddings.

In [42]:
iter_comb = [(1, 0)] ## Fix slice 0 as reference to align
adata_concat = STAligner.train_STAligner(adata_concat, verbose=True, knn_neigh = 10, iter_comb = iter_comb, margin=1,  device=used_device)

STAligner(
  (conv1): GATConv(982, 512, heads=1)
  (conv2): GATConv(512, 30, heads=1)
  (conv3): GATConv(30, 512, heads=1)
  (conv4): GATConv(512, 982, heads=1)
)
Pretrain with STAGATE...


100%|█████████████████████████████████████████████████████████████████████████████| 500/500 [01:48<00:00,  4.62it/s]


Train with STAligner...


  0%|                                                                                       | 0/500 [00:00<?, ?it/s]

Update spot triplets at epoch 500


  adj = nx.adjacency_matrix(G)
 20%|███████████████▍                                                             | 100/500 [00:30<01:27,  4.55it/s]

Update spot triplets at epoch 600


 40%|██████████████████████████████▊                                              | 200/500 [01:02<01:05,  4.56it/s]

Update spot triplets at epoch 700


 60%|██████████████████████████████████████████████▏                              | 300/500 [01:33<00:43,  4.58it/s]

Update spot triplets at epoch 800


 80%|█████████████████████████████████████████████████████████████▌               | 400/500 [02:04<00:21,  4.56it/s]

Update spot triplets at epoch 900


100%|█████████████████████████████████████████████████████████████████████████████| 500/500 [02:36<00:00,  3.20it/s]


In [51]:
andata_save = adata_concat.copy()
for column in andata_save.obs.columns:
    print(f"Column '{column}' has data type: {andata_save.obs[column].dtype}")

Column 'in_tissue' has data type: int64
Column 'array_row' has data type: int64
Column 'array_col' has data type: int64
Column 'obs_names' has data type: object
Column 'slice_name' has data type: category
Column 'batch_name' has data type: category


In [60]:
andata_save.obs['obs_names'] = andata_save.obs['obs_names'].astype("string")
andata_save.obs_names = andata_save.obs_names.astype(str)
andata_save.uns['edgeList'] = np.column_stack(andata_save.uns['edgeList'])

In [63]:
andata_save.write_h5ad(os.path.join(pathout, "andata_bc_BreastCancer_STAlign.h5ad"))

... storing 'obs_names' as categorical
