<a href="https://colab.research.google.com/github/dtabuena/Workshop/blob/main/Ref_Map/KZ_BuildRef.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scanpy --quiet
!pip install pybiomart --quiet
!pip install python-igraph --quiet
!pip install louvain --quiet
!pip install pynndescent --quiet
!pip install gprofiler-official --quiet

In [None]:
import h5py
import numpy as np
import scipy as sci
from matplotlib import pyplot as plt
import scanpy as sc
import tarfile
import os
import anndata as ad
import pandas as pd
import pybiomart
from tqdm import tqdm
import urllib.request
from IPython.display import clear_output
from matplotlib.pyplot import rc_context
import logging
import seaborn as sns
import gprofiler


logging.Logger('my_log').setLevel('INFO')

############# SPECIFIC CONFIG #############
working_dir = r"C:\Users\dennis.tabuena\Gladstone Dropbox\Dennis Tabuena\0_Projects\_SeqRef"
os.chdir(working_dir)
zalocusky_url = 'https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE167497&format=file'
geo_zalo_filename = 'zalocusky_indiv.tar'

import urllib
response = urllib.request.urlretrieve('https://raw.githubusercontent.com/dtabuena/Resources/main/Matplotlib_Config/Load_FS6.py','Load_FS6.py')
%run Load_FS6.py



In [None]:
def trim_key(k):
    floxed_dict = {'GSM5106175_YH_KZ03_01':('E3fKI_Syn_Cre602_15m','GSM5106175_602_E3fKI_15_XX'),
                   'GSM5106176_YH_KZ03_03':('E4fKI_Syn_Cre475_15m','GSM5106176_475_E4fKI_15_XX')}
    for f in floxed_dict.keys():
        if f in k: return floxed_dict[f][1]
    k = k.replace('_raw_gene_bc_matrices_h5.h5',"")
    return k

def query_capitilaziation(gene,adata):
    try:
        return adata.var.index[ [g.lower() for g in list(adata.var.index)].index(gene.lower()) ]
    except:
        return gene + ' not_found'

def z_score(x,axis=-1):
    x=np.array(x)
    return (x-np.mean(x,axis=axis))/np.std(x,axis=axis)



In [None]:
def pull_gene_annots(csv_loc='./mmusculus_coding_noncoding.csv',
                     my_git='https://raw.githubusercontent.com/dtabuena/Resources/main/Genetics/mmusculus_coding_noncoding.csv',
                     biomart_name='mmusculus',
                     biomart_keys=["ensembl_gene_id", "chromosome_name","transcript_biotype","external_gene_name","peptide"]):

    if os.path.exists('./mmusculus_coding_noncoding.csv'):
        print( 'Use local copy of musmus')
        annot_dd = pd.read_csv('./mmusculus_coding_noncoding.csv').set_index("external_gene_name")
    else:
        try:
            print( 'attempting to pull mus mus from git...')
            musmus_link = 'https://raw.githubusercontent.com/dtabuena/Resources/main/Genetics/mmusculus_coding_noncoding.csv'
            filename = './mmusculus_coding_noncoding.csv'
            urllib.request.urlretrieve(musmus_link, filename)
            annot_dd = pd.read_csv('./mmusculus_coding_noncoding.csv').set_index("external_gene_name")
        except:
            print('attempting to pull mus mus from biomart...')
            annot = sc.queries.biomart_annotations("mmusculus",["ensembl_gene_id", "chromosome_name","transcript_biotype","external_gene_name"],).set_index('ensembl_gene_id')
            uniq_inds = list(set(list(annot.index)))
            for r in tqdm(uniq_inds):
                match_bool = annot.index.str.contains(r)
                if np.sum(match_bool)>1:
                    new_val ='__'.join(list(annot.loc[r,'transcript_biotype']))
                    annot.at[r,'transcript_biotype']=new_val
            annot['is_coding']= annot.transcript_biotype.str.contains('coding')
            annot_dd = annot.drop_duplicates().set_index("external_gene_name")
            annot_dd.to_csv('./mmusculus_coding_noncoding.csv')

    coding_list = annot_dd.index[ annot_dd['is_coding'] ].to_list()
    return coding_list, annot_dd


