<a href="https://colab.research.google.com/github/dtabuena/Workshop/blob/main/RNA_Workshop/Explore_Nell2_Pathways_both.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scanpy --quiet
!pip install pybiomart --quiet
!pip install python-igraph --quiet
!pip install louvain --quiet
!pip install pynndescent --quiet


!pip install scipy

In [None]:
import h5py
import numpy as np
from matplotlib import pyplot as plt
import scanpy as sc
import tarfile
import os
import anndata as ad
import pandas as pd
import pybiomart
from tqdm import tqdm
import urllib.request
from IPython.display import clear_output
from matplotlib.pyplot import rc_context
import scipy
import logging
import itertools

In [None]:
!pip install goatools

In [None]:
import goatools

In [None]:
### quick import
import urllib
import json
from matplotlib import rcParams
def import_mpl_config(FS=6):
    """ Load my default plotting parameters """
    if os.path.isfile(f'./mpl_config_FS{FS}.json'):
        os.remove(f'./mpl_config_FS{FS}.json')
    _ = urllib.request.urlretrieve('https://github.com/dtabuena/Resources/'\
                                   'raw//main/Matplotlib_Config/'\
                                   f'mpl_config_FS{FS}.json',
                                   f'mpl_config_FS{FS}.json')
    with open(f"./mpl_config_FS{FS}.json",'r') as import_file:
        fig_config = json.load(import_file)
    rcParams.update(fig_config)
    return fig_config
_ = import_mpl_config()



In [None]:
def init_GO():
    # !pip install goatools --quiet


    '''Get Gene Lists and metadata from ncbi'''
    import os
    import urllib.request
    gene_list_url='https://raw.githubusercontent.com/dtabuena/Resources/main/GO%20Files/gene_result.txt'
    urllib.request.urlretrieve(gene_list_url, 'gene_result.txt')
    scripts_path = [p for p in os.environ['PATH'].split(';') if 'Scripts' in p][0]
    ncbi_path = os.path.join(scripts_path,'ncbi_gene_results_to_python.py')
    !python $ncbi_path -o genes_ncbi_mus_musculus_proteincoding.py gene_result.txt
    from genes_ncbi_mus_musculus_proteincoding import GENEID2NT as GeneID2nt_mus



    '''Get Key Funcs'''
    from goatools.base import download_go_basic_obo
    from goatools.base import download_ncbi_associations
    from goatools.obo_parser import GODag
    from goatools.anno.genetogo_reader import Gene2GoReader
    from goatools.goea.go_enrichment_ns import GOEnrichmentStudyNS

    '''Download Current Go Annotations'''
    obo_fname = download_go_basic_obo()
    fin_gene2go = download_ncbi_associations()
    obodag = GODag("go-basic.obo")




    '''Get Mapper from Symbol to Gene and Inv'''
    mapper = {}
    for key in GeneID2nt_mus:
        mapper[GeneID2nt_mus[key].Symbol] = GeneID2nt_mus[key].GeneID
    inv_map = {v: k for k, v in mapper.items()}



    '''Read NCBI's gene2go. Store annotations in a list of namedtuples '''
    objanno = Gene2GoReader(fin_gene2go, taxids=[10090])
    # Get namespace2association where:
    #    namespace is:
    #        BP: biological_process
    #        MF: molecular_function
    #        CC: cellular_component
    #    assocation is a dict:
    #        key: NCBI GeneID
    #        value: A set of GO IDs associated with that gene
    ns2assoc = objanno.get_ns2assc()


    '''Create a GO Object'''
    goeaobj = GOEnrichmentStudyNS(
            GeneID2nt_mus.keys(), # List of mouse protein-coding genes
            ns2assoc, # geneid/GO associations
            obodag, # Ontologies
            propagate_counts = False,
            alpha = 0.05, # default significance cut-off
            methods = ['fdr_bh']) # defult multipletest correction method


    ''' PASS '''
    num_terms = {}
    GO_items = []
    temp = goeaobj.ns2objgoea['BP'].assoc
    num_terms['biological_process']=len(temp)
    for item in temp:
        GO_items += temp[item]
    temp = goeaobj.ns2objgoea['CC'].assoc
    num_terms['cellular_component']=len(temp)
    for item in temp:
        GO_items += temp[item]
    temp = goeaobj.ns2objgoea['MF'].assoc
    num_terms['molecular_function']=len(temp)
    for item in temp:
        GO_items += temp[item]



    def go_it(test_genes):
        ''' Quick Access Function for doing the GO associations '''
        logging.info(f'input genes: {len(test_genes)}')
        mapped_genes = []
        for gene in test_genes:
            try:
                mapped_genes.append(mapper[gene])
            except:
                pass
        logging.info(f'mapped genes: {len(mapped_genes)}')
        goea_results_all = goeaobj.run_study(mapped_genes)
        goea_results_sig = [r for r in goea_results_all if r.p_fdr_bh < 0.05]
        GO = pd.DataFrame(list(map(lambda x: [x.GO, x.goterm.name, x.goterm.namespace, num_terms[x.goterm.namespace], x.p_uncorrected, x.p_fdr_bh,\
                    x.ratio_in_study[0], x.ratio_in_study[1], GO_items.count(x.GO), list(map(lambda y: inv_map[y], x.study_items)),\
                    ], goea_results_sig)), columns = ['GO', 'term', 'class', 'class_size', 'p', 'p_corr', 'n_genes',\
                                                        'n_study', 'n_go', 'study_genes'])
        GO = GO[GO.n_genes > 1]
        GO['enrich_ratio'] = ( GO['n_genes']/GO['n_study'] ) / ( GO['n_go']/GO['class_size'] )
        return GO


    all_GO = None
    goea_results_all = goeaobj.run_study([mapper['Apoe']])
    all_GO = pd.DataFrame(list(map(lambda x: [x.GO, x.goterm.name, x.goterm.namespace, num_terms[x.goterm.namespace], GO_items.count(x.GO),
                        ], goea_results_all)), columns = ['GO', 'term', 'class', 'class_size', 'n_go',])


    return go_it, num_terms ,goeaobj,goea_results_all,all_GO

