<a href="https://colab.research.google.com/github/dtabuena/Workshop/blob/main/Ref_Map/KZ_BuildRef.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scanpy --quiet
!pip install pybiomart --quiet
!pip install python-igraph --quiet
!pip install louvain --quiet
!pip install pynndescent --quiet


In [None]:
import h5py
import numpy as np
import scipy as sci
from matplotlib import pyplot as plt
import scanpy as sc
import tarfile
import os
import anndata as ad
import pandas as pd
import pybiomart
from tqdm import tqdm
import urllib.request
from IPython.display import clear_output
from matplotlib.pyplot import rc_context
from scipy import stats as st
import logging
logging.Logger('my_log').setLevel('INFO')

############# SPECIFIC CONFIG #############
working_dir = r"C:\Users\dennis.tabuena\Gladstone Dropbox\Dennis Tabuena\0_Projects\_KZ_All"
os.chdir(working_dir)
zalocusky_url = 'https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE167497&format=file'
geo_zalo_filename = 'zalocusky_indiv.tar'

import urllib
response = urllib.request.urlretrieve('https://raw.githubusercontent.com/dtabuena/Resources/main/Matplotlib_Config/Load_FS6.py','Load_FS6.py')
%run Load_FS6.py



In [None]:
def trim_key(k):
    floxed_dict = {'GSM5106175_YH_KZ03_01':('E3fKI_Syn_Cre602_15m','GSM5106175_602_E3fKI_15_XX'),
                   'GSM5106176_YH_KZ03_03':('E4fKI_Syn_Cre475_15m','GSM5106176_475_E4fKI_15_XX')}
    for f in floxed_dict.keys():
        if f in k: return floxed_dict[f][1]
    k = k.replace('_raw_gene_bc_matrices_h5.h5',"")
    return k

def query_capitilaziation(gene,adata):
    try:
        return adata.var.index[ [g.lower() for g in list(adata.var.index)].index(gene.lower()) ]
    except:
        return gene + ' not_found'

def z_score(x,axis=-1):
    x=np.array(x)
    return (x-np.mean(x,axis=axis))/np.std(x,axis=axis)



In [None]:
def pull_gene_annots(csv_loc='./mmusculus_coding_noncoding.csv',
                     my_git='https://raw.githubusercontent.com/dtabuena/Resources/main/Genetics/mmusculus_coding_noncoding.csv',
                     biomart_name='mmusculus',
                     biomart_keys=["ensembl_gene_id", "chromosome_name","transcript_biotype","external_gene_name","peptide"]):

    if os.path.exists('./mmusculus_coding_noncoding.csv'):
        print( 'Use local copy of musmus')
        annot_dd = pd.read_csv('./mmusculus_coding_noncoding.csv').set_index("external_gene_name")
    else:
        try:
            print( 'attempting to pull mus mus from git...')
            musmus_link = 'https://raw.githubusercontent.com/dtabuena/Resources/main/Genetics/mmusculus_coding_noncoding.csv'
            filename = './mmusculus_coding_noncoding.csv'
            urllib.request.urlretrieve(musmus_link, filename)
            annot_dd = pd.read_csv('./mmusculus_coding_noncoding.csv').set_index("external_gene_name")
        except:
            print('attempting to pull mus mus from biomart...')
            annot = sc.queries.biomart_annotations("mmusculus",["ensembl_gene_id", "chromosome_name","transcript_biotype","external_gene_name"],).set_index('ensembl_gene_id')
            uniq_inds = list(set(list(annot.index)))
            for r in tqdm(uniq_inds):
                match_bool = annot.index.str.contains(r)
                if np.sum(match_bool)>1:
                    new_val ='__'.join(list(annot.loc[r,'transcript_biotype']))
                    annot.at[r,'transcript_biotype']=new_val
            annot['is_coding']= annot.transcript_biotype.str.contains('coding')
            annot_dd = annot.drop_duplicates().set_index("external_gene_name")
            annot_dd.to_csv('./mmusculus_coding_noncoding.csv')

    coding_list = annot_dd.index[ annot_dd['is_coding'] ].to_list()
    return coding_list, annot_dd




