In [None]:
import scarches as sca
#import torch
import scanpy as sc
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import os
import time
import pickle
import itertools

# Load data

In [None]:
proc_dir='/data/gpfs/projects/punim2121/Atherosclerosis/xenium_data/processed_data/baysor_processed_output'
scale_param='10'
avg_assignment_conf_thr=0.75

adata_dict={}

for panel in ['Panel1','Panel2']:

    ## Save batch corrected data
    fn=os.path.join(proc_dir,f'filtered_batch_corr_{panel}_cells_scale_{scale_param}_asg_conf_{avg_assignment_conf_thr}.h5ad')
    adata=sc.read_h5ad(fn)
    
    adata_dict[panel]=adata

# Functions for DGE

In [None]:
from pybiomart import Server,Dataset
import itertools


def map_hugo_to_ensembl(hugo_symbols, dataset='hsapiens_gene_ensembl'):
    server = Server(host="http://www.ensembl.org/")
    mart = server['ENSEMBL_MART_ENSEMBL']
    dataset = mart['hsapiens_gene_ensembl']

    # Set the dataset and attributes
    attributes = ['ensembl_gene_id', 'hgnc_symbol']

    # Query the BioMart database
    lookup_table= dataset.query(attributes)
    ensembl_ids=lookup_table.loc[lookup_table['HGNC symbol'].isin(hugo_symbols),'Gene stable ID']

    ## Create dictionary with mathcing ENSEMBL IDS to HUGO symbols. If HUGO not found in lookup tanle, return np.nan
    mapping={hugo:lookup_table.loc[lookup_table['HGNC symbol']==hugo,'Gene stable ID'].tolist() if hugo in lookup_table['HGNC symbol'].tolist() else [np.nan]
                for hugo in hugo_symbols}
    return mapping


###=======================
## Run scanpy's DGE (Preferably with Wilcoxon method, but can be other) and 
#  return the dataframe containing DGE between the clusters in 'leiden_cluster_coln' column
def run_scanpy_dge(adata,groupby_coln,de_method):
    
    adata.X=adata.layers['norm_by_area_all']
    sc.tl.rank_genes_groups(adata, groupby=groupby_coln, method=de_method, key_added=f'dea_{groupby_coln}')
    dge_df=sc.get.rank_genes_groups_df(adata, group=None,key=f'dea_{groupby_coln}')#,gene_symbols='HGNC')
    
    ## Rename column names to the same as DEseq2 output for easier code handling
    dge_df=dge_df.rename(columns={'group':'cluster','logfoldchanges':'log2FC','pvals':'pvalue','pvals_adj':'padj','names':'gene'})
    
    return dge_df    