os.chdir("C:/Users/dennis.tabuena/Dropbox (Gladstone)/0_Projects/_Hyper+Crisper/_Nell2_enrichment/")
go_it, num_terms, goeaobj,goea_results_all,all_GO = init_GO()

In [None]:
def trim_key(k):
    floxed_dict = {'GSM5106175_YH_KZ03_01':('E3fKI_Syn_Cre602_15m','GSM5106175_602_E3fKI_15_XX'),
                   'GSM5106176_YH_KZ03_03':('E4fKI_Syn_Cre475_15m','GSM5106176_475_E4fKI_15_XX')}
    for f in floxed_dict.keys():
        if f in k: return floxed_dict[f][1]
    k = k.replace('_raw_gene_bc_matrices_h5.h5',"")
    return k
def query_capitilaziation(gene,adata):
    try:
        return adata.var.index[ [g.lower() for g in list(adata.var.index)].index(gene.lower()) ]
    except:
        return gene + ' not_found'
def z_score(x,axis=-1):
    x=np.array(x)
    return (x-np.mean(x,axis=axis))/np.std(x,axis=axis)



In [None]:
os.chdir("C:/Users/dennis.tabuena/Dropbox (Gladstone)/0_Projects/_Seurat_Scanpy/Scanpy_data/")
adata = sc.read_h5ad('./2023_11_07_KZ_anndata.h5ad')

In [None]:
display(adata.obs.head())

meta_df = pd.read_csv('./kz_metadata.csv').set_index('Barcodes')

adata_meta= adata.copy()
adata_meta.obs["Cluster_ID"]=np.nan
adata_meta.obs["Cluster_ID"]= meta_df["Cluster_ID"]
display(adata_meta.obs.head())

