# Visualizing the pre and post denoising results of markers identified by STAGATE across seven regions

In [1]:
# author: Yuhan Jia (jiayuhan21@mails.ucas.ac.cn) and Yiyang Zhang (zhangyiyang1328@163.com)
import scanpy as sc
import anndata
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import os
import tqdm
import glob

# import scvi
from matplotlib import rcParams
rcParams['pdf.fonttype'] = 42 # enables correct plotting of text for PDFs

rcParams['font.sans-serif'] = "Arial"
plt.rcParams['font.size'] = 12

In [2]:
import sys
sys.path.append('./Algorithm/STAGATE_pyG-main/STAGATE_pyG-main/')

import STAGATE_pyG

In [3]:
section_ids = ['ST_3M_AD_1_1', 'ST_3M_AD_1_2', 'ST_3M_WT_1_1', 'ST_3M_WT_1_2',
               'ST_6M_AD_2_1', 'ST_6M_AD_2_2', 'ST_6M_WT_2_1', 'ST_6M_WT_2_2',
               'ST_15M_AD_2_1', 'ST_15M_AD_2_2', 'ST_15M_WT_2_1', 'ST_15M_WT_2_2'] 

In [4]:
Batch_list = []

for section_id in section_ids:
    print(section_id)

    # The path to Data
    input_dir0 = os.path.join('D:/Users/zyzhang/Alzheimer/st/', section_id)
    adata = sc.read_visium(path=input_dir0,
                           count_file='filtered_feature_bc_matrix.h5',
                           load_images=True)
    adata.var_names_make_unique()

    if_choose_file = pd.read_csv(input_dir0 + '/if_choose.csv',
                                 index_col=0,
                                 sep=',')
    barcodes_for_choose = if_choose_file[if_choose_file['if_choose'] ==
                                         1].index  
    adata = adata[barcodes_for_choose]

    Batch_list.append(adata)

ST_3M_AD_1_1


  utils.warn_names_duplicates("var")


ST_3M_AD_1_2
ST_3M_WT_1_1
ST_3M_WT_1_2
ST_6M_AD_2_1
ST_6M_AD_2_2
ST_6M_WT_2_1
ST_6M_WT_2_2
ST_15M_AD_2_1
ST_15M_AD_2_2
ST_15M_WT_2_1
ST_15M_WT_2_2


In [None]:
for m in tqdm.tqdm(range(len(Batch_list))):
    adata = Batch_list[m].copy()
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    
    STAGATE_pyG.Cal_Spatial_Net(adata, rad_cutoff=300)
    STAGATE_pyG.Stats_Spatial_Net(adata)
    
    adata = STAGATE_pyG.train_STAGATE(adata, save_reconstrction=True)

    ntop = 56
    for region in glob.glob('./results/Scanpy_SVG_Merge_regions/*.csv'):
        marker = pd.read_csv(region, index_col=0)
        Specific_Genes = marker.sort_values('logfoldchanges', ascending=False)['names'][:ntop].tolist()
        region_name = region.split('\\')[-1].split('.csv')[0]
        gene_to_expmax = {}

        for gene in Specific_Genes:
            temp = min(pd.DataFrame(adata.layers['STAGATE_ReX'], index=adata.obs_names, columns=adata.var_names).loc[:, gene].max(),
                       pd.DataFrame(adata.X.todense(), index=adata.obs_names, columns=adata.var_names).loc[:, gene].max())
            gene_to_expmax[gene] = temp
        
        if not os.path.exists(f'./results/STAGATE_Denoising_7region_markers/{section_ids[m]}/{region_name}'):
            os.makedirs(f'./results/STAGATE_Denoising_7region_markers/{section_ids[m]}/{region_name}')
        
        
        n_ctypes = len(Specific_Genes)
        nrows = n_ctypes // 7 + 1
        
        ######################################################    
        with mpl.rc_context({'axes.facecolor':  'black'}):
            fig, axs = plt.subplots(
                nrows=nrows, ncols=7, figsize=(4.5 * (7 + 1) + 2, 2.5 * (n_ctypes//7+1) + 1), squeeze=True
            )
            for i in range(n_ctypes):
                sc.pl.spatial(
                    adata,
                    cmap="rainbow",
                    color=Specific_Genes[i],
                    size=1.3,
                    img_key="hires",
                    # limit color scale at 99.2% quantile of gene expression
                    ax=axs[i//7, i%7],
                    show=False,
                    frameon=False,
                    vmax=gene_to_expmax[Specific_Genes[i]]
                )
    
        fig.savefig(f'./results/STAGATE_Denoising_7region_markers/{section_ids[m]}/{region_name}/before_Denoising.pdf')
        
        ######################################################
        with mpl.rc_context({'axes.facecolor':  'black'}):
            fig, axs = plt.subplots(
                nrows=nrows, ncols=7, figsize=(4.5 * (7 + 1) + 2, 2.5 * (n_ctypes//7+1) + 1), squeeze=True
            )
            for i in range(n_ctypes):
                sc.pl.spatial(
                    adata,
                    cmap="rainbow",
                    color=Specific_Genes[i],
                    layer='STAGATE_ReX',
                    size=1.3,
                    img_key="hires",
                    # limit color scale at 99.2% quantile of gene expression
                    ax=axs[i//7, i%7],
                    show=False,
                    frameon=False,
                    vmax=gene_to_expmax[Specific_Genes[i]]
                )
        fig.savefig(f'./results/STAGATE_Denoising_7region_markers/{section_ids[m]}/{region_name}/after_Denoising.pdf')

  0%|                                                                                                                                                                                      | 0/12 [00:00<?, ?it/s]

------Calculating spatial graph...
The graph contains 14016 edges, 2405 cells.
5.8279 neighbors per cell on average.
Size of Input:  (2405, 32285)



  0%|                                                                                                                                                                                    | 0/1000 [00:00<?, ?it/s][A
  0%|▏                                                                                                                                                                         | 1/1000 [00:04<1:18:35,  4.72s/it][A
  0%|▎                                                                                                                                                                           | 2/1000 [00:04<33:45,  2.03s/it][A
  0%|▌                                                                                                                                                                           | 3/1000 [00:04<19:08,  1.15s/it][A
  0%|▋                                                                                                                                         