###=======================
## Run DEseq2's DGE and return the dataframe containing DGE between the clusters in 'groupby_cluster_coln' column
#  - as running DEseq2 can take a long time, run it only for the clusters possibly containing multiple celltypes (==clusters_to_compare)
#  ==> clusters_to_compare: list of cluster names to compare in a "cluster vs. rest' fashion
#  - show_r_output: show output of R or not
def run_deseq2_dge(adata,leiden_clust_coln,clusters_to_compare,show_r_output):
    import tempfile
    import os
    import subprocess
    import pandas as pd
    import scanpy as sc
    with tempfile.TemporaryDirectory() as tmpdirname:
        ifn, ofn, rfn = [os.path.join(tmpdirname, e) for e in ["in.h5ad", "out.csv","script.R"]]
        adata.write_h5ad(ifn,compression='gzip')
        
        clusters_to_compare_str = 'c("{}")'.format('", "'.join(clusters_to_compare))
        
        rcmd=f'''.libPaths("/data/gpfs/projects/punim2121/R_libs/4.2.1");\
                library(DESeq2);\
                library(muscat);\
                library(SingleCellExperiment);\
                library(zellkonverter);\
                get_leiden_marker_genes=function(dat,leiden_clust_colname,coldata,col_to_sum_by,cols_to_correct_for,clusters_to_compare){{;\
                    de_result_list=list();\
                        for (leiden_clust in sort(clusters_to_compare)){{;\
                        print(paste0("Leiden cluster:",leiden_clust));\
                        leid_clust_dat=dat[,dat@colData[,leiden_clust_colname]==leiden_clust];\
                        rest_dat=dat[,dat@colData[,leiden_clust_colname]!=leiden_clust];\
                        leid_clust_pb=aggregateData(leid_clust_dat, assay="raw_counts", fun="sum", by=c(col_to_sum_by));\
                        rest_pb=aggregateData(rest_dat, assay="raw_counts", fun="sum", by=c(col_to_sum_by));\
                        leid_clust_pb_coln=paste(colnames(assay(leid_clust_pb)), "_clust", sep = "");\
                        rest_pb_coln=paste(colnames(assay(rest_pb)), "_rest", sep = "");\
                        coldata_clust=coldata[colnames(assay(leid_clust_pb)),];\
                        coldata_rest=coldata[colnames(assay(rest_pb)),];\
                        contr_data=cbind(assay(leid_clust_pb),assay(rest_pb));\
                        contr_data_coln=c(leid_clust_pb_coln,rest_pb_coln);\
                        colnames(contr_data)=contr_data_coln;\
                        contr_data_coldata=rbind(coldata_clust,coldata_rest);\
                        rownames(contr_data_coldata)=contr_data_coln;\
                        contr_data_coldata[,"contrast"]=c(rep("clust",length(leid_clust_pb_coln)),rep("rest",length(rest_pb_coln)));\
                        contr_data_coldata$contrast=factor(contr_data_coldata$contrast, levels=c("rest","clust"));\
                        contr_data_coldata=contr_data_coldata[colnames(contr_data),];\
                    tryCatch({{;\
                        design_form=as.formula(paste(c(~1,cols_to_correct_for,"contrast"),collapse="+"));\
                        dge=DESeqDataSetFromMatrix(contr_data, colData=contr_data_coldata, design=design_form);\
                        dds=DESeq(dge);\
                        }}, error = function(e) {{;\
                        message("An error occurred while processing Leiden cluster ", leiden_clust, e$message);\
                        if (grepl("every gene contains at least one zero, cannot compute log geometric means",e$message)){{;\
                        zero_rows=apply(contr_data, 1, function(row) all(row == 0));\
                        contr_data=contr_data[!zero_rows, ];\
                        contr_data=contr_data + 1 ;\
                        design_form=as.formula(paste(c(~1,cols_to_correct_for,"contrast"),collapse="+"));\
                        dge=DESeqDataSetFromMatrix(contr_data, colData=contr_data_coldata,design=design_form);\
                        dds=DESeq(dge);\
                        }} else {{stop(e$message)}};\
                        return (dds)}});\
                res=results(dds, contrast=c("contrast","clust","rest"));\
                deseq_out=as.data.frame(res@listData, row.names = res@rownames);\
                df=deseq_out[,c("log2FoldChange","pvalue","padj")];\
                df$gene=rownames(deseq_out);\
                colnames(df)=c("log2FC","pvalue","padj","gene");\
                df[,"cluster"]=leiden_clust;\
                df=df[(!is.na(df$padj)),];\
                df=df[order(df$padj),];\
                de_result_list[[leiden_clust]]=df;\
                }};\
                return (de_result_list);\
                }};\
                adata=zellkonverter::readH5AD(file="{ifn}",reader="python");\
                meta=adata@colData;\
                coldata=aggregate(x=meta[,c("patient","condition")],by=list(original_sample=meta$original_sample),FUN=unique);\
                rownames(coldata)=coldata$original_sample;\
                leiden_clust_colname="{leiden_clust_coln}";\
                col_to_sum_by="original_sample";\
                cols_to_correct_for=c("patient");\
                clusters_to_compare={clusters_to_compare_str};\
                marker_df_list=get_leiden_marker_genes(adata,leiden_clust_colname,coldata,col_to_sum_by,cols_to_correct_for,clusters_to_compare);\
                concatenated_df <- do.call(rbind, marker_df_list);\
                rownames(concatenated_df) <- NULL;\
                write.csv(concatenated_df,file="{ofn}");'''#.format(ifn,leiden_clust_coln,ofn)
            

        #rcmd = rcmd.replace('\\', '\\\\')
        with open(rfn, "w") as f:
            f.write(rcmd)
        print("Running DEseq2 DGE ...")
        #proc = subprocess.Popen(["Rscript", rfn,], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        proc=subprocess.Popen(f"module load foss/2022a R/4.2.1;R < {rfn} --no-save",
                              shell=True,stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        while not proc.poll():
            ## To print output of R script
            output = proc.stdout.readline()
            if show_r_output==True:
                if output:
                    try:
                        print(output.decode("utf-8"))
                    except UnicodeDecodeError as e:
                        print(f"Error decoding output: {e}")
                        print("Raw output:", output)   
            c = proc.stdout.read(1)
            if not c:
                break
                
        dge_df=pd.read_csv(ofn,index_col=0)
    
    return dge_df