all_cats = list(set(meta_df["Cluster_ID"]))
print(all_cats)

In [None]:
########### Sub Divide Clusters of Interest
os.chdir("C:/Users/dennis.tabuena/Dropbox (Gladstone)/0_Projects/_Hyper+Crisper/_Nell2_enrichment/")
if not os.path.exists('./figures'):
    os.makedirs('./figures')


# dgc_01_adata =  adata_meta[adata_meta.obs["Cluster_ID"] == '01 Dentate Gyrus Granule Cells'].copy()
# display(dgc_01_adata)
dgc_02_adata =  adata_meta[adata_meta.obs["Cluster_ID"] == '02 Dentate Gyrus Granule Cells'].copy()
display(dgc_02_adata)
CA3_06_adata =  adata_meta[adata_meta.obs["Cluster_ID"] == '06 CA2/CA3 Pyramids'].copy()
display(CA3_06_adata)


In [None]:
import statsmodels

In [None]:
def find_corr_genes(adata,gene,pct=1,to_plot=True,fig_size =(1.75,1)):
    X_ = adata.X.toarray()
    gene_ind = list(adata.var['name']).index(gene)
    target_row = X_[:,gene_ind]
    num_genes= X_.shape[1]
    gene_corr = np.ones([num_genes])
    p_vals = np.zeros([num_genes])
    for g in range(num_genes):
        gene_corr[g],p_vals[g] = scipy.stats.pearsonr( target_row,X_[:,g])


    gene_corr[np.isnan(gene_corr)] = 0
    gene_rank =np.argsort(np.argsort(-gene_corr))
    gene_corr_plot=gene_corr.copy()
    gene_corr_plot[gene_ind]=np.nan

    low,high = np.nanpercentile(gene_corr,[pct,100-pct])

    high_bool = gene_corr>high
    low_bool = gene_corr<low

    high_names=adata.var['name'][high_bool]
    low_names=adata.var['name'][low_bool]

    FDR_p_vals = p_vals.copy()
    p_not_nan=np.logical_not(np.isnan(p_vals))
    FDR_bool = np.zeros_like(FDR_p_vals,dtype=bool)
    (FDR_bool[p_not_nan],
     FDR_p_vals[p_not_nan]) = statsmodels.stats.multitest.fdrcorrection(
          p_vals[p_not_nan], alpha=0.05, method='indep', is_sorted=False)

    fig_R_rank, axr=plt.subplots(1,1,figsize=fig_size,dpi=300)
    fig_R_pval, axp=plt.subplots(1,1,figsize=fig_size,dpi=300)
    if to_plot:
        # fig,ax=plt.subplots(1,2,figsize=(3,2),dpi=300)
        vmin = np.nanmin(gene_corr_plot)
        vmax = np.nanmax(gene_corr_plot)
        print(vmin,vmax)
        axr.scatter(gene_rank/num_genes*100,gene_corr_plot,s=1,
                      c=gene_corr_plot,cmap='coolwarm',vmin=vmin,vmax=vmax)

        axr.axhline(high,color='grey',linewidth=1)
        axr.axhline(low,color='gray',linewidth=1)

        axr.axhline(high,color='grey',linewidth=1)
        axr.axhline(low,color='gray',linewidth=1)

        axr.set_xlabel('Gene Percentile Rank')
        axr.set_ylabel('Pearson R')


        axp.scatter(-np.log10(FDR_p_vals),gene_corr,s=1,color='grey')
        axp.scatter(-np.log10(FDR_p_vals)[FDR_bool], gene_corr[FDR_bool],
                      s=1, color='r')
        axp.set_xlabel('-np.log10(FDR p-value)')
        axp.set_ylabel('Pearson R')


    plt.show()

    results_dict = {'gene_corr':gene_corr,
                    'high_bool':high_bool,
                    'low_bool':low_bool,
                    'both_bool':np.logical_or(high_bool,low_bool),
                    'high_names':high_names,
                    'low_names':low_names,
                    'both_names': np.concatenate([high_names,low_names]),
                    'gene_rank':gene_rank}

    results_dict['table'] = pd.DataFrame({'gene_name':adata.var['name'],
                                          'gene_rank':gene_rank,
                                          'gene_corr':gene_corr,
                                          'p_vals':p_vals,
                                          'FDR_p_vals':FDR_p_vals,
                                          'high_bool':high_bool,
                                          'low_bool':low_bool,
                                          'both_bool':np.logical_or(high_bool,low_bool),})

    results_dict['table'] = results_dict['table'].set_index('gene_name')
    results_dict['table'] = results_dict['table'].sort_values('gene_rank',
                                                              axis=0)
    return results_dict, fig_R_rank, fig_R_pval