def preprocess_andata10x(adata_og,pct_mito=0.25,min_genes=500,max_genes=2400,min_counts=500,max_counts=4500,keep_NC=False):

    print('pulling gene annotations...')
    coding_list, _ = pull_gene_annots()
    adata_og.var['mt'] = adata_og.var_names.str.startswith('mt-')
    adata_og.var['coding'] = [gene in coding_list for gene in adata_og.var_names]
    sc.pp.calculate_qc_metrics(adata_og, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

    adata_QC = adata_og.copy()

    print('Filtering...')
    adata_QC = adata_QC[adata_QC.obs.pct_counts_mt < pct_mito, :]
    print(str(np.sum(adata_og.obs.pct_counts_mt >pct_mito)) + f' cells with >{pct_mito}% mt removed')

    if keep_NC:
        print('keep noncoding')
    else:
        adata_QC = adata_QC[:, adata_QC.var.coding]
        print(str(np.sum(np.logical_not(adata_og.var.coding))) + ' non coding genes removed')

    sc.pp.filter_cells(adata_QC, min_genes=min_genes)
    sc.pp.filter_cells(adata_QC, max_genes=max_genes)
    sc.pp.filter_cells(adata_QC, min_counts=min_counts)
    sc.pp.filter_cells(adata_QC, max_counts=max_counts)
    fig,ax=plt.subplots(1,figsize=(1.5,1.5))
    sc.pl.scatter(adata_QC, x='total_counts', y='n_genes_by_counts',ax=ax)

    return adata_QC,adata_og

def high_var_genes_dim_reduc(adata,min_mean = 0.25,max_mean = 4,min_disp=0.55):
    ''' The gene expression matrices were then log-normalized with a scale factor of 10,000,
    using the Seurat NormalizeData function57,58. Highly dispersed genes were selected using
    the Seurat FindVariableGenes function57,58,filtering for an average expression range of
    0.25–4 and a minimum dispersion of 0.55, resulting in a list of 2,197 genes.'''
    adata.raw = adata
    sc.pp.normalize_total(adata, target_sum=10000)
    sc.pp.log1p(adata)
    adata.uns['log1p'] = {'base': None}
    print(adata.uns['log1p'])
    sc.pp.highly_variable_genes(adata, min_mean=min_mean, max_mean=max_mean, min_disp=min_disp)
    with rc_context({'figure.figsize': (1.5, 1.5)}):
        sc.pl.highly_variable_genes(adata)
    plt.tight_layout()
    print(np.sum(adata.var['highly_variable']),'hv genes')


    #### PCA
    sc.tl.pca(adata, svd_solver='arpack',n_comps=50)
    fig,ax=plt.subplots(figsize=(1,1))
    ax.plot(adata.uns['pca']['variance_ratio'][:25],'ok',markersize=1)
    quiet_PCA_plots(adata,['E_type','age_bin','mouse_ID'],pc_pairs=[(0,1)])

    return adata

def quiet_PCA_plots(adata,key_list,figsize=(2,2),pc_pairs=[(0,1)]):
    fig,ax=plt.subplots(1*len(pc_pairs),len(key_list),figsize=(figsize[0]*len(key_list),figsize[1]*len(pc_pairs)))
    for ip,pair in enumerate(pc_pairs):
        if len(key_list) == 1: ax=[ax]
        for key_ind,key in enumerate(key_list):
            key_types = sorted(list(set( adata.obs[key] )))
            for k in key_types:
                is_k = adata.obs[key]==k
                ax[ip,key_ind].scatter(adata.obsm['X_pca'][is_k,pair[0]],adata.obsm['X_pca'][is_k,pair[1]],s=2,marker='.',linewidth=0,edgecolors=None,label=k)
                ax[ip,key_ind].set_xlabel(f'PC{pair[0]}')
                ax[ip,key_ind].set_ylabel(f'PC{pair[1]}')
            if len(key_types)<8: ax[ip,key_ind].legend(key_types,loc='best',markerscale=3)
            ax[ip,key_ind].set_title(key)
            plt.tight_layout()
    return None


def umap_and_cluster(adata, n_neighbors=10, n_pcs=20,resolution=.6,plot_keys=['Cluster (nn)'],size = 1,to_plot=True):
    sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs,random_state=42, use_rep='X_pca')
    sc.tl.louvain(adata,resolution=resolution,random_state=42)
    sc.tl.paga(adata)
    sc.tl.umap(adata,random_state=42)
    adata.obs['Cluster (nn)']= adata.obs['louvain']
    if to_plot:
        with rc_context({'figure.figsize': (2.5, 2.5)}):
            sc.pl.umap(adata,add_outline=False, legend_loc='on data', color=plot_keys,size=size)
    return adata

def explore_umap(adata_GABA,key_list=[],size=1,legend_loc=None):
    with rc_context({'figure.figsize': (1.5,1.5)}):
        sc.pl.umap(adata_GABA, legend_loc=legend_loc, color=key_list,vmin=-0.5,size=size,cmap='Purples') # add_outline=True,
        plt.tight_layout()




# def marker_analysis(adata):
#     #### Get Marker Genes
#     sc.tl.rank_genes_groups(adata, 'louvain', method='wilcoxon',key_added='maker_genes',pts=True,use_raw =True)
#     sc.tl.rank_genes_groups(adata, 'louvain', method='logreg',pts=True,use_raw =True)
#     maker_genes_df = pd.DataFrame(adata.uns['maker_genes']['names'])
#     maker_genes_df_LFC = pd.DataFrame(adata.uns['maker_genes']['logfoldchanges'])
#     top_30 = maker_genes_df[:30]
#     lfc_thresh=4
#     marker_genes = [m for m in set(top_30[maker_genes_df_LFC[:30]>lfc_thresh].values.flatten()) if isinstance(m,str)]
#     sc.pl.stacked_violin(adata, marker_genes, groupby='louvain');
#     return marker_genes






In [None]:

def show_clusters(adata_GABA, key, cmap='tab20', legend_fontsize='small',size = 0.4,ncol=1,figsize=(3.5, 2.5)):
    assert key in adata_GABA.obs, f"Key '{key}' not found in adata_GABA.obs."

    umap_coords = adata_GABA.obsm['X_umap']


    # Create the figure with subplots for scatter and legend
    fig, (ax_scatter, ax_legend) = plt.subplots(1, 2, figsize=figsize, gridspec_kw={'width_ratios': [5, 1]})

    # Get cluster colors
    colors = adata_GABA.obs[key].values.astype(float)
    scatter = ax_scatter.scatter(
        umap_coords[:, 0], umap_coords[:, 1],
        c=colors, cmap=cmap, s=size, vmin=-0.5, edgecolor='none'  # Set edgecolor to 'none'
    )

    # Set equal aspect ratio to make axes square
    ax_scatter.set_aspect('equal', adjustable='box')
    ax_scatter.set_xlabel('umap 1',ha='left')
    ax_scatter.xaxis.set_label_coords(0, 0.0)
    ax_scatter.set_ylabel('umap 2',ha='left')
    ax_scatter.yaxis.set_label_coords(0.0, 0)
    ax_scatter.set_xticks([])
    ax_scatter.set_yticks([])

    unique_clusters = np.unique(colors)
    handles = []
    labels = []

    # Create custom legend handles for unique clusters
    for cluster in unique_clusters:
        color_i = scatter.cmap(scatter.norm(cluster))
        handle = plt.Line2D([], [], marker='o', linestyle='None', color=color_i,
                             markerfacecolor=color_i,
                             markersize=4.5)  # Smaller markers
        handles.append(handle)
        labels.append(f'Cluster {int(cluster)}')

    # Turn off the axis for the legend subplot
    ax_legend.axis('off')

    # Add the custom legend to the legend axis with specified font size
    ax_legend.legend(handles, labels, loc='center', fontsize=legend_fontsize, markerscale=0.75,
                     ncol=ncol, frameon=False)
    plt.tight_layout()
    plt.show()

    return fig, (ax_scatter, ax_legend)

# Example calls to the function
# show_clusters(adata_GABA, 'Cluster (nn)')