# Load dataframes containing the DE genes found in Bulk and scRNA
- Extract the overlapping gene names with the respective Xenium panel genes

In [None]:
## MATCH HGNC SYMBOLS WITH ENSEMBL GENE IDS

data_dir='/data/gpfs/projects/punim2121/Atherosclerosis/xenium_data//bulk_sc_de_genes_fatemeh'
bulk_sc_intersect=pd.read_csv(os.path.join(data_dir,'BULK_SC_intersect.csv'))
de_genes_stab=pd.read_csv(os.path.join(data_dir,'DE_genes_stability.csv'))

bulk_sc_intersect=bulk_sc_intersect.rename(columns={'Genes_sign_BULK_CS':'hgnc_id'})
de_genes_stab=de_genes_stab.rename(columns={'rn':'hgnc_id'})

de_genes_stab=de_genes_stab.loc[de_genes_stab['significant_class']=='Significant in Unstable_Pl vs Cont']


for panel in ['Panel1','Panel2']:
    print(panel)

    adata=adata_dict[panel]
    
    for df in [bulk_sc_intersect,de_genes_stab]:
        mapping=map_hugo_to_ensembl(df['hgnc_id'].tolist())
    
        df[panel]=np.nan

        for hgnc_name,ensem_ids in mapping.items(): 
            overlapp_genes=[ensem_id for ensem_id in ensem_ids if ensem_id in adata.var.index.tolist()]

            if len(overlapp_genes)>0:
                df.loc[df['hgnc_id']==hgnc_name,panel]=overlapp_genes[0]


# Run DGE

In [None]:
dge_df_dict={}

## DGE methods
de_methods=['deseq2','wilcoxon']

for panel in ['Panel1','Panel2'][:1]:
    print(panel)
    dge_df_dict[panel]={}

    adata=adata_dict[panel]
    
    ## THIS SHOULD BE STABLE VS. UNSTABLE ==> FIX THIS
    subclust_coln='condition'
    
    for de_method in de_methods[:]:
        print(de_method)
        
        if de_method!='deseq2':
            dge_df=run_scanpy_dge(adata,subclust_coln,de_method)

        if de_method=='deseq2':
            show_r_output=False
            clusters_to_compare=adata.obs[subclust_coln].unique().tolist()
            dge_df=run_deseq2_dge(adata,subclust_coln,clusters_to_compare,show_r_output)
            
         
        dge_df_dict[panel][de_method]=dge_df
            
        

# Extract genes of ineterest from DGE results 
- overlapping dysregulated genes between bulk-scRNA
- genes which are significant only in unstable Plaques 

In [None]:
dge_df_overlap_dict={}

for panel in ['Panel1','Panel2'][:1]:
    print(panel)
    dge_df_overlap_dict[panel]={}

    adata=adata_dict[panel]
    
    for de_method in de_methods[:]:
        
        dge_df_overlap_dict[panel][de_method]={}
        
        dge_df=dge_df_dict[panel][de_method]
        
        for df,df_name in zip([bulk_sc_intersect,de_genes_stab],['bulk_sc_intersect','de_genes_stab']):
            xen_genes=df[panel].dropna().tolist()
            
            dge_df_overlap=dge_df[dge_df['gene'].isin(xen_genes) & (dge_df['cluster']=='D')]
            
            dge_df_overlap_dict[panel][de_method][df_name]=dge_df_overlap
            
        
            
        
        

In [None]:
#dge_df_overlap_dict['Panel1']['deseq2']['bulk_sc_intersect'].sort_values('padj')
dge_df_dict['Panel1']['deseq2'].sort_values('padj').drop_duplicates(subset=['gene'])

In [None]:
dge_df_dict['Panel1']['wilcoxon'].sort_values('padj').drop_duplicates(subset=['gene'])

In [None]:
dge_df_overlap_dict['Panel1']['deseq2']['bulk_sc_intersect'].sort_values('padj')

In [None]:
dge_df_overlap_dict['Panel1']['wilcoxon']['bulk_sc_intersect'].sort_values('padj')

# Plot spatial distribution of dysregulated genes in both Bulk and scRNA-seq