def preprocess_andata10x(adata_og,pct_mito=0.25,min_genes=500,max_genes=2400,min_counts=500,max_counts=4500):

    print('pulling gene annotations...')
    coding_list, _ = pull_gene_annots()
    adata_og.var['mt'] = adata_og.var_names.str.startswith('mt-')
    adata_og.var['coding'] = [gene in coding_list for gene in adata_og.var_names]
    sc.pp.calculate_qc_metrics(adata_og, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

    adata_QC = adata_og.copy()

    print('Filtering...')
    adata_QC = adata_QC[adata_QC.obs.pct_counts_mt < pct_mito, :]
    print(str(np.sum(adata_og.obs.pct_counts_mt <pct_mito)) + f' cells with >{pct_mito}% removed')
    adata_QC = adata_QC[:, adata_QC.var.coding]
    print(str(np.sum(np.logical_not(adata_og.var.coding))) + ' non coding genes removed')
    sc.pp.filter_cells(adata_QC, min_genes=min_genes)
    sc.pp.filter_cells(adata_QC, max_genes=max_genes)
    sc.pp.filter_cells(adata_QC, min_counts=min_counts)
    sc.pp.filter_cells(adata_QC, max_counts=max_counts)
    fig,ax=plt.subplots(1,figsize=(1.5,1.5))
    sc.pl.scatter(adata_QC, x='total_counts', y='n_genes_by_counts',ax=ax)

    return adata_QC,adata_og

def high_var_genes_dim_reduc(adata,min_mean = 0.25,max_mean = 4,min_disp=0.55):
    ''' The gene expression matrices were then log-normalized with a scale factor of 10,000,
    using the Seurat NormalizeData function57,58. Highly dispersed genes were selected using
    the Seurat FindVariableGenes function57,58,filtering for an average expression range of
    0.25–4 and a minimum dispersion of 0.55, resulting in a list of 2,197 genes.'''
    adata.raw = adata
    sc.pp.normalize_total(adata, target_sum=10000)
    sc.pp.log1p(adata)
    adata.uns['log1p'] = {'base': None}
    print(adata.uns['log1p'])
    sc.pp.highly_variable_genes(adata, min_mean=min_mean, max_mean=max_mean, min_disp=min_disp)
    with rc_context({'figure.figsize': (1.5, 1.5)}):
        sc.pl.highly_variable_genes(adata)
    plt.tight_layout()
    print(np.sum(adata.var['highly_variable']),'hv genes')


    #### PCA
    sc.tl.pca(adata, svd_solver='arpack',n_comps=50)
    fig,ax=plt.subplots(figsize=(1,1))
    ax.plot(adata.uns['pca']['variance_ratio'][:25],'ok',markersize=1)
    quiet_PCA_plots(adata,['E_type','age_bin','mouse_ID'],pc_pairs=[(0,1),(2,3)])

    return adata

def quiet_PCA_plots(adata,key_list,figsize=(2,2),pc_pairs=[(0,1)]):
    fig,ax=plt.subplots(1*len(pc_pairs),len(key_list),figsize=(figsize[0]*len(key_list),figsize[1]*len(pc_pairs)))
    for ip,pair in enumerate(pc_pairs):
        if len(key_list) == 1: ax=[ax]
        for key_ind,key in enumerate(key_list):
            key_types = sorted(list(set( adata.obs[key] )))
            for k in key_types:
                is_k = adata.obs[key]==k
                ax[ip,key_ind].scatter(adata.obsm['X_pca'][is_k,pair[0]],adata.obsm['X_pca'][is_k,pair[1]],s=2,marker='.',linewidth=0,edgecolors=None,label=k)
                ax[ip,key_ind].set_xlabel(f'PC{pair[0]}')
                ax[ip,key_ind].set_ylabel(f'PC{pair[1]}')
            if len(key_types)<8: ax[ip,key_ind].legend(key_types,loc='best',markerscale=3)
            ax[ip,key_ind].set_title(key)
            plt.tight_layout()
    return None


def umap_and_cluster(adata, n_neighbors=10, n_pcs=20,resolution=.6,plot_keys=['Cluster (nn)'],size = 1,to_plot=True):
    sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs,random_state=42)
    sc.tl.louvain(adata,resolution=resolution,random_state=42)
    sc.tl.paga(adata)
    sc.tl.umap(adata,random_state=42)
    adata.obs['Cluster (nn)']= adata.obs['louvain']
    if to_plot:
        with rc_context({'figure.figsize': (2.5, 2.5)}):
            sc.pl.umap(adata,add_outline=False, legend_loc='on data', color=plot_keys,size=size)
    return adata

def explore_umap(adata_GABA,key_list=[],size=1,legend_loc=None):
    with rc_context({'figure.figsize': (1.5,1.5)}):
        sc.pl.umap(adata_GABA, legend_loc=legend_loc, color=key_list,vmin=-0.5,size=size,cmap='Purples') # add_outline=True,
        plt.tight_layout()




def marker_analysis(adata):
    #### Get Marker Genes
    sc.tl.rank_genes_groups(adata, 'louvain', method='wilcoxon',key_added='maker_genes',pts=True,use_raw =True)
    sc.tl.rank_genes_groups(adata, 'louvain', method='logreg',pts=True,use_raw =True)
    maker_genes_df = pd.DataFrame(adata.uns['maker_genes']['names'])
    maker_genes_df_LFC = pd.DataFrame(adata.uns['maker_genes']['logfoldchanges'])
    top_30 = maker_genes_df[:30]
    lfc_thresh=4
    marker_genes = [m for m in set(top_30[maker_genes_df_LFC[:30]>lfc_thresh].values.flatten()) if isinstance(m,str)]
    sc.pl.stacked_violin(adata, marker_genes, groupby='louvain');
    return marker_genes






In [None]:

def show_clusters(adata_GABA, key, cmap='tab20', legend_fontsize='small'):
    assert key in adata_GABA.obs, f"Key '{key}' not found in adata_GABA.obs."

    umap_coords = adata_GABA.obsm['X_umap']
    size = 0.4

    # Create the figure with subplots for scatter and legend
    fig, (ax_scatter, ax_legend) = plt.subplots(1, 2, figsize=(3.5, 2.5), gridspec_kw={'width_ratios': [5, 1]})

    # Get cluster colors
    colors = adata_GABA.obs[key].values.astype(float)
    scatter = ax_scatter.scatter(
        umap_coords[:, 0], umap_coords[:, 1],
        c=colors, cmap=cmap, s=size, vmin=-0.5, edgecolor='none'  # Set edgecolor to 'none'
    )

    # Set equal aspect ratio to make axes square
    ax_scatter.set_aspect('equal', adjustable='box')
    ax_scatter.set_xlabel('umap 1',ha='left')
    ax_scatter.xaxis.set_label_coords(0, 0.0)
    ax_scatter.set_ylabel('umap 2',ha='left')
    ax_scatter.yaxis.set_label_coords(0.0, 0)
    ax_scatter.set_xticks([])
    ax_scatter.set_yticks([])

    unique_clusters = np.unique(colors)
    handles = []
    labels = []

    # Create custom legend handles for unique clusters
    for cluster in unique_clusters:
        color_i = scatter.cmap(scatter.norm(cluster))
        handle = plt.Line2D([], [], marker='o', linestyle='None', color=color_i,
                             markerfacecolor=color_i,
                             markersize=4.5)  # Smaller markers
        handles.append(handle)
        labels.append(f'Cluster {int(cluster)}')

    # Turn off the axis for the legend subplot
    ax_legend.axis('off')

    # Add the custom legend to the legend axis with specified font size
    ax_legend.legend(handles, labels, loc='center', fontsize=legend_fontsize, markerscale=0.75,
                     ncol=1, frameon=False)
    plt.tight_layout()
    plt.show()

    return fig, (ax_scatter, ax_legend)

# Example calls to the function
# show_clusters(adata_GABA, 'Cluster (nn)')


def show_gene_loading(adata_GABA, gene, cmap='Purples'):
    assert gene in adata_GABA.var_names, f"Gene '{gene}' not found in adata_GABA.var_names."

    umap_coords = adata_GABA.obsm['X_umap']
    size = 0.5

    # Get gene expression values for the specified gene
    gene_expression = adata_GABA[:, gene].X.toarray().flatten()  # Convert to dense array and flatten

    # Create the figure
    fig = plt.figure(figsize=(3, 1.5))

    # Define positions for scatter and colorbar axes
    ax_scatter = fig.add_axes([0.0, 0.1, 0.5, 0.8])  # [left, bottom, width, height]
    ax_cbar = fig.add_axes([0.50, 0.1, 0.01, 0.8])  # Slightly to the right of the scatter plot

    # Scatter plot
    scatter = ax_scatter.scatter(
        umap_coords[:, 0], umap_coords[:, 1],
        c=gene_expression, cmap=cmap, s=size, edgecolor='none',vmin=-.45,vmax=2.75)

    # Set equal aspect ratio to make axes square
    ax_scatter.set_aspect('equal', adjustable='box')
    ax_scatter.set_xlabel('umap 1',ha='left')
    ax_scatter.xaxis.set_label_coords(0, 0.0)
    ax_scatter.set_ylabel('umap 2',ha='left')
    ax_scatter.yaxis.set_label_coords(0.0, 0)
    ax_scatter.set_xticks([])
    ax_scatter.set_yticks([])

    # Colorbar
    cbar = fig.colorbar(scatter, cax=ax_cbar, orientation='vertical')
    cbar.set_label(f'stdev/mean', rotation=270, labelpad=10)
    ax_scatter.set_title(gene.upper())
    # Ensure colorbar is labeled
    # ax_cbar.set_ylabel(f'{gene} Expression', rotation=270, labelpad=10)

    # Adjust layout and show the plot
    plt.tight_layout()
    plt.show()

    return fig, (ax_scatter, ax_cbar)

# Example call to the function
# show_gene_loading(adata_GABA, 'Sst')

# cluster_fig, ax = show_clusters(adata_GABA, 'Cluster (nn)')
# cluster_fig.savefig("cluster_plot.svg", format="svg")
# cluster_fig.savefig("cluster_plot.jpeg", format="jpeg")

In [None]:

################# INITIALIZE DIRECTORY DOWNLOAD FROM GEO

# try: os.makedirs(working_dir)
# except: None
os.chdir(working_dir)
try: os.makedirs('./indiv_animal_results')
except: None
urllib.request.urlretrieve(zalocusky_url, './indiv_animal_results/'+geo_zalo_filename)
my_tar = tarfile.open('./indiv_animal_results/'+geo_zalo_filename)
my_tar.extractall('./indiv_animal_results') # specify which folder to extract to
my_tar.close()
# for f in os.listdir('./indiv_animal_results'):
#     print(f)


In [None]:
import h5py
import scanpy as sc

# Replace with your HDF5 file path
file_path = 'raw_matix.h5'

# Read the Cell Ranger data
adata = sc.read_10x_h5(file_path)

# Inspect the imported data
print(adata)
print(f"Number of cells: {adata.n_obs}")
print(f"Number of genes: {adata.n_vars}")

In [None]:

################# Read, Combine, and Sample Split multiple 10x's