pct = 5
(dgc_02_nell_results,
 dgc_nell2_fig_R_rank,
 dgc_nell2_fig_R_pval) = find_corr_genes(dgc_02_adata,'Nell2',pct=pct)
dgc_nell2_fig_R_rank.savefig('./figures/dgc_nell2_fig_R_rank.svg',format='svg')
dgc_nell2_fig_R_pval.savefig('./figures/dgc_nell2_fig_R_pval.svg',format='svg')

(dgc_02_apoe_results,
 dgc_apoe_fig_R_rank,
 dgc_apoe_fig_R_pval) = find_corr_genes(dgc_02_adata,'Apoe',pct=pct)
dgc_apoe_fig_R_rank.savefig('./figures/dgc_apoe_fig_R_rank.svg',format='svg')
dgc_apoe_fig_R_pval.savefig('./figures/dgc_apoe_fig_R_pval.svg',format='svg')

(CA3_06_nell_results,
 CA3_nell2_fig_R_rank,
 CA3_nell2_fig_R_pval) = find_corr_genes(CA3_06_adata,'Nell2',pct=pct)
CA3_nell2_fig_R_rank.savefig('./figures/CA3_nell2_fig_R_rank.svg',format='svg')
CA3_nell2_fig_R_pval.savefig('./figures/CA3_nell2_fig_R_pval.svg',format='svg')

(CA3_06_apoe_results,
 CA3_apoe_fig_R_rank,
 CA3_apoe_fig_R_pval) = find_corr_genes(CA3_06_adata,'Apoe',pct=pct)
CA3_apoe_fig_R_rank.savefig('./figures/CA3_apoe_fig_R_rank.svg',format='svg')
CA3_apoe_fig_R_pval.savefig('./figures/CA3_apoe_fig_R_pval.svg',format='svg')


In [None]:

def write_tables(res,prefix):
    high_table = res['table']
    high_table = high_table[high_table['high_bool']]
    high_table.to_csv(prefix+'_high_corr_names.csv')

    low_table = res['table']
    low_table = low_table[low_table['low_bool']]
    low_table.to_csv(prefix+'_low_corr_names.csv')

    both_table = res['table']
    both_table = both_table[both_table['both_bool']]
    both_table.to_csv(prefix+'_both_corr_names.csv')

    return None

os.chdir("C:/Users/dennis.tabuena/Dropbox (Gladstone)/0_Projects/_Hyper+Crisper/_Nell2_enrichment/")
write_tables(CA3_06_nell_results,'CA3_06_nell')
write_tables(CA3_06_apoe_results,'CA3_06_apoe')
write_tables(dgc_02_nell_results,'dgc_02_nell')
write_tables(dgc_02_apoe_results,'dgc_02_apoe')



In [None]:
logging.basicConfig(filename='example.log',level=logging.DEBUG)
def get_onto_from_dict(dict_results):
    h = len(dict_results['high_names'])
    l = len(dict_results['low_names'])
    logging.info(f"{h} high genes and {l} low genes")
    # dict_results['go_high'] = go_it(dict_results['high_names'])
    # dict_results['go_low'] = go_it(dict_results['low_names'])
    dict_results['go_both'] = go_it(dict_results['both_names'])
    return dict_results