In [240]:
import squidpy as sq


## Calculate the column-wise quantiles of a numpy matrix==>
#  => Important for setting the range for the colorbars of the scatterplots
def calc_matrix_column_quantiles(matrix,quantile):

    # Get non-zero elements of each column
    non_zero_elements = [col[col != 0] for col in matrix.T]

    # Calculate the 10th quantile for each column
    quantiles = [np.quantile(col, quantile) if len(col) > 0 else np.nan for col in non_zero_elements]
    
    return quantiles



for panel in ['Panel1','Panel2'][:]:
    print(panel)
    adata=adata_dict[panel]
    
    for df,df_name in zip([bulk_sc_intersect,de_genes_stab],['bulk_sc_intersect','de_genes_stab']): 

        genes_of_interest=df[panel].dropna().tolist()

        ncols=1
        sample_names_per_panel=adata.obs['original_sample'].unique().tolist()


        #n=0

        adata.obsm["spatial"] = adata.obs[["x", "y"]].copy().to_numpy()
        
        for sample_name in sample_names_per_panel[:]:

            print(sample_name)

            
            nrows=int(np.ceil(len(genes_of_interest)/ncols))
            
            fig=plt.figure(figsize=(ncols*8.5,nrows*7))
            fig.suptitle(sample_name,fontweight='bold',y=0.92)
            
            title=df.loc[df[panel].isin(genes_of_interest[:]),'hgnc_id'].tolist()

            for n,col in enumerate(genes_of_interest[0:]):

                ax=fig.add_subplot(nrows,ncols,n+1)
                vmin=calc_matrix_column_quantiles(adata[adata.obs['original_sample']==sample_name,col].X.A,0.3)[0]
                vmax=calc_matrix_column_quantiles(adata[adata.obs['original_sample']==sample_name,col].X.A,0.95)[0]

                sns.scatterplot(data=adata.obs[adata.obs['original_sample']==sample_name],x='x',y='y',ax=ax,
                                hue=adata[adata.obs['original_sample']==sample_name,col].X.A.flatten(),
                                palette='viridis',hue_norm=(vmin,vmax),
                                s=2)

                ax.set_title(title[n])

                ax.grid(False)

                norm = plt.Normalize(vmin,vmax)
                sm = plt.cm.ScalarMappable(cmap="viridis", norm=norm)
                sm.set_array([])

                # Remove the legend and add a colorbar
                ax.get_legend().remove()
                ax.figure.colorbar(sm,ax=ax)

            fn=os.path.join(data_dir,df_name,sample_name+'.png')
            os.makedirs(os.path.join(data_dir,df_name),exist_ok=True)

            fig.savefig(fn,bbox_inches='tight')
            plt.close(fig)
    
    

Panel1
Panel1_P1_D
Panel1_P1_H
Panel1_P2_D
Panel1_P2_H
Panel1_P3_D
Panel1_P3_H
Panel1_P4_D
Panel1_P4_H
Panel1_P5_D
Panel1_P6_D
Panel1_P7_D
Panel1_P8_D
Panel1_P9_D
Panel1_P10_D
Panel1_P11_D
Panel1_P12_D
Panel1_P1_D
Panel1_P1_H
Panel1_P2_D
Panel1_P2_H
Panel1_P3_D
Panel1_P3_H
Panel1_P4_D
Panel1_P4_H
Panel1_P5_D
Panel1_P6_D
Panel1_P7_D
Panel1_P8_D
Panel1_P9_D
Panel1_P10_D
Panel1_P11_D
Panel1_P12_D
Panel2
Panel2_P1_D
Panel2_P1_H
Panel2_P2_D
Panel2_P2_H
Panel2_P3_D
Panel2_P3_H
Panel2_P4_D
Panel2_P4_H
Panel2_P5_D
Panel2_P6_D
Panel2_P7_D
Panel2_P8_D
Panel2_P9_D
Panel2_P10_D
Panel2_P11_D
Panel2_P12_D
Panel2_P1_D
Panel2_P1_H
Panel2_P2_D
Panel2_P2_H
Panel2_P3_D
Panel2_P3_H
Panel2_P4_D
Panel2_P4_H
Panel2_P5_D
Panel2_P6_D
Panel2_P7_D
Panel2_P8_D
Panel2_P9_D
Panel2_P10_D
Panel2_P11_D
Panel2_P12_D