def show_gene_loading(adata, gene, cmap='Purples',size = 0.4,alpha=0.8,figsize=(3, 1.5)):
    assert gene in adata.var_names, f"Gene '{gene}' not found in adata.var_names."

    umap_coords = adata.obsm['X_umap']


    # Get gene expression values for the specified gene
    gene_expression = adata[:, gene].X.toarray().flatten()  # Convert to dense array and flatten

    # Create the figure
    fig = plt.figure(figsize=figsize)

    # Define positions for scatter and colorbar axes
    ax_scatter = fig.add_axes([0.0, 0.1, 0.5, 0.8])  # [left, bottom, width, height]
    ax_cbar = fig.add_axes([0.50, 0.1, 0.01, 0.8])  # Slightly to the right of the scatter plot

    # Scatter plot
    scatter = ax_scatter.scatter(
        umap_coords[:, 0], umap_coords[:, 1],
        c=gene_expression, cmap=cmap, s=size, edgecolor='none',alpha=alpha,vmin=-.45,vmax=2.75)

    # Set equal aspect ratio to make axes square
    ax_scatter.set_aspect('equal', adjustable='box')
    ax_scatter.set_xlabel('umap 1',ha='left')
    ax_scatter.xaxis.set_label_coords(0, 0.0)
    ax_scatter.set_ylabel('umap 2',ha='left')
    ax_scatter.yaxis.set_label_coords(0.0, 0)
    ax_scatter.set_xticks([])
    ax_scatter.set_yticks([])

    # Colorbar
    cbar = fig.colorbar(scatter, cax=ax_cbar, orientation='vertical')
    cbar.set_label(f'stdev/mean', rotation=270, labelpad=10)
    ax_scatter.set_title(gene.upper())
    # Ensure colorbar is labeled
    # ax_cbar.set_ylabel(f'{gene} Expression', rotation=270, labelpad=10)

    # Adjust layout and show the plot
    plt.tight_layout()
    plt.show()

    return fig, (ax_scatter, ax_cbar)

# Example call to the function
# show_gene_loading(adata_GABA, 'Sst')

# cluster_fig, ax = show_clusters(adata_GABA, 'Cluster (nn)')
# cluster_fig.savefig("cluster_plot.svg", format="svg")
# cluster_fig.savefig("cluster_plot.jpeg", format="jpeg")
# show_gene_loading(adata_full, 'Prox1')

In [None]:

################# INITIALIZE DIRECTORY DOWNLOAD FROM GEO
os.chdir(working_dir)
os.makedirs('./indiv_animal_results', exist_ok=True)
urllib.request.urlretrieve(zalocusky_url, './indiv_animal_results/'+geo_zalo_filename)
my_tar = tarfile.open('./indiv_animal_results/'+geo_zalo_filename)
my_tar.extractall('./indiv_animal_results') # specify which folder to extract to
my_tar.close()
# for f in os.listdir('./indiv_animal_results'):
#     print(f)


In [None]:

################# Read, Combine, and Sample Split multiple 10x's

adata_dict = {}
for f in os.listdir('./indiv_animal_results'):
    if '.h5' in f:
        a = sc.read_10x_h5('./indiv_animal_results/'+f)
        a.var_names_make_unique()
        sample_code = trim_key(f)
        a.obs['age_bin'] = str(int(np.ceil( int(sample_code.split("_")[3])/5)*5))+'m'
        a.obs['E_type'] = sample_code.split("_")[2]
        a.obs['mouse_ID'] = sample_code.split("_")[1]
        a.obs['well'] = sample_code.split("_")[4]
        a.obs['GSM'] = sample_code.split("_")[0]
        adata_dict[sample_code.split("_")[0]] = a
adata = ad.concat(adata_dict,axis = 0,label="Sample",index_unique="_")
adata.obs.E_type
# adata = adata[['fKI' not in t for t in adata.obs.E_type], :]
adata_dict = {}
clear_output()
adata.write_h5ad(filename='./kz_adata_raw.h5')


adata_QC = preprocess_andata10x(adata,keep_NC=True)[0]
adata_QC.write_h5ad(filename='./kz_adata_qc.h5')




In [None]:
############## QC Filter #####################
adata_full = adata_QC.copy()
age_dict = {'5m':'05m', '10m': '10m','15m': '15m','20m': '20m'}
adata_full.obs['age_bin'] = [ age_dict[a] for a in adata_full.obs['age_bin'] ]

sc.pp.filter_genes(adata_full, min_counts=50)

############## Dim Reduction #####################
adata_full = high_var_genes_dim_reduc(adata_full)
adata_full.write_h5ad(filename='./adata_full.h5')
print(adata_full.uns['log1p'])



In [None]:
############## Dim Reduction #####################
# adata_full = ad.read_h5ad('./adata_full.h5')
adata_full = umap_and_cluster(adata_full, n_neighbors=15, n_pcs=20,resolution=.6,to_plot=False)
fig, ax = show_clusters(adata_full, 'Cluster (nn)',ncol=2)
fig.savefig('full_umap.svg')

In [None]:
tab20b_colors = plt.cm.get_cmap('tab20b').colors
tab20c_colors = plt.cm.get_cmap('tab20c').colors
tab40_colors = np.concatenate([tab20c_colors, tab20b_colors[4:]])
tab40 = mpl.colors.ListedColormap(tab40_colors)



show_clusters(adata_full, 'Cluster (nn)',cmap=tab40,ncol=1)
# show_gene_loading(adata_full, 'Sst')
# show_gene_loading(adata_full, 'Pvalb')
# show_gene_loading(adata_full, 'Vip')
# show_gene_loading(adata_full, 'Reln')
# show_gene_loading(adata_full, 'Apoe')
fig,ax = show_gene_loading(adata_full, 'Syn1',size=.1,alpha=1)
fig.savefig('Syn1.svg')
fig,ax = show_gene_loading(adata_full, 'Slc17a7',size=.1,alpha=1)
fig.savefig('Slc17a7.svg')
fig,ax = show_gene_loading(adata_full, 'Prox1',size=.1,alpha=1)
fig.savefig('Prox1.svg')
fig,ax = show_gene_loading(adata_full, 'Pdzd2',size=.1,alpha=1)
fig.savefig('Pdzd2.svg')