adata_dict = {}
for f in os.listdir('./indiv_animal_results'):
    if '.h5' in f:
        a = sc.read_10x_h5('./indiv_animal_results/'+f)
        a.var_names_make_unique()
        sample_code = trim_key(f)
        a.obs['age_bin'] = str(int(np.ceil( int(sample_code.split("_")[3])/5)*5))+'m'
        a.obs['E_type'] = sample_code.split("_")[2]
        a.obs['mouse_ID'] = sample_code.split("_")[1]
        a.obs['well'] = sample_code.split("_")[4]
        a.obs['GSM'] = sample_code.split("_")[0]
        adata_dict[sample_code.split("_")[0]] = a
adata = ad.concat(adata_dict,axis = 0,label="Sample",index_unique="_")
adata.obs.E_type
# adata = adata[['fKI' not in t for t in adata.obs.E_type], :]
adata_dict = {}
clear_output()
adata.write_h5ad(filename='./kz_adata_raw.h5')


adata_QC = preprocess_andata10x(adata)[0]
adata_QC.write_h5ad(filename='./kz_adata_qc.h5')


############## GABAERGIC Filter #####################
adata_GABA = adata_QC.copy()
age_dict = {'5m':'05m', '10m': '10m','15m': '15m+','20m': '15m+'}
adata_GABA.obs['age_bin'] = [ age_dict[a] for a in adata_GABA.obs['age_bin'] ]

adata_GABA.obs['Gad1_pos'] = z_score(sc.get.obs_df(adata_GABA,'Gad1'))>0.5
adata_GABA.obs['Gad2_pos'] = z_score(sc.get.obs_df(adata_GABA,'Gad2'))>0.5
adata_GABA.obs['Syn1_pos'] = z_score(sc.get.obs_df(adata_GABA,'Syn1'))>-100
is_gaba = np.logical_and( np.logical_or(adata_GABA.obs['Gad1_pos'] , adata_GABA.obs['Gad2_pos'] ), adata_GABA.obs['Syn1_pos'])


adata_GABA.obs['Gabaergic'] = is_gaba
adata_GABA = adata_GABA[ adata_GABA.obs['Gabaergic'] ,:]
sc.pp.filter_genes(adata_GABA, min_counts=50)

############## Dim Reduction #####################
adata_GABA = high_var_genes_dim_reduc(adata_GABA)
adata_GABA.write_h5ad(filename='./kz_adata_GABA.h5')


print(adata_GABA.uns['log1p'])

In [None]:
# ############## Dim Reduction #####################
# # adata_GABA = high_var_genes_dim_reduc(adata_GABApub
# # adata_GABA.write_h5ad(filename='./kz_adata_GABA.h5')

# adata_GABA = ad.read_h5ad('./kz_adata_GABA.h5')
# adata_GABA = umap_and_cluster(adata_GABA, n_neighbors=15, n_pcs=20,resolution=.6,to_plot=False)
# show_clusters(adata_GABA, 'Cluster (nn)')

In [None]:
show_clusters(adata_GABA, 'Cluster (nn)')
show_gene_loading(adata_GABA, 'Sst')
show_gene_loading(adata_GABA, 'Pvalb')
show_gene_loading(adata_GABA, 'Vip')
show_gene_loading(adata_GABA, 'Reln')
show_gene_loading(adata_GABA, 'Apoe')
show_gene_loading(adata_GABA, 'Syn1')
show_gene_loading(adata_GABA, 'Gfap')


adata_GABA = umap_and_cluster(adata_GABA, n_neighbors=15, n_pcs=20,resolution=.6,to_plot=False)
explore_umap(adata_GABA,['Cluster (nn)','Sst', 'Pvalb'],legend_loc='on data')
explore_umap(adata_GABA,['Cluster (nn)','E_type', 'age_bin'])
explore_umap(adata_GABA,['Cluster (nn)','Vip', 'Reln'])
explore_umap(adata_GABA,['Cluster (nn)','Gfap', 'Syn1'])
explore_umap(adata_GABA,['Cluster (nn)','Apoe','E_type'])



In [None]:
adata_GABA

In [None]:
kcn_list = [n for n in adata_GABA.var.index if 'kcn' in str(n).lower()]
kcn_list.sort()


In [None]:
def deg_set_up(adata,keys_of_interest):
    combo_key = '_'.join(keys_of_interest)
    adata.obs[combo_key] = adata.obs[keys_of_interest[0]]
    for i,v in enumerate(keys_of_interest[1:]):
        # to_add = adata.obs[v]
        adata.obs[combo_key] = adata.obs[combo_key].str.cat(adata.obs[v].astype(str),sep='_')
    return adata,combo_key


def deg_analysis(adata_temp,new_file_path,analysis_pairs,group_key,adata_name='volcano',log2fc_extrema=[-15,15],label_cutoff=1,to_plot=True):
    '''
    Takes in anndata RNA object and a stratifying key and returns
    DEGs based on given grouppings. Also builds volcano plots.    '''
    adata_temp.raw = adata_temp
    '''Perform statistics and write to DF'''
    # print(os.getcwd())
    deg_df_dict={}
    for key,pair in analysis_pairs.items():
        sc.tl.rank_genes_groups(adata_temp,use_raw=True,groupby=group_key,groups=pair,reference=pair[1],key_added=key,method='wilcoxon',tie_correct=True)
        deg_df = sc.get.rank_genes_groups_df(adata_temp, group=pair[0],key=key,pval_cutoff=1,log2fc_min=log2fc_extrema[0], log2fc_max=log2fc_extrema[1])
        deg_df = deg_df.set_index('names')
        deg_df_dict[key] = deg_df
        new_file_name = os.path.join(new_file_path,adata_name+'_'+key+'.csv')
        print(new_file_name)
        deg_df.to_csv(new_file_name)
    return adata_temp, deg_df_dict