CA3_06_nell_results = get_onto_from_dict(CA3_06_nell_results)
CA3_06_apoe_results = get_onto_from_dict(CA3_06_apoe_results)
dgc_02_nell_results = get_onto_from_dict(dgc_02_nell_results)
dgc_02_apoe_results = get_onto_from_dict(dgc_02_apoe_results)



In [None]:
go_group_dict={'CA3 Nell2':CA3_06_nell_results['go_both'],
               'CA3 ApoE':CA3_06_apoe_results['go_both'],
               'DGC Nell2':dgc_02_nell_results['go_both'],
               'DGC ApoE':dgc_02_apoe_results['go_both'],
               }

for k,v in go_group_dict.items():
    in_bio = ['biological_process' in c for c in v['class']]
    go_group_dict[k] = v[in_bio].set_index('GO')
    # go_group_dict[k] = list(v['GO'])


In [None]:
def membership_dict_to_df(group_dict: pd.DataFrame) -> pd.DataFrame:
    """
    get dict with keys:group and values: pd.DataFrame of emrichment.
    convert dict to a pd.DataFrame with index:items and
    columns for membership in label. A numerical embeding
    of membership overlaps is created for simplicity after.
    A look up dict its returned to interpred embeddings
    """

    inicies = [x for v in group_dict.values() for x in v.index]
    # inicies = [x for v in group_dict.values() for x in v]
    inicies = list(set(inicies))
    mebership_df = pd.DataFrame(index=inicies,columns=list(group_dict.keys()))
    mebership_df['emb_combo'] = np.nan
    possible_combos = list(itertools.product([True,False], repeat=len(group_dict.keys())))
    combo_sum = [np.sum(c) for c in possible_combos]
    possible_combos = [possible_combos[i] for i in np.argsort(combo_sum)]
    embed_combos = {c:i for i,c in enumerate(possible_combos)}
    for r in mebership_df.index:
        row_bool = [ r in vals.index for vals in group_dict.values()]
        mebership_df.loc[r,list(group_dict.keys())] = row_bool
        mebership_df.loc[r,'emb_combo'] = embed_combos[tuple(row_bool)]
    return mebership_df, embed_combos




In [None]:
def upset_plot(group_dict, figsize=(3,2), exclude_all_none=True,write_name=None):

    'group_dict->dict with keys=categories/groups and values=list(members)'

    mebership_df, embed_combos = membership_dict_to_df(go_group_dict)
    possible_combos = list(embed_combos.keys())
    if exclude_all_none: #exclude members that do not intersect any group
        possible_combos = [c for c in possible_combos if np.array(c).any()]
        false_tupple = tuple(np.full((len(possible_combos[0])),False))
        del embed_combos[false_tupple]

    fig,ax=plt.subplots(2,2,figsize=figsize,width_ratios=[3,.5],height_ratios=[3,1],dpi=300)
    null_ax = ax[0,1]
    null_ax.axis('off')
    combo_ax = ax[1,0]
    overlap_ax = ax[0,0]
    set_size_ax = ax[1,1]

    group_names = [c for c in mebership_df.columns if str(c) not in 'emb_combo']


    """Draw Dots and Connect"""
    true_xy = np.where(possible_combos)
    false_xy = np.where(np.logical_not(possible_combos))
    dot_size=12
    combo_ax.scatter(true_xy[0],true_xy[1],color='k',s=dot_size)
    combo_ax.scatter(false_xy[0],false_xy[1],color='lightgrey',s=dot_size)
    combo_ax.set_yticks(range(4),group_names)
    combo_ax.set_xticks([])
    combo_ax.set_ylim([-.5,len(group_names)-.5])
    my_map = np.cumsum(np.ones_like(possible_combos),axis=1)*possible_combos
    for row,vals in enumerate(my_map):
        vals_nz = [v for v in vals if v > 0]
        if np.sum(vals>0)>1:
            combo_ax.plot([row,row],[np.min(vals_nz)-1,np.max(vals_nz)-1],color='k',linewidth=.5 )



    """Plot Overlap Bars"""
    intersection_counts, bin_edges = np.histogram(mebership_df['emb_combo'],
                                                  bins=len(possible_combos),
                                                  range=[np.min(mebership_df['emb_combo'])-.5,np.max(mebership_df['emb_combo'])+.5]) #
    overlap_ax.bar(bin_edges[:-1]-.5,intersection_counts,color='k')
    overlap_ax.set_xlim(combo_ax.get_xlim() )
    overlap_ax.set_xticks([])
    overlap_ax.set_ylabel('Intersection (#)')


    """Plot Groups Sizes"""
    set_size_ax.barh( list(group_dict.keys()), [len(v) for v in group_dict.values()],color='k'   )
    set_size_ax.set_yticks([])
    set_size_ax.set_xlabel('Group Size (#)')

    from matplotlib import rcParams
    rcParams.update({'figure.autolayout': True})

    if write_name is not None:
        fig.savefig(f'./{write_name}.svg',format='svg',dpi=300,bbox_inches="tight")
    return fig, mebership_df, embed_combos