# show_gene_loading(adata_full, 'Gfap')
# show_gene_loading(adata_full, 'Gad1')
# show_gene_loading(adata_full, 'Gad2')




In [None]:
"""
Re-Filter Re Normalize
"""
pct_mito=0.25
min_genes=50
max_genes=2400
min_counts=500
max_counts=4500

DGC_adata = adata_full[adata_full.obs['louvain'].isin(['0', '2', '4']), :]
DGC_adata.obsm = {}
DGC_adata.uns = {}



# DGC_adata,DGC_adata_og = preprocess_andata10x(DGC_adata,keep_NC=False)
# display(DGC_adata)
# sc.pp.filter_cells(DGC_adata, min_genes=min_genes)
# display(DGC_adata)
# sc.pp.filter_cells(DGC_adata, max_genes=max_genes)
# display(DGC_adata)
# sc.pp.filter_cells(DGC_adata, min_counts=min_counts)
# display(DGC_adata)
# sc.pp.filter_cells(DGC_adata, max_counts=max_counts)
# display(DGC_adata)
# DGC_adata = high_var_genes_dim_reduc(DGC_adata)

In [None]:
sc.tl.pca(DGC_adata, n_comps=50)
sc.pp.scale(DGC_adata)
DGC_adata = umap_and_cluster(DGC_adata, n_neighbors=25, n_pcs=30,resolution=.4,to_plot=False)
fig, ax = show_clusters(DGC_adata, 'Cluster (nn)')
fig.savefig('dgc_umap.svg')

In [None]:
sc.pl.pca_variance_ratio(DGC_adata)
DGC_adata.write_h5ad(filename='./DGC_adata.h5')

In [None]:
"""
SKIP TO HERE
"""

In [None]:
# DGC_adata = ad.read_h5ad('./DGC_adata.h5')

In [None]:
show_clusters(DGC_adata, 'Cluster (nn)')
show_gene_loading(DGC_adata, 'Apoe',size = 0.5,alpha=0.5)
show_gene_loading(DGC_adata, 'Syn1',size = 0.5,alpha=0.5)
show_gene_loading(DGC_adata, 'Slc17a7',size = 0.5,alpha=0.5)
show_gene_loading(DGC_adata, 'Prox1',size = 0.5,alpha=0.5)
show_gene_loading(DGC_adata, 'Pdzd2',size = 0.5,alpha=0.5)


In [None]:
from scipy.stats import fisher_exact
from statsmodels.stats.multitest import multipletests

def calculate_enrichment_with_stats(adata, group_var, cluster_var, alpha=0.05):
    # Contingency table of counts
    contingency_table = pd.crosstab(adata.obs[cluster_var], adata.obs[group_var])
    total_cells = contingency_table.values.sum()

    # Observed and expected proportions
    cluster_totals = contingency_table.sum(axis=1).values[:, None]  # Total in each cluster
    group_totals = contingency_table.sum(axis=0).values[None, :]  # Total in each group
    expected = cluster_totals * group_totals / total_cells  # Expected counts
    enrichment = contingency_table / expected  # Observed-to-Expected ratio

    # Perform Fisher's Exact Test for each cell
    p_values = []
    for cluster_idx in range(contingency_table.shape[0]):
        for group_idx in range(contingency_table.shape[1]):
            obs = contingency_table.iloc[cluster_idx, group_idx]
            rest_cluster = cluster_totals[cluster_idx, 0] - obs
            rest_group = group_totals[0, group_idx] - obs
            rest_total = total_cells - obs - rest_cluster - rest_group

            # Construct the 2x2 table for Fisher's test
            table = np.array([[obs, rest_group], [rest_cluster, rest_total]])
            _, p_value = fisher_exact(table, alternative="two-sided")
            p_values.append(p_value)

    # FDR correction for multiple testing
    p_values_corrected = multipletests(p_values, method="fdr_bh")[1]

    # Format results
    p_corrected_matrix = pd.DataFrame(
        np.array(p_values_corrected).reshape(contingency_table.shape),
        index=contingency_table.index,
        columns=contingency_table.columns,
    )

    return enrichment, p_corrected_matrix



In [None]:
def plot_helicopter(df, pvalues, title, cbar_label="Enrichment", ax=None,cmap="coolwarm"):
    """
    Plots a helicopter plot (heatmap) with enrichment values and p-values annotated.

    Args:
        df (pd.DataFrame): DataFrame containing enrichment values.
        pvalues (pd.DataFrame): DataFrame containing corresponding p-values.
        title (str): Title for the plot.
        cbar_label (str): Label for the color bar.
        ax (matplotlib.axes.Axes, optional): Axes to plot on. Creates a new figure if None.
    """
    # Create a new figure only if ax is None
    if ax is None:
        fig, ax = plt.subplots(figsize=(2, 1.5))

    # Calculate color scale limits for symmetry about 1
    vrange = abs(df.values - 1).max()
    vmin = 1 - vrange
    vmax = 1 + vrange

    # Function to format p-values
    def format_pvalue(p):
        if p < 0.001:
            return f"\n(p={p:.1e})"  # Scientific notation for p < 0.001
        else:
            return f"\n(p={p:.3f})"  # Decimal format for p >= 0.001

    # Initialize an empty DataFrame to hold the annotations
    annotations = pd.DataFrame(index=df.index, columns=df.columns, dtype=object)

    # Format the enrichment values and p-values
    for i in range(df.shape[0]):
        for j in range(df.shape[1]):
            enrichment = f"{df.iloc[i, j]:.2f}"
            pvalue = p_to_astk(pvalues.iloc[i, j])  # Use p_to_astk for p-value abbreviation
            annotations.iloc[i, j] = enrichment + "\n" + pvalue

    sns.heatmap(
        df,
        annot=annotations.values,
        fmt="",
        cmap=cmap,
        cbar_kws={"label": cbar_label},
        vmin=vmin,
        vmax=vmax,
        ax=ax,  # Use the provided axes or create a new one
    )
    ax.set_title(title)
    ax.set_xlabel(None)
    ax.set_ylabel("Clusters")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45)

    # Adjust layout only if no external ax is provided
    if ax is None:
        plt.tight_layout()
        plt.show()

    return ax