def deg_volc_plot(deg_df_dict,gene_list=[],label_cutoff=0.1,suptitle='',lock_y_max=True,special_genes=[],use_adj=False):
    ''' Plot Volcanoes '''
    fig_volcano,ax=plt.subplots(1,len(deg_df_dict),figsize=(len(deg_df_dict)*1.5,2))
    if len(deg_df_dict) == 1:  ax=[ax]
    for key,df in deg_df_dict.items():
        key_ind = list(deg_df_dict.keys()).index(key)
        ax[key_ind].set_title(key.replace('_','\n'))
        genes= df.index
        lfc = df['logfoldchanges']
        if use_adj: pval = df['pvals_adj']
        else: pval = df['pvals']

        neg_log10_pval = -np.log10(pval)
        for i,g in enumerate(genes):
            if str(g) in gene_list:
                if not np.isnan(lfc[i]*neg_log10_pval[i]):
                    ax[key_ind].scatter(lfc[i],neg_log10_pval[i],c='k',s=3)
                    if pval[i]<label_cutoff:
                        ax[key_ind].text(lfc[i],neg_log10_pval[i],genes[i],rotation=45,fontsize=4)

                if g in special_genes:
                    if not np.isnan(lfc[i]*neg_log10_pval[i]):
                        ax[key_ind].scatter(lfc[i],neg_log10_pval[i],c='r',s=3)
                        ax[key_ind].text(lfc[i],neg_log10_pval[i],genes[i],rotation=45,fontsize=4,color='r')
    x_lim=[]
    y_lim=[]
    for a in ax:
        x_lim.append(a.get_xlim())
        y_lim.append(a.get_ylim())

    x_etr = np.max(np.abs(x_lim))
    y_etr = np.max(np.abs(y_lim))

    for a in ax:
        a.set_xlim(-x_etr*1.1,x_etr*1.1)
        if lock_y_max: a.set_ylim(0,y_etr*1.1)
        else: a.set_ylim(0,a.get_ylim()[1]*1.1)
        a.set_xlabel('Log2 Fold Change')
        if use_adj: a.set_ylabel('-log10(adj-pvalue)')
        else: a.set_ylabel('-log10(raw-pvalue)')
        a.axhline(-np.log10(0.05),c='k',linestyle=":",linewidth=.6,label='p=0.05')

    fig_volcano.suptitle(suptitle)
    plt.tight_layout()
    return fig_volcano,ax

def deg_heatmap(deg_dict,gene_list=[],ax=None,lfcmax = 2.5,ax_cbar=None,suptitle=''):

    try: ax.grid(visible=False)
    except:
        fig,axs = plt.subplots(1,2,figsize=(1+.5*len(deg_dict.keys()),len(gene_list)/14),width_ratios=[9,.5])

        ax=axs[0]
        ax_cbar = axs[1]
        ax.grid(visible=False)
    if len(gene_list)==0:gene_list = list()

    heat_map_null=np.zeros([len(gene_list),len(deg_dict)])
    heat_map_lfc = heat_map_null.copy()*np.nan
    heat_map_p = heat_map_null.copy()*np.nan
    for key,deg_df in deg_dict.items():
        key_id = list(deg_dict.keys()).index(key)
        for gene in gene_list:
            i_gene = gene_list.index(gene)
            try:
                heat_map_lfc[i_gene,key_id] = deg_df.loc[gene,'logfoldchanges']
            except:
                None
    cbar = ax.pcolorfast(heat_map_lfc,vmin=-lfcmax,vmax=lfcmax,cmap='bwr')
    ax.set_yticks(np.arange(len(gene_list))+0.5, labels=gene_list,rotation=0)
    ax.tick_params(length=0)

    fmt_xlabels = list(deg_dict.keys())
    fmt_xlabels = [xl.replace('_','\n') for xl in fmt_xlabels]
    ax.set_xticks(np.arange(len(deg_dict.keys()))+0.5, labels=fmt_xlabels,rotation=0)
    ax.grid(visible=False)


    for key,deg_df in deg_dict.items():
        key_id = list(deg_dict.keys()).index(key)
        for gene in gene_list:
            i_gene = gene_list.index(gene)
            try:
                p = deg_df.loc[gene,'pvals']
                if p<.1:
                    p_str = str(p)[:5]
                    l2fc = str(2**deg_df.loc[gene,'logfoldchanges'])[:3]
                    tag = f"{l2fc}X; p={p_str}"
                    tag = pval_to_star(p)
                    ax.text(key_id+.5,i_gene+.5,tag,ha='center',va='top',rotation=0)
            except: None

    ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

    try:
        plt.colorbar(cbar,cax=ax_cbar)
        ax_cbar.set_ylabel('log2(Fold Change)')
    except: plt.colorbar(cbar)
    plt.suptitle(suptitle)
    plt.tight_layout()
    return plt.gcf(), ax, cbar