upset_go_fig, mebership_df,embed_combos = upset_plot(go_group_dict,write_name='_Correlation_GO_Terms')
upset_go_fig.savefig('./figures/upset_go_fig.svg',format='svg')

df = all_GO.copy().set_index('GO')
df = df[ [go in mebership_df.index.to_list() for go in df.index] ]
mebership_df = pd.concat([mebership_df, df], axis=1)


mebership_df.to_csv('./nell_apoe_CA_DG_go_term_intersections.csv')


In [None]:
import matplotlib as mpl

In [None]:
def top_gos(go_df_dict,select_keys,number=15):

    go_group_dict = {key: val for key,val in go_df_dict.items() if key in select_keys}

    mebership_df, embed_combos = membership_dict_to_df(go_df_dict)
    membership_tuple = tuple([k in select_keys for k in go_df_dict.keys()])
    combo = embed_combos[membership_tuple]
    all_member_gos = mebership_df[mebership_df['emb_combo']==combo].index.to_list()

    sel_go_df_dict = go_df_dict.copy()
    bad_keys = [k for k in sel_go_df_dict.keys() if k not in select_keys]
    for k in bad_keys:
        del sel_go_df_dict[k]

    for key,val in sel_go_df_dict.items():
        sel_go_df_dict[key].loc[:,'rank'] = np.argsort(np.argsort(np.array(val['p_corr'])))

    rank_list = [  [val.loc[g,'rank'] for val in sel_go_df_dict.values()]     for g in all_member_gos]
    rank_magnitudes = [  np.sum(np.array(vec)**2) for vec in  rank_list]

    magnitude_ranks = np.argsort(np.argsort(rank_magnitudes))

    select_gos = list(np.array(all_member_gos)[np.argsort(np.argsort(rank_magnitudes))<number])

    # fig,ax=plt.subplots(1,figsize=(1.5,1.5),dpi=300)
    # if len(select_keys)==2:
    #     for g in all_member_gos:
    #         x = sel_go_df_dict[select_keys[0]].loc[g,'p_corr']
    #         y = go_df_dict[select_keys[1]].loc[g,'p_corr']
    #         x,y = -np.log10((x,y))
    #         if g in select_gos:
    #             ax.scatter(x,y,c='r',s=1)
    #         else:
    #             ax.scatter(x,y,c='k',s=1)
    #     ax.set_xlabel(f'-log10(p-value)\n{select_keys[0]} GO Enrichment')
    #     ax.set_ylabel(f'-log10(p-value)\n{select_keys[0]} GO Enrichment')
    # else:
    #     x_ticks = np.arange(len(all_member_gos))
    #     x_labels_sorted = np.array(all_member_gos)[magnitude_ranks]
    #     bar_heights_sorted = np.array(rank_magnitudes)[magnitude_ranks]
    #     ax.bar(x_ticks, bar_heights_sorted)
    #     ax.set_xticks(x_ticks,x_labels_sorted)
    #     ax.invert_yaxis()

    return select_gos,sel_go_df_dict