In [None]:
def p_to_astk(p_val):
        astk_range = [0.05,0.01,0.001,0.0001]
        num_astk = ["*" for t in astk_range if p_val <= t ]
        astk = "".join(num_astk)
        if len(astk)==0:
            # astk=str(round(p_val,2))
            astk='ns'
        return astk


In [None]:
DGC_adata.obs['E_type_age_bin'] = DGC_adata.obs['E_type'].astype(str) + "_" + DGC_adata.obs['age_bin'].astype(str)
enrichment, corrected_p_values = calculate_enrichment_with_stats(DGC_adata, group_var="E_type_age_bin", cluster_var="louvain")
cmap='PRGn'

"""
E3 vs E4 Enrichment over time
"""
fig,ax = plt.subplots(1,2,figsize=(4,1.75), width_ratios=(1,1))

subset = [c for c in enrichment.columns if 'fKI' not in c and 'E3' in c]
plot_helicopter(enrichment[subset], corrected_p_values[subset], title="ApoE3" ,ax=ax[0],cmap=cmap)

subset = [c for c in enrichment.columns if 'fKI' not in c and 'E4' in c]
plot_helicopter(enrichment[subset], corrected_p_values[subset], title="ApoE4" ,ax=ax[1],cmap=cmap)

plt.tight_layout()
fig.savefig('E3E4_heli.svg')
plt.savefig('E3E4_heli.jpeg')


"""
SynCre Compsre
"""

fig,ax = plt.subplots(1,2,figsize=(2.5,1.75), width_ratios=(1,1))

subset = [c for c in enrichment.columns if '15m' in c and 'E3' in c]
plot_helicopter(enrichment[subset], corrected_p_values[subset], title="ApoE3" ,ax=ax[0],cmap=cmap)

subset = [c for c in enrichment.columns if '15m' in c and 'E4' in c]
plot_helicopter(enrichment[subset], corrected_p_values[subset], title="ApoE4" ,ax=ax[1],cmap=cmap)

plt.tight_layout()
fig.savefig('floxE4_heli.svg')
plt.savefig('floxE4_heli.jpeg')



In [None]:
clusters = enrichment.index
size = 15
num_clust = len(clusters)
fig,ax=plt.subplots(1,num_clust+1,figsize=(1*(num_clust+1),1))
for i,c in enumerate(clusters):
    enrich = enrichment.loc[c,:]
    all_cols = enrich.index
    cols_1 = [col for col in all_cols if 'E3' in col and 'fKI' not in col]
    ax[i].plot(cols_1,enrich[cols_1],'b',label='E3',marker='o')
    ax[i].scatter(cols_1[2],enrich['E3fKI_15m'],s=size,marker='o',facecolor=None,edgecolor='b',label='E3fKI')
    cols_2 = [col for col in all_cols if 'E4' in col and 'fKI' not in col]
    ax[i].plot(cols_1,enrich[cols_2],'r',label='E4',marker='o')
    ax[i].scatter(cols_1[2],enrich['E4fKI_15m'],s=size,marker='o',facecolor=None,edgecolor='r',label='E4fKI')
    labels= [l.replace('E3_','') for l in all_cols]
    ax[i].set_title(f"Cluster {i}")
    ax[i].set_xticklabels(labels,rotation=45)
    ax[i].set_ylim([np.min(enrichment), np.max(enrichment)])
    ax[i].axhline(1,color='k',linestyle=':')
handles, labels = ax[0].get_legend_handles_labels()
ax[-1].legend(handles, labels, loc='center')
ax[-1].axis("off")
ax[0].set_ylabel('Relative Enrichment')
plt.tight_layout()
fig.savefig('enrichment_lines.svg')
fig.savefig('enrichment_lines.jpeg')

show_clusters(DGC_adata, 'Cluster (nn)',figsize=(1.5,1.5))


In [None]:
print(os.getcwd)

In [None]:
"""
Define Marker Genes
"""
sc.tl.dendrogram(DGC_adata, groupby='louvain')
sc.tl.rank_genes_groups(
    DGC_adata,
    groupby='louvain',
    method='wilcoxon',  # Wilcoxon rank-sum test for differential expression
    corr_method='benjamini-hochberg'  # Correct for multiple testing using FDR
)


In [None]:

ordered_clusters = DGC_adata.uns['dendrogram_louvain']['categories_ordered']

# Extract top-ranked genes for each cluster
top_n = 5
ranked_genes = pd.DataFrame({
    cluster: DGC_adata.uns['rank_genes_groups']['names'][cluster][:top_n]
    for cluster in ordered_clusters
})
# display(ranked_genes.head())
top_genes_all_clusters = ranked_genes.values.flatten().tolist()
top_genes_all_clusters = list(set(top_genes_all_clusters))
top_genes_all_clusters.append('Apoe')



sc.pl.rank_genes_groups_dotplot(DGC_adata, n_genes=top_n)


fig = sc.pl.heatmap(DGC_adata, var_names=top_genes_all_clusters, groupby="louvain", cmap="viridis", dendrogram=False,figsize=(6, 1.5))
plt.savefig('heatmap.svg')
plt.savefig('heatmap.jpeg')

fig = sc.pl.stacked_violin(DGC_adata, var_names=top_genes_all_clusters, groupby="louvain", swap_axes=False, dendrogram=False, figsize=(6, 1.5), return_fig=True)
fig.savefig('violins.svg')
plt.savefig('violins.jpeg')

fig = sc.pl.matrixplot( DGC_adata, var_names=top_genes_all_clusters, groupby="louvain", dendrogram=False, cmap="Blues", standard_scale="var", colorbar_title="column scaled\nexpression", figsize=(6,  1.5), return_fig=True)
fig.savefig('matix.svg')
plt.savefig('matix.jpeg')
# print(fig)