def pval_to_star(p,specifics={(.05,.1):'\'',(.05,1):'ns'},max_star=5):
    for range,tag in specifics.items():
        if p<=np.max(range) and p>np.min(range): return tag
    if p<=0.5 and p>0.01: return '*'
    nlogp=-np.log10(p)
    return ('*'*int(nlogp))[:max_star]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def single_deg_volcano(
    deg_df,
    gene_list=[],
    label_cutoff=0.1,
    title=None,
    comp_string=None,
    lock_y_max=True,
    special_genes=[],
    use_adj=False,
    marker_size = .5
    )   :
    '''
    Plot a single volcano plot for a DE comparison.

    Parameters:
        deg_df (DataFrame): DataFrame of DE results with gene names as index and columns including
                            'logfoldchanges', 'pvals', and optionally 'pvals_adj'.
        gene_list (list, optional): List of genes to be highlighted in the plot (default is an empty list).
        label_cutoff (float, optional): Minimum -log10(p-value) threshold for labeling genes (default is 0.1).
        title (str, optional): Title for the plot (default is an empty string).
        lock_y_max (bool, optional): If True, locks the y-axis max across subplots; adjust to data range otherwise.
        special_genes (list, optional): Genes to be specially highlighted in red (default is an empty list).
        use_adj (bool, optional): If True, use adjusted p-values; otherwise, use raw p-values (default is False).

    Returns:
        fig_volcano (Figure): Matplotlib figure object for the volcano plot.
        ax (Axes): Matplotlib axes object for the volcano plot.
    '''

    # Initialize the figure and axis for the volcano plot
    fig_volcano, ax = plt.subplots(1, 1, figsize=(1.75, 1.75))

    # Set the title for the plot
    ax.set_title(title.replace('_', '\n'))

    # Extract necessary data
    genes = deg_df.index
    lfc = deg_df['logfoldchanges']
    pval = deg_df['pvals_adj'] if use_adj else deg_df['pvals']
    neg_log10_pval = -np.log10(pval)

    # Plot each gene, marking those in gene_list and special_genes
    for i, g in enumerate(genes):
        if str(g) in gene_list:
            if not np.isnan(lfc[i] * neg_log10_pval[i]):
                # Highlight genes in gene_list
                ax.scatter(lfc[i], neg_log10_pval[i], c='k', s=marker_size)
                if pval[i] < label_cutoff:
                    ax.text(lfc[i], neg_log10_pval[i], genes[i], rotation=45, fontsize=4)

            # Highlight special genes in red
            if g in special_genes:
                if not np.isnan(lfc[i] * neg_log10_pval[i]):
                    ax.scatter(lfc[i], neg_log10_pval[i], c='r', s=marker_size)
                    ax.text(lfc[i], neg_log10_pval[i], genes[i], rotation=45, fontsize=4, color='r')

    # Set axis limits and labels
    x_lim = ax.get_xlim()
    y_lim = ax.get_ylim()
    x_etr = np.max(np.abs(x_lim))
    y_etr = np.max(np.abs(y_lim))

    ax.set_xlim(-x_etr * 1.1, x_etr * 1.1)
    if lock_y_max:
        ax.set_ylim(0, y_etr * 1.1)
    else:
        ax.set_ylim(0, ax.get_ylim()[1] * 1.1)

    if comp_string is not None:
        xlabel = f'{comp_string}\nLog2(fold change)'
    else:
        xlabel = 'Log2(fold change)'
    ax.set_xlabel(xlabel)
    ax.set_ylabel('-log10(adj-pvalue)' if use_adj else '-log10(raw-pvalue)')
    ax.axhline(-np.log10(0.05), c='k', linestyle=":", linewidth=0.6, label='p=0.05')
    ax.axvline(0, c='k', linestyle="-", linewidth=0.6, label='p=0.05')

    plt.tight_layout()

    return fig_volcano, ax




In [None]:
def get_gene_cluster(adata,gene):
    gene_bool = z_score(sc.get.obs_df(adata,gene))>2
    gene_cluster_list = np.array(adata.obs['louvain'][gene_bool]).astype(int)
    gene_clust = st.mode(gene_cluster_list)[0][0]
    return gene_clust


sst_clust = get_gene_cluster(adata_GABA,'Sst')
pv_clust = get_gene_cluster(adata_GABA,'Pvalb')

print('sst_clust',sst_clust)
print('pv_clust',pv_clust)

In [None]:
sst_clust
print(adata_GABA.uns['log1p'])
adata_GABA.uns['log1p'] = {'base': None}
print(adata_GABA.uns['log1p'])

In [None]:
set(adata_GABA.obs.E_type)

In [None]:
analysis_pairs_34 = { a+'_E4vE3':(a+"_E4", a+"_E3") for a in ages}
analysis_pairs_4f4 = {'15m+_E4vE4fKI':("15m+_E4", "15m+_E4fKI")}
analysis_pairs = analysis_pairs_34 | analysis_pairs_4f4
print(analysis_pairs)