def plot_go_bubbles( sel_go_df_dict, select_gos, vmin=0, vmax=20,figsize=(2,2),hidden_legend=False):
    cmap = mpl.colormaps.get_cmap('cividis')

    num_dfs = len(sel_go_df_dict)
    fig = plt.figure(layout=None ,figsize=figsize,dpi=300,tight_layout=True)
    gs = fig.add_gridspec(nrows=2, ncols=num_dfs+1,width_ratios=[1]*num_dfs+[.3])

    ax = list()
    for a in range(num_dfs):
        ax.append(fig.add_subplot(gs[:2, a]))
    ax.append(fig.add_subplot(gs[0, -1]))
    ax.append(fig.add_subplot(gs[1, -1]))

    for df_i, (key,df) in  enumerate(sel_go_df_dict.items()):
        for i,g in enumerate(select_gos):
            x = df.loc[g,'enrich_ratio']
            y = i
            size = df.loc[g,'n_genes']
            color_level = -np.log10(df.loc[g,'p_corr'])
            cbr = ax[df_i].scatter(x,y,s=size,c=color_level,cmap=cmap,vmin=vmin,vmax=vmax)
            ax[df_i].set_title(key)

    ### Labeling
    terms = [df.loc[g,'term'] for g in select_gos]
    ax[0].set_yticks(range(len(terms)), terms)
    ax[0].set_ylabel('Biological Function\nGene Ontology Terms')
    for a in ax[1:-2]:
        a.set_yticks(ax[0].get_yticks())
        a.set_yticklabels([])


    ### Colorscale
    plt.colorbar(cbr,cax=ax[-1])
    ax[-1].set_ylabel('-log10(p-value)')


    ### Dot Scale
    sizes = [5,10,20,50]
    n=len(sizes)
    ax[-2].scatter(np.ones(n),np.linspace(0,1,n,endpoint=False),s=sizes,color='k')
    ax[-2].text(1.2,.5,'Gene #',rotation=90,ha='center',va='center')
    ax[-2].axis('off')
    for i,s in enumerate(sizes):
        ax[-2].text(1,i/n,f"     {s}",ha='left',va='center')
    ax[-2].set_ylim(-.2,1.2)


    ### Formating
    x_lim_list=list()

    for a in ax[:-2]:
        x_lim_list.append(a.get_xlim())
        a.set_xticks(np.arange(0,8,2))
        a.set_ylim(-.75,len(select_gos)-.25)
        a.set_xlabel('Fold Enrichment')
        a.grid(visible=True)
        a.set_axisbelow(True)
    x_lim_list = np.stack(x_lim_list)
    common_lim = (np.floor(np.min(x_lim_list[:,0])),np.ceil(np.max(x_lim_list[:,1])))
    print(common_lim)
    for a in ax[:-2]:
        a.set_xlim(common_lim)

    if hidden_legend:
        for a in ax[-2:]:
            print('AAAAAA')
            a.set_visible(False)
    plt.tight_layout()
    return fig




In [None]:


select_gos, sel_go_df_dict = top_gos(go_group_dict,select_keys=['CA3 Nell2','CA3 ApoE','DGC ApoE','DGC Nell2'])
both_bubble_fig = plot_go_bubbles( sel_go_df_dict, select_gos, figsize=(5,2),hidden_legend=True)
both_bubble_fig.savefig('./figures/both_bubble_fig.svg',format='svg',bbox_inches="tight")