In [None]:
fig,ax = plt.subplots(figsize=(2,1))
sc.pl.violin(
    DGC_adata,
    keys='Apoe',
    groupby="louvain",
    ax=ax
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def single_deg_volcano(
    deg_df,
    label_cutoff=0.1,
    title=None,
    comp_string=None,
    lock_y_max=True,
    special_genes=[],
    use_adj=False,
    marker_size = .5,
    min_p = 1e-30,
    maxlfc =10,
    ax=None):
    '''
    Plot a single volcano plot for a DE comparison.

    Parameters:
        deg_df (DataFrame): DataFrame of DE results with gene names as index and columns including
                            'logfoldchanges', 'pvals', and optionally 'pvals_adj'.
        gene_list (list, optional): List of genes to be highlighted in the plot (default is an empty list).
        label_cutoff (float, optional): Minimum -log10(p-value) threshold for labeling genes (default is 0.1).
        title (str, optional): Title for the plot (default is an empty string).
        lock_y_max (bool, optional): If True, locks the y-axis max across subplots; adjust to data range otherwise.
        special_genes (list, optional): Genes to be specially highlighted in red (default is an empty list).
        use_adj (bool, optional): If True, use adjusted p-values; otherwise, use raw p-values (default is False).

    Returns:
        fig_volcano (Figure): Matplotlib figure object for the volcano plot.
        ax (Axes): Matplotlib axes object for the volcano plot.
    '''

    # Initialize the figure and axis for the volcano plot
    if ax is None:
        fig_volcano, ax = plt.subplots(1, 1, figsize=(1.75, 1.75))
    else:
        fig_volcano = ax.figure

    # Set the title for the plot
    if title is not None:
        ax.set_title(title.replace('_', '\n'))

    # Extract necessary data
    genes = deg_df.index
    lfc = deg_df['logfoldchanges']
    pval = deg_df['pvals_adj'] if use_adj else deg_df['pvals']
    pval[pval<min_p]=min_p
    lfc[lfc<-maxlfc]=-maxlfc
    lfc[lfc>maxlfc]=maxlfc
    neg_log10_pval = -np.log10(pval)

    # Plot each gene, marking those in gene_list and special_genes
    ax.scatter(lfc, neg_log10_pval, c='gray', s=marker_size)

    is_sig = neg_log10_pval>4
    high_lfc = abs(lfc)>=1

    selected = np.logical_and(is_sig,high_lfc)

    lfc_sig = lfc[selected]
    neg_log10_pval_sig = neg_log10_pval[selected]
    genes_sig = genes[selected]

    for i, g in enumerate(genes_sig):
        if not np.isnan(lfc_sig[i] * neg_log10_pval_sig[i]):
            ax.scatter(lfc_sig[i], neg_log10_pval_sig[i], c='r', s=marker_size)
            ax.text(lfc_sig[i], neg_log10_pval_sig[i], genes_sig[i], rotation=45, fontsize=4, color='r')

    # Set axis limits and labels
    x_lim = ax.get_xlim()
    y_lim = ax.get_ylim()
    x_etr = np.max(np.abs(x_lim))
    y_etr = np.max(np.abs(y_lim))

    ax.set_xlim(-x_etr * 1.1, x_etr * 1.1)
    if lock_y_max:
        ax.set_ylim(0, y_etr * 1.1)
    else:
        ax.set_ylim(0, ax.get_ylim()[1] * 1.1)

    if comp_string is not None:
        xlabel = f'{comp_string}\nLog2(fold change)'
    else:
        xlabel = 'Log2(fold change)'
    ax.set_xlabel(xlabel)
    ax.set_ylabel('-log10(adj-pvalue)' if use_adj else '-log10(raw-pvalue)')
    ax.axhline(-np.log10(0.05), c='k', linestyle=":", linewidth=0.6, label='p=0.05')
    ax.axvline(0, c='k', linestyle="-", linewidth=0.6, label='p=0.05')

    plt.tight_layout()

    return fig_volcano, ax




In [None]:
marker_genes = DGC_adata.uns["rank_genes_groups"]  # Assuming marker genes are stored here
clus=3

deg_df = pd.DataFrame(index=[row[clus] for row in marker_genes['names']],
                      data={'scores':[ row[clus] for row in marker_genes['scores']],
                            'pvals':[ row[clus] for row in marker_genes['pvals']],
                            'pvals_adj':[ row[clus] for row in marker_genes['pvals_adj']],
                            'logfoldchanges':[ row[clus] for row in marker_genes['logfoldchanges']],
                            })
deg_df.to_csv(f'Cluster_{clus}_deg_df.csv')
fig_volcano, ax = plt.subplots(1, 1, figsize=(8, 4))
single_deg_volcano(deg_df,min_p=-1e-30,maxlfc=4,ax=ax,title=f'Cluster {clus}')


marker_genes = DGC_adata.uns["rank_genes_groups"]  # Assuming marker genes are stored here
clus=0

deg_df = pd.DataFrame(index=[row[clus] for row in marker_genes['names']],
                      data={'scores':[ row[clus] for row in marker_genes['scores']],
                            'pvals':[ row[clus] for row in marker_genes['pvals']],
                            'pvals_adj':[ row[clus] for row in marker_genes['pvals_adj']],
                            'logfoldchanges':[ row[clus] for row in marker_genes['logfoldchanges']],
                            })
deg_df.to_csv(f'Cluster_{clus}_deg_df.csv')
fig_volcano, ax = plt.subplots(1, 1, figsize=(8, 4))
single_deg_volcano(deg_df,min_p=-1e-30,maxlfc=4,ax=ax,title=f'Cluster {clus}')

In [None]:
"""
Go Analysis
"""


In [None]:
def query_go(query_list):
    # Initialize g:Profiler
    gp = gprofiler.GProfiler(return_dataframe=True)

    # Perform GO enrichment analysis
    results = gp.profile(
        query=query_list,
        organism='mmusculus',  # Mouse organism code
        sources=['GO:BP'],    # GO Biological Process
        no_evidences=True     # Set to True to omit evidence codes
    )
    results['gene_ratio'] = results['intersection_size'] / results['term_size']
    results = results.sort_values(by='gene_ratio', ascending=False)

    # Correct p-values using FDR
    corrected_pvals = multipletests(results['p_value'], method='fdr_bh')[1]
    results['p_value_corrected'] = corrected_pvals
    results['sig'] = results['p_value_corrected'] < 0.05
    return results

def plot_go_enrichment(df, title, figsize=(4, 4),max_go=100):
    """
    Plot GO enrichment analysis results as a dot plot with swapped axes, custom color bar placement,
    and sample min/max dots in separate subplots.

    Parameters:
    df (pd.DataFrame): DataFrame containing GO enrichment results with columns:
                       'name' (GO term names), 'p_value' (p-values),
                       'intersection_size' (size of the gene intersection),
                       'gene_ratio' (ratio of genes associated with the GO term).
    title (str): Title for the plot.
    figsize (tuple): Figure size for the plot. Default is (4, 4).

    Returns:
    None
    """
    # Ensure the DataFrame contains the required columns
    required_columns = ['name', 'p_value', 'intersection_size', 'gene_ratio']
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"DataFrame must contain the following columns: {required_columns}")

    # Sort results by gene_ratio and select top 20 terms
    top_results = df.sort_values(by='gene_ratio', ascending=False).head(max_go)

    # Extract relevant data
    go_terms = top_results['name']
    gene_ratios = top_results['gene_ratio']
    p_values = top_results['p_value']
    intersection_size = top_results['intersection_size']

    # Calculate -log10(p-value) for color
    neg_log_p_values = -np.log10(p_values)

    # Scale sizes so that intersection_size 1 maps to 10 and max maps to 200
    min_size, max_size = 10, 200
    min_intersection, max_intersection = intersection_size.min(), intersection_size.max()
    sizes = min_size + ((intersection_size - min_intersection) / (max_intersection - min_intersection)) * (max_size - min_size)

    # Round sizes to the nearest 10 increment
    sizes = np.round(sizes / 10) * 10

    # Create figure and subplots with custom size
    fig, axs = plt.subplots(2, 2, figsize=figsize, gridspec_kw={'width_ratios': [4, 0.35]})

    # Main scatter plot spanning both rows in the first column
    ax_main = plt.subplot2grid((2, 2), (0, 0), rowspan=2)
    scatter = ax_main.scatter(gene_ratios, go_terms, s=sizes,  # gene_ratio on x-axis, GO Term on y-axis
                              c=neg_log_p_values, cmap='cool', alpha=0.7, edgecolors='k', linewidth=0.5)

    # Add labels and title
    ax_main.set_xlabel('Gene Ratio')
    ax_main.set_ylabel('GO Term')
    ax_main.set_title(title)

    # Set x-axis limits to start at 0
    ax_main.set_xlim(left=0)

    # Flip y-axis to show the most enriched terms at the top
    ax_main.invert_yaxis()

    # Add custom color bar outside the main plot
    cax = axs[0, 1]
    cbar = plt.colorbar(scatter, cax=cax, orientation='vertical')
    cbar.set_label('-log10(p-value)')

    # Plot sample size dots
    ax_dots = axs[1, 1]
    min_intersection_dot = 1
    max_intersection_dot = max_intersection
    num_dots = 5
    sample_sizes = np.linspace(min_intersection_dot, max_intersection_dot, num_dots)  # Generate sizes from min to max

    for i, size in enumerate(sample_sizes):
        dot = ax_dots.scatter(1, i, s=size*10, color='black', edgecolors='w', linewidth=0.5, zorder=5)  # Scale by 10 for visibility
        # Annotate the dots with their intersection sizes, with left horizontal alignment and offset for visibility
        ax_dots.text(2, i, f'{int(size)}', ha='left', va='bottom', color='black')

    # Hide all axes and labels on this subplot
    ax_dots.set_xticks([])
    ax_dots.set_yticks([])
    ax_dots.set_xlim([0, 3])
    ax_dots.spines['top'].set_visible(False)
    ax_dots.spines['right'].set_visible(False)
    ax_dots.spines['left'].set_visible(False)
    ax_dots.spines['bottom'].set_visible(False)

    # Set limits
    ax_dots.set_xlim(-1, 2)
    ax_dots.set_ylim(-1, num_dots)

    # Hide left plots and axes
    axs[0, 0].axis('off')
    axs[1, 0].axis('off')

    # Adjust layout for better fit
    plt.tight_layout()

    # Save the plot as SVG
    plt.savefig('GO_enrichment_dot_plot.svg', bbox_inches='tight')

    # Show the plot
    plt.show()
    return (fig, axs)