In [None]:
###### SST Cluster KCN DEGs #####
path='./DEG_Reuslts/KCN/SST_Cluster_DEG_KCN_4v3_across_ages/'
try: os.makedirs(path)
except: None
adata_sst_cluster = adata_GABA.copy()[adata_GABA.obs['Cluster (nn)'].astype(int)==sst_clust,:]
adata_sst_cluster.var['is_KCN'] = [g in kcn_list for g in adata_sst_cluster.var.index]
adata_sst_cluster.obs['age_geno'] = adata_sst_cluster.obs['age_bin'].str.cat(adata_sst_cluster.obs['E_type'],sep='_')
ages = sorted(list(set(adata_sst_cluster.obs['age_bin'])))
print(analysis_pairs)
adata_name = 'SST_cluster_KCNs'
sst_cluster_deg_data = deg_analysis(adata_sst_cluster,path,analysis_pairs,group_key = 'age_geno',log2fc_extrema=[-15,15],to_plot=False)
single_deg_volcano(sst_cluster_deg_data[1]['15m+_E4vE3'],gene_list=kcn_list,title='SST Cluster 15-20m+\nApoE4 vs ApoE3 ',special_genes=['Kcnt2'])

single_deg_volcano(sst_cluster_deg_data[1]['15m+_E4vE4fKI'],gene_list=kcn_list,title='SST Cluster 15-20m+\nApoE4 vs nEKO',special_genes=['Kcnt2'])

###### PV Cluster KCN DEGs #####
path='./DEG_Reuslts/KCN/PV_Cluster_DEG_KCN_4v3_across_ages/'
try: os.makedirs(path)
except: None
adata_pv_cluster = adata_GABA.copy()[adata_GABA.obs['Cluster (nn)'].astype(int)==pv_clust,:]
adata_pv_cluster.var['is_KCN'] = [g in kcn_list for g in adata_pv_cluster.var.index]
adata_pv_cluster.obs['age_geno'] = adata_pv_cluster.obs['age_bin'].str.cat(adata_pv_cluster.obs['E_type'],sep='_')
ages = sorted(list(set(adata_pv_cluster.obs['age_bin'])))
adata_name = 'pv_cluster_KCNs'
pv_cluster_deg_data = deg_analysis(adata_pv_cluster,path,analysis_pairs,group_key = 'age_geno',log2fc_extrema=[-15,15],to_plot=False)
single_deg_volcano(pv_cluster_deg_data[1]['15m+_E4vE3'],gene_list=kcn_list,title='PV Cluster 15-20m+\nApoE4 vs ApoE3',special_genes=['Kcnt2'])
single_deg_volcano(pv_cluster_deg_data[1]['15m+_E4vE4fKI'],gene_list=kcn_list,title='PV Cluster 15-20m+\nApoE4 vs nEKO',special_genes=['Kcnt2'])

In [None]:
########### MULTIMODAL DEGS #######################
#### Clusters
path = './DEG_Reuslts/KCN/Cluster_DEGs_multivar/'
try: os.makedirs(path)
except: None
adata_sstpv_clusters = adata_GABA.copy()[np.logical_or(adata_GABA.obs['Cluster (nn)'].astype(int)==sst_clust,adata_GABA.obs['Cluster (nn)'].astype(int)==pv_clust),:]
keys_of_interest = ['age_bin','louvain','E_type']
adata_sstpv_clusters,combo_key = deg_set_up(adata_sstpv_clusters,keys_of_interest)
analysis_pairs = {'15m_SST_E4 vs E3_(genotype)': (f'15m+_{sst_clust}_E4', f'15m+_{sst_clust}_E3'),
                #   'SST_E4_15m vs 5m_(age)': (f'15m+_{sst_clust}_E4', f'05m_{sst_clust}_E4'),
                #   'SST_E4_15m vs 10m_(age)': (f'15m+_{sst_clust}_E4', f'10m_{sst_clust}_E4'),
                  '15m_E4_SST vs PV_(celltype)': (f'15m+_{sst_clust}_E4', f'15m+_{pv_clust}_E4')}
multimodal_cluster_deg_data = deg_analysis(adata_sstpv_clusters,path,analysis_pairs,group_key = combo_key,log2fc_extrema=[-15,15],to_plot=False)
single_deg_volcano(multimodal_cluster_deg_data[1]['15m_E4_SST vs PV_(celltype)'],gene_list=kcn_list,title='15-20m+ ApoE4 SST vs PV Cluster',special_genes=['Kcnt2'])


In [None]:
"""
Figure Plotting and Saving
"""
# Cluster plots
cluster_fig, ax = show_clusters(adata_GABA, 'Cluster (nn)')
cluster_fig.savefig("cluster_plot.svg", format="svg")
cluster_fig.savefig("cluster_plot.jpeg", format="jpeg")

sst_cluster_fig, ax = show_gene_loading(adata_GABA, 'Sst')
sst_cluster_fig.savefig("sst_cluster_plot.svg", format="svg")
sst_cluster_fig.savefig("sst_cluster_plot.jpeg", format="jpeg")

pv_cluster_fig, ax = show_gene_loading(adata_GABA, 'Pvalb')
pv_cluster_fig.savefig("pv_cluster_plot.svg", format="svg")
pv_cluster_fig.savefig("pv_cluster_plot.jpeg", format="jpeg")

vip_cluster_fig, ax = show_gene_loading(adata_GABA, 'Vip')
vip_cluster_fig.savefig("vip_cluster_plot.svg", format="svg")
vip_cluster_fig.savefig("vip_cluster_plot.jpeg", format="jpeg")