select_gos, sel_go_df_dict = top_gos(go_group_dict,select_keys=['CA3 Nell2','DGC Nell2'])
nell2_bubble_fig = plot_go_bubbles( sel_go_df_dict, select_gos, figsize=(4,2))
nell2_bubble_fig.savefig('./figures/nell2_bubble_fig.svg',format='svg',bbox_inches="tight")

select_gos, sel_go_df_dict = top_gos(go_group_dict,select_keys=['CA3 ApoE','DGC ApoE'])
apoe_bubble_fig = plot_go_bubbles( sel_go_df_dict, select_gos, figsize=(4,1),hidden_legend=True)
apoe_bubble_fig.savefig('./figures/apoe_bubble_fig.svg',format='svg',bbox_inches="tight")

Creating Figure Layouts

In [None]:
!pip install svgutils

In [None]:
fig_params = {'left margin':0,
              'fig_width':7.05*72,
              'fig_length':9*72,
              'lettering_size': 12,
              'lettering_wt': 'bold',
              'rows':np.cumsum([12,80,80,150,80]),
              'scale':1
               }

# p=svc.Panel(SVG('./figure/'),
#         Text("A", 20, 20, size=label_size,weight=weight)
#         ).move(width*0-10, row[0])

panel_list=list()



### Rank Corr
panel_list.append(svc.Panel(svc.MplFigure(CA3_nell2_fig_R_rank),
                            svc.Text("A", 0, 0, size=fig_params['lettering_size'],weight=fig_params['lettering_wt'])
                            ).move(fig_params['fig_width']*0.00, fig_params['rows'][0]))

panel_list.append(svc.Panel(svc.MplFigure(dgc_nell2_fig_R_rank),
                            svc.Text("B", 0, 0, size=fig_params['lettering_size'],weight=fig_params['lettering_wt'])
                            ).move(fig_params['fig_width']*0.25, fig_params['rows'][0]))

panel_list.append(svc.Panel(svc.MplFigure(CA3_apoe_fig_R_rank),
                            svc.Text("C", 0, 0, size=fig_params['lettering_size'],weight=fig_params['lettering_wt'])
                            ).move(fig_params['fig_width']*0.00, fig_params['rows'][1]))

panel_list.append(svc.Panel(svc.MplFigure(dgc_apoe_fig_R_rank),
                            svc.Text("D", 0, 0, size=fig_params['lettering_size'],weight=fig_params['lettering_wt'])
                            ).move(fig_params['fig_width']*0.25, fig_params['rows'][1]))

panel_list.append(svc.Panel(svc.MplFigure(upset_go_fig),
                            svc.Text("E", 0, 0, size=fig_params['lettering_size'],weight=fig_params['lettering_wt'])
                            ).move(fig_params['fig_width']*0.50, fig_params['rows'][0]))

panel_list.append(svc.Panel(svc.MplFigure(nell2_bubble_fig),
                            svc.Text("F", 0, 0, size=fig_params['lettering_size'],weight=fig_params['lettering_wt'])
                            ).move(fig_params['fig_width']*0.20, fig_params['rows'][2]))

panel_list.append(svc.Panel(svc.MplFigure(apoe_bubble_fig),
                            svc.Text("G", 0, 0, size=fig_params['lettering_size'],weight=fig_params['lettering_wt'])
                            ).move(fig_params['fig_width']*0.20, fig_params['rows'][3]))

panel_list.append(
    svc.Panel(svc.SVG('./figures/both_bubble_fig.svg'),
              svc.Text("H", 0, 0, size=fig_params['lettering_size'],
                       weight=fig_params['lettering_wt'])
              ).move(fig_params['fig_width']*0.00, fig_params['rows'][4]))

layout = svc.Figure(f"{fig_params['fig_width']}px",
                    f"{fig_params['fig_length']}px",
                    *panel_list,
                    svc.Grid(72, 72))
layout.save("./figures/GO_Figure_Layout.svg")
layout

In [None]:
rcParams