def high_low_go_split(deg_df,p_thresh=0.05,lfc_thresh=.5):
    high_df = deg_df[deg_df['logfoldchanges']>lfc_thresh]
    high_df = high_df[high_df['pvals_adj']<.05]
    high_df = high_df.sort_values(by='pvals_adj',ascending=True)

    low_df = deg_df[deg_df['logfoldchanges']<-lfc_thresh]
    low_df = low_df[low_df['pvals_adj']<.05]
    low_df = low_df.sort_values(by='pvals_adj',ascending=True)

    return high_df,low_df




In [None]:
for clus in range(len(marker_genes['names'][0])):
    deg_df = pd.DataFrame(index=[row[clus] for row in marker_genes['names']],
                        data={'scores':[ row[clus] for row in marker_genes['scores']],
                                'pvals':[ row[clus] for row in marker_genes['pvals']],
                                'pvals_adj':[ row[clus] for row in marker_genes['pvals_adj']],
                                'logfoldchanges':[ row[clus] for row in marker_genes['logfoldchanges']],
                                })
    high_df,low_df = high_low_go_split(deg_df,lfc_thresh=1)
    results = query_go(high_df.index.to_list())
    (fig, axs) = plot_go_enrichment(results, f'Cluster{clus} Up Regulated', figsize=(6, 2),max_go=20)
    fig.savefig(f'GO_Cluster{clus} Up Regulated.svg')
    fig.savefig(f'GO_Cluster{clus} Up Regulated.jpeg')


In [None]:
# Visualize UMAP
sc.pl.umap(DGC_adata, color='louvain')


# # Compute PAGA graph
# sc.tl.paga(DGC_adata, groups='louvain')  # Replace 'louvain' with the column in `.obs` representing clusters

# # Plot PAGA graph overlaid on UMAP
# fig,ax=plt.subplots(figsize=(1.5,1.5))
# sc.pl.paga(DGC_adata, color='louvain',ax=ax)
# # sc.pl.umap(DGC_adata, color='louvain', edges=True)




In [None]:
import svgutils

In [None]:
print(os.getcwd())