reln_cluster_fig, ax = show_gene_loading(adata_GABA, 'Reln')
reln_cluster_fig.savefig("reln_cluster_plot.svg", format="svg")
reln_cluster_fig.savefig("reln_cluster_plot.jpeg", format="jpeg")

Lamp5_cluster_fig, ax = show_gene_loading(adata_GABA, 'Lamp5')
Lamp5_cluster_fig.savefig("lamp5_cluster_plot.svg", format="svg")
Lamp5_cluster_fig.savefig("lamp5_cluster_plot.jpeg", format="jpeg")

Apoe_cluster_fig, ax = show_gene_loading(adata_GABA, 'Apoe')
Apoe_cluster_fig.savefig("apoe_cluster_plot.svg", format="svg")
Apoe_cluster_fig.savefig("apoe_cluster_plot.jpeg", format="jpeg")

Syn1_cluster_fig, ax = show_gene_loading(adata_GABA, 'Syn1')
Syn1_cluster_fig.savefig("syn1_cluster_plot.svg", format="svg")
Syn1_cluster_fig.savefig("syn1_cluster_plot.jpeg", format="jpeg")

Gfap_cluster_fig, ax = show_gene_loading(adata_GABA, 'Gfap')
Gfap_cluster_fig.savefig("gfap_cluster_plot.svg", format="svg")
Gfap_cluster_fig.savefig("gfap_cluster_plot.jpeg", format="jpeg")

Kcnt2_cluster_fig, ax = show_gene_loading(adata_GABA, 'Kcnt2')
Kcnt2_cluster_fig.savefig("kcnt2_cluster_plot.svg", format="svg")
Kcnt2_cluster_fig.savefig("kcnt2_cluster_plot.jpeg", format="jpeg")

# Volcano plots
sst_volc, ax = single_deg_volcano(sst_cluster_deg_data[1]['15m+_E4vE3'], gene_list=kcn_list, title='SST Cluster 15-20m+',comp_string='ApoE4 vs ApoE3', special_genes=['Kcnt2'])
sst_volc.savefig("sst_volcano_plot.svg", format="svg")
sst_volc.savefig("sst_volcano_plot.jpeg", format="jpeg")

pv_volc, ax = single_deg_volcano(pv_cluster_deg_data[1]['15m+_E4vE3'], gene_list=kcn_list, title='PV Cluster 15-20m+',comp_string='ApoE4 vs ApoE3', special_genes=['Kcnt2'])
pv_volc.savefig("pv_volcano_plot.svg", format="svg")
pv_volc.savefig("pv_volcano_plot.jpeg", format="jpeg")

sst_pv_volc, ax = single_deg_volcano(multimodal_cluster_deg_data[1]['15m_E4_SST vs PV_(celltype)'], gene_list=kcn_list, title='15-20m+ ApoE4' ,comp_string='SST vs PV Cluster', special_genes=['Kcnt2'])
sst_pv_volc.savefig("sst_pv_volcano_plot.svg", format="svg")
sst_pv_volc.savefig("sst_pv_volcano_plot.jpeg", format="jpeg")



In [None]:
!pip install svgutils


# from svgutils.compose import *
from svgutils import compose as svguc



row_1 = 0
row_2 =row_1+240
row_3 = row_2+160
row_4 = row_3+160
dpi = 96
width = 7.25*dpi/2
length = 9.4*dpi
scale_factor= 1.33
weight='bold'
label_size = 12


primary_figure_layout_1 = svguc.Figure(str(width), str(length),

            svguc.Panel(svguc.SVG("cluster_plot.svg"),
            svguc.Text("a", 0, 10, size=label_size,weight=weight,font='arial')
            ).scale(scale_factor).move(width*0, row_1),

            svguc.Panel(svguc.SVG("sst_cluster_plot.svg"),
            svguc.Text("b", 0, 10, size=label_size,weight=weight,font='arial')
            ).scale(scale_factor).move(width*0.0, row_2),

            svguc.Panel(svguc.SVG("pv_cluster_plot.svg"),
            svguc.Text("c", 0, 10, size=label_size,weight=weight,font='arial')
            ).scale(scale_factor).move(width*0.5, row_2),

            svguc.Panel(svguc.SVG("sst_volcano_plot.svg"),
            svguc.Text("d", 0, 10, size=label_size,weight=weight,font='arial')
            ).scale(scale_factor).move(width*0.0, row_3),

            svguc.Panel(svguc.SVG("pv_volcano_plot.svg"),
            svguc.Text("e", 0, 10, size=label_size,weight=weight,font='arial')
            ).scale(scale_factor).move(width*0.5, row_3),

            svguc.Panel(svguc.SVG("sst_pv_volcano_plot.svg"),
            svguc.Text("f", 0, 10, size=label_size,weight=weight,font='arial')
            ).scale(scale_factor).move(width*0, row_4),

            svguc.Panel(svguc.SVG("kcnt2_cluster_plot.svg"),
            svguc.Text("g", 0, 10, size=label_size,weight=weight,font='arial')
            ).scale(scale_factor).move(width*0.5, row_4),

    )

display(primary_figure_layout_1)
primary_figure_layout_1.save("primary_figure_layout_1.svg")