In [None]:
# row_0 =0
# row_1 =row_0+160
# row_2 =row_1+170
# row_3 = row_2+190
# row_4 = row_3+215
# dpi = 96
# width = 7.25*dpi
# length = 9.4*dpi
# scale_factor= 1.33
# weight='bold'
# label_size = 12
# my_figure = Figure(str(width), str(length),

#                     Panel(SVG(".svg"),
#                     Text("a", 10, 20, size=label_size,weight=weight,font='arial')
#                     ).scale(scale_factor).move(width*0, row_0),
#                     Panel(SVG("/content/DG_Type_II_GC_main_parameters_Clean/CrossVal_Data k=2 Centers.svg"),
#                     Text("b", 10, 20, size=label_size,weight=weight)
#                     ).scale(scale_factor).move(width*.25, row_0),
#                     Panel(SVG("/content/DG_Type_I_GC_main_parameters_Clean/CrossVal_Data k=2 Centers.svg"),
#                     Text("c", 10, 20, size=label_size,weight=weight)
#                     ).scale(scale_factor).move(width*.50, row_0),
#                     Panel(SVG("/content/CA1_main_parameters_Clean/CrossVal_Data k=2 Centers.svg"),
#                     Text("d", 10, 20, size=label_size,weight=weight)
#                     ).scale(scale_factor).move(width*.75, row_0),

In [None]:

# try: del adata
# except: None
# try: del adata_QC
# except: None
# try: del adata_dict
# except: None
# try: del adata_full
# except: None
# try: del adata
# except: None

# import gc
# gc.collect()

# %whos


In [None]:
# """
# Weighted gene co-expression network analysis (WGCNA)
# """
# # !pip install PyWGCNA
# from PyWGCNA import WGCNA

# DGC_wgcna = WGCNA(anndata=DGC_adata)
# DGC_wgcna.geneExpr.to_df().head(5)


In [None]:
# DGC_wgcna.preprocess()

In [None]:
# DGC_wgcna.findModules()

In [None]:
# test
#
print(DGC_adata)

In [None]:
np.sum(DGC_adata.var['highly_variable'])

In [None]:
DGC_adata_hvg = DGC_adata[:, DGC_adata.var['highly_variable']].copy()

In [None]:
norm_counts =  DGC_adata_hvg.X
print(norm_counts.shape)
correlation_matrix_hvg = np.corrcoef(norm_counts.T)
fig,ax = plt.subplots(figsize=(3,3))
correlation_matrix_hvg_filt = correlation_matrix_hvg.copy()
correlation_matrix_hvg_filt[abs(correlation_matrix_hvg_filt<0.01)]=0
cbh = ax.imshow(correlation_matrix_hvg_filt,aspect='auto',vmax=0.2)
fig.colorbar(cbh, ax=ax, label="Correlation Coefficient")

In [None]:
def calculate_scale_free_fit(correlation_matrix, beta):
    """
    Calculate the scale-free fit index for a given beta value.
    """
    # Raise the correlation matrix to the power of beta to create the adjacency matrix
    adjacency_matrix = np.abs(correlation_matrix) ** beta

    # Calculate degree (sum of connections for each gene)
    degree = np.sum(adjacency_matrix, axis=0)

    # Calculate the frequency distribution of degrees
    degree_freq = np.bincount(np.round(degree).astype(int))

    # Avoid issues with 0 frequencies
    valid_degrees = degree_freq > 0
    x = np.arange(len(degree_freq))[valid_degrees]
    y = degree_freq[valid_degrees]

    # Fit a linear model to log-log scale
    log_x = np.log10(x)
    log_y = np.log10(y)
    slope, r_squared = np.polyfit(log_x, log_y, 1)

    return r_squared

def sweep_beta_values(correlation_matrix, beta_range):
    """
    Test a range of beta values and calculate the scale-free fit index for each.
    """
    results = []
    for beta in beta_range:
        try:
            r_squared = calculate_scale_free_fit(correlation_matrix, beta)
            results.append((beta, r_squared))
        except Exception as e:
            print(f"Failed for beta={beta}: {e}")
            results.append((beta,np.nan))
    return results

beta_values  = np.round(np.arange(0.4, 5, .2),1)
beta_results = sweep_beta_values(correlation_matrix_hvg_filt, beta_values)

# Extract beta and R-squared values
betas, r_squareds = zip(*beta_results)

# Plot the scale-free topology fit index
plt.figure(figsize=(2, 2))
plt.plot(betas, r_squareds, marker='o', linestyle='-', color='b')
plt.xlabel('Soft-Thresholding Power (Beta)', fontsize=12)
plt.ylabel('Scale-Free Fit Index (R^2)', fontsize=12)
plt.title('Scale-Free Topology Fit Index vs Beta', fontsize=14)
plt.grid(True)
plt.show()

In [None]:
# Extract beta and R-squared values
betas, r_squareds = zip(*beta_results)
optimal_beta_index = np.argmax(r_squareds)
optimal_beta = betas[optimal_beta_index]

print(f"Optimal Beta: {optimal_beta}")

# Step 2: Apply soft-thresholding with the optimal beta
correlation_matrix = correlation_matrix  # Replace with your actual correlation matrix
adjacency_matrix = np.abs(correlation_matrix) ** optimal_beta

# Step 3: Apply a threshold to the adjacency matrix
threshold = 0.0001  # Example threshold, adjust as needed
adjacency_matrix[adjacency_matrix < threshold] = 0

# Step 4: Calculate the degree distribution (number of connections for each gene)
degree_distribution = np.sum(adjacency_matrix, axis=0)

# Remove zero degrees (genes with no connections)
degree_distribution = degree_distribution[degree_distribution > 0]

# Step 5: Log-log plot of degree distribution
plt.figure(figsize=(2, 2))
plt.loglog(np.arange(1, len(degree_distribution)+1), np.sort(degree_distribution)[::-1], marker='o', color='blue')

plt.title(f'Log-Log Plot of Degree Distribution for Optimal Beta = {optimal_beta}')
plt.xlabel('Degree (Log Scale)')
plt.ylabel('Frequency (Log Scale)')
plt.grid(True)
plt.tight_layout()
plt.show()