By Hattie Chung (contact: hchung@broadinstitute.org) 

This notebook contains helper functions used in all analysis notebooks (for both HeLa and mouse hippocampus data). Functions from this notebook are imported using nbimporter.

## load packages

In [None]:
import scipy.stats as stats
import statsmodels.formula.api as smf
import statsmodels.api as sm
import warnings
warnings.filterwarnings("ignore")
import statsmodels.stats as sms
import time
import numpy as np
import pandas as pd
import os
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from statannot import add_stat_annotation
from plotnine import *

In [None]:
def copy_figs_to_bucket(fdir, gdir, wildcard='*'):
    if type(fdir) is not str: 
        fdir = str(sc.settings.figdir)
    fdir1 = fdir+'/'+wildcard
    !gsutil -m cp $fdir1 $gdir
    print('Copying complete')

## data handling and merging

In [None]:
def merge_adt_and_gex(adt_pd, gex_ad, ab_list): 
    import time
    start = time.time()
    for cell_bc in gex_ad.obs.index:
        bc = cell_bc
        if bc in adt_pd.columns:
            for ab in ab_list: 
                gex_ad.obs.at[cell_bc,ab] = adt_pd.loc[ab,bc]
    end = time.time()
    print('Merged GEX and ADT in %0.3f seconds' %(end-start) )
    return gex_ad

In [None]:
def calc_mito_ncounts(ad, case='mouse'):
    # calculate mitochondrial fraction
    if case=='human': 
        mito_genes = ad.var_names.str.startswith('MT-') # human
    if case=='hg19': 
        mito_genes = ad.var_names.str.startswith('hg19_MT-') # human
    elif case=='mouse': 
        mito_genes = ad.var_names.str.startswith('mt-') #mouse
    
    ad.obs['frac_mito'] = np.sum(ad[:, mito_genes].X, axis=1).A1 / np.sum(ad.X, axis=1).A1

    # add the total counts per cell as observations-annotation to adata
    ad.obs['n_counts'] = ad.X.sum(axis=1).A1
    return ad

In [None]:
def roundup(x, base=5):
    import math
    return int(math.ceil(x / float(base))) * base

def round_decimals_up(number:float, decimals:int=2):
    import math
    """
    Returns a value rounded up to a specific number of decimal places.
    """
    if not isinstance(decimals, int):
        raise TypeError("decimal places must be an integer")
    elif decimals < 0:
        raise ValueError("decimal places has to be 0 or more")
    elif decimals == 0:
        return math.ceil(number)

    factor = 10 ** decimals
    return math.ceil(number * factor) / factor

## CITE normalizations

In [None]:
def normalize_CITE(ad_obj, ab_list):    
    import scipy
    # normalize all CITE ab by doing ab+1 / hashtag_counts (inflated pseudocount)
    for ab in ab_list: 
        # nADT
        ab_x = (ad_obj.obs[ab]+1)/ad_obj.obs['hashtag_counts']
        ad_obj.obs[ab+'_norm'] = ab_x
        
        # nCLR
        geo_mean = scipy.stats.mstats.gmean(ab_x[~np.isnan(ab_x)])
        ad_obj.obs[ab+'_nCLR'] = np.log(ab_x/geo_mean)
        

In [None]:
# batch specific CLR correction
def CITE_binarize_by_batch(ad, ab, norm_type, thresh_map):
    import scipy
    ad.obs[ab+'_binary'] = 0
    for batch in ad.obs.batch.cat.categories:
        b_thresh = thresh_map[batch]
        b_idx = ((ad.obs['batch']==batch) & (ad.obs[ab+'_'+norm_type]>=b_thresh))
        ad.obs.loc[b_idx,ab+'_binary'] = 1

## subsetting data

In [None]:
def get_feat_values(ad, var_name, scaling='lognorm'): 
    if var_name in ad.obs.columns: 
        # obs variable
        vals = list(ad.obs[var_name])
    elif var_name in ad.var_names: 
        # gene variable
        if scaling=='zscore':
            vals = [float(i) for i in list(ad[:,var_name].layers['zscore'].toarray())]
        elif scaling=='lognorm': 
            vals = [float(i) for i in list(ad[:,var_name].layers['counts'].toarray())]
        elif scaling=='zscore_regressed': 
            vals = [float(i) for i in list(ad[:,var_name].X.toarray())]
        elif scaling=='spliced': 
            vals = [float(i) for i in list(ad[:,var_name].layers['spliced'].toarray())]
        elif scaling=='unspliced': 
            vals = [float(i) for i in list(ad[:,var_name].layers['unspliced'].toarray())]
    return vals

def get_linear_regression_stats(ad, xname, yname, scaling='lognorm'):
    from scipy import stats

    xvals = get_feat_values(ad, xname, scaling)
    yvals = get_feat_values(ad, yname, scaling)
    
    m,b = np.polyfit(xvals, yvals, 1)
    slope, intercept, r_value, p_value, std_err = stats.linregress(xvals, yvals)
    print('Slope %.4f \t R^2 %.4f \t pval %.6f' %(slope, r_value**2, p_value) )
    print(r_value**2)
    print(p_value)
    
    plot_X = xvals
    plot_Y = [m*x for x in xvals] + b
    
    return plot_X, plot_Y

In [None]:
def regress_ncounts(ad): 
    sc.pp.normalize_total(ad)
    sc.pp.log1p(ad)
    ad.raw = ad
    sc.pp.regress_out(ad, ['n_counts'])
    sc.pp.scale(ad, max_value=10)

In [None]:
def get_filtered_df(df, FC_cutoff): 
    for col in df.columns: 
            df.loc[(df[col].abs()<FC_cutoff), col] = np.NAN
            
    # drop rows with only NaN
    nan_only_genes = (df.isnull().sum(axis=1)==df.shape[1])==True
    bad_genes = (nan_only_genes[nan_only_genes]).index
    df.drop(bad_genes, inplace=True)
    return df
    
def get_mask_subbed_df(df, mask_val=-100): 
    # substitute with mask_val 
    mask = df.isnull()
    subbed_df = df.fillna(mask_val)
    return subbed_df

def get_mask_dropped_df(df): 
    # drop mask values
    dropped_df = np.ma.masked_invalid(df)
    return dropped_df

In [None]:
def make_EX_neuron_broad_type(ad_in): 
    clusters = ad_in.obs.annot.cat.categories
    ex_neurons = [cluster for cluster in clusters if cluster.startswith('Ex.')]
#     astroglia = ['Astrocyte','Microglia']

    adata_grouped = ad_in.copy()
#     adata_grouped.obs['annot'].cat.add_categories(['EX_neuron','Astroglia'], inplace=True)
    adata_grouped.obs['annot'].cat.add_categories(['EX_neuron'], inplace=True)
    adata_grouped.obs.loc[ad_in.obs['annot'].isin(ex_neurons),'annot'] = 'EX_neuron'
#     adata_grouped.obs.loc[ad_in.obs['annot'].isin(astroglia),'annot'] = 'Astroglia'

    adata_grouped.obs['annot'].cat.remove_unused_categories(inplace=True)
    return adata_grouped

In [None]:
def combine_EX_and_CA_subtypes(ad_in): 
    clusters = ad_in.obs.annot.cat.categories
    ex_CA = [cluster for cluster in clusters if cluster.startswith('Ex.CA')]
    ex_GC = [cluster for cluster in clusters if cluster.startswith('Ex.GranuleCell')]

    ad_CT_grouped = ad_in.copy()

    ad_CT_grouped.obs['annot'].cat.add_categories(['EX_CA','EX_GranuleCell'], inplace=True)
    ad_CT_grouped.obs.loc[ad_in.obs['annot'].isin(ex_GC),'annot'] = 'EX_GranuleCell'
    ad_CT_grouped.obs.loc[ad_in.obs['annot'].isin(ex_CA),'annot'] = 'EX_CA'

    ad_CT_grouped.obs['annot'].cat.remove_unused_categories(inplace=True)
    return ad_CT_grouped

# Linear model tools

In [None]:
def fetch_model(form, dataframe, model_type):
    import warnings
    import numpy as np
    warnings.filterwarnings("ignore")

    if model_type=='OLS':
        try: 
            mod = smf.ols(formula=form, 
                       data=dataframe).fit()
        except np.linalg.LinAlgError as err:
            if 'Singular matrix' in str(err):
                return None
    elif model_type=='GLM_NegBin': 
        try: 
            mod = smf.glm(formula=form, 
                       data=dataframe, 
                      family=sm.families.NegativeBinomial()).fit()
        except np.linalg.LinAlgError as err:
            if 'Singular matrix' in str(err):
                return None
    elif model_type=='mixedlm': 
        try: 
            mod = smf.mixedlm(form, dataframe, 
                          groups=dataframe['batch']).fit()
        except np.linalg.LinAlgError as err:
            if 'Singular matrix' in str(err):
                return None
    elif model_type=='mixedlm_regularized': 
        try: 
            mod = smf.mixedlm(form, dataframe, 
                          groups=dataframe['batch']).fit_regularized()
        except np.linalg.LinAlgError as err:
            if 'Singular matrix' in str(err):
                return None
    else: 
        print('Please enter valid model type')
    return mod

In [None]:
def run_model(ad, FORMULA, model_type, cluster_name, CITE, data_mode, 
              regress=False, scale=False, permute=False, min_obs=20):
    import copy
    from sklearn import preprocessing
    
    # subset cells based on cluster if needed
    if cluster_name!='': 
        if cluster_name not in set(ad.obs['annot']):
            print('Please enter valid cluster name')
            return
        else: 
            print('Subsetting cluster %s' %cluster_name)
            ad_clust = ad[ad.obs['annot']==cluster_name].copy()
    else: 
        print('Using full adata')
        ad_clust = ad.copy()
    
    # only keep genes found in at least 20 cells
    if (data_mode is 'lognorm') or (data_mode is 'zscore'): 
        counts_type='counts'
    elif data_mode is 'spliced': 
        counts_type='spliced_counts' 
    elif data_mode is 'unspliced': 
        counts_type='unspliced_counts'
    else: 
        print(data_mode)
        
    ad_clust = ad_clust[:, ad_clust.layers[counts_type].astype(bool).sum(axis=0)>min_obs].copy()
    cells, genes = ad_clust.shape
    print('Testing genes with min %i cells' %min_obs)
    if genes>0:
        print('Cluster %s with %i nuclei and %i genes' %(cluster_name,cells,genes))
    else: 
        print('Not enough cells with %i genes, aborting cluster analysis' %genes)
        return None, None
    
    
    # regress out n_counts here, if need be
    if regress: 
        regress_ncounts(ad_clust)
    
    # gather all variables 
    df = copy.deepcopy(ad_clust.obs)
    df['log_ncounts'] = np.log(df['n_counts'])
    print('min max log_ncounts %.3f, %.3f' %(min(df['log_ncounts']), max(df['log_ncounts'])) )
    df['treatment'] = copy.deepcopy(ad_clust.obs['assignment'])
    
    if scale: 
        scale_cols = ['log_ncounts', 'log_hashtag_counts', 'cFos_nCLR', 'p65_nCLR', 'PU1_nCLR', 'NeuN_nCLR']
        df[scale_cols] = preprocessing.scale(df[scale_cols])
        
    # initialize output dataframes
    mod = None
    start_idx = 0
    while mod is None: 
        init_gene = ad_clust.var_names[start_idx]
        df['Gene'] = get_feat_values(ad_clust, init_gene, scaling=data_mode)
        mod = fetch_model(FORMULA, df, model_type)
        start_idx+=1
        
    params = pd.DataFrame([], index=ad_clust.var_names, columns=mod.params.index)
    pvals = pd.DataFrame([], index=ad_clust.var_names, columns=mod.pvalues.index)
    
    # run model
    print('Running model for all genes')
    start = time.time()
    idx=0
    for gene in ad_clust.var_names:
        df['Gene'] = get_feat_values(ad_clust, gene, scaling=data_mode)
        mod = fetch_model(FORMULA, df, model_type)
        if not (mod is None):
            params.loc[gene] = mod.params
            pvals.loc[gene] = mod.pvalues
        else: 
            continue
            
        idx+=1
        if (idx%1000) == 0: print(idx)
            
    end = time.time()
    print(end-start)
                     
    return params, pvals

In [None]:
def run_two_partmodel(ad, FORM1, FORM2, model_type, cluster_name, data_mode, scale=False, min_obs=20):
    import copy
    from sklearn import preprocessing
    
    # subset cells based on cluster if needed
    if cluster_name!='': 
        if cluster_name not in set(ad.obs['annot']):
            print('Please enter valid cluster name')
            return
        else: 
            print('Subsetting cluster %s' %cluster_name)
            ad_clust = ad[ad.obs['annot']==cluster_name].copy()
    else: 
        print('Using full adata')
        ad_clust = ad.copy()
    
    # only keep genes found in at least 20 cells
    if (data_mode is 'lognorm') or (data_mode is 'zscore'): 
        counts_type='counts'
    elif data_mode is 'spliced': 
        counts_type='spliced_counts' 
    elif data_mode is 'unspliced': 
        counts_type='unspliced_counts'
    else: 
        print(data_mode)
        
    ad_clust = ad_clust[:, ad_clust.layers[counts_type].astype(bool).sum(axis=0)>=min_obs].copy()
    cells, genes = ad_clust.shape
    print('Testing genes with min %i cells' %min_obs)
    if genes>0:
        print('Cluster %s with %i nuclei and %i genes' %(cluster_name,cells,genes))
    else: 
        print('Not enough cells with %i genes, aborting cluster analysis' %genes)
        return None, None
    
    # gather all variables 
    df = copy.deepcopy(ad_clust.obs)
    df['log_ncounts'] = np.log(df['n_counts'])
    print('min max log_ncounts %.3f, %.3f' %(min(df['log_ncounts']), max(df['log_ncounts'])) )
    df['treatment'] = copy.deepcopy(ad_clust.obs['assignment'])
    
    if scale: 
        scale_cols = ['log_ncounts', 'log_hashtag_counts', 'cFos_nCLR', 'p65_nCLR', 'PU1_nCLR', 'NeuN_nCLR']
        df[scale_cols] = preprocessing.scale(df[scale_cols])

    # initialize output dataframes
    mod0 = None
    start_idx = 0
    while mod0 is None: 
        init_gene = ad_clust.var_names[start_idx]
        df['Gene'] = get_feat_values(ad_clust, init_gene, scaling=data_mode)
        mod0 = fetch_model(FORM1, df, model_type)

        df['Resid'] = mod0.resid
        mod = fetch_model(FORM2, df, model_type)

        start_idx+=1
        
    params = pd.DataFrame([], index=ad_clust.var_names, columns=mod.params.index)
    pvals = pd.DataFrame([], index=ad_clust.var_names, columns=mod.pvalues.index)
    
    # run model
    print('Running model for all genes')
    start = time.time()
    idx=0
    for gene in ad_clust.var_names: 
        df['Gene'] = get_feat_values(ad_clust, gene, scaling=data_mode)
        mod0 = fetch_model(FORM1, df, model_type)
        if not (mod0 is None):
            df['Resid'] = mod0.resid
            mod = fetch_model(FORM2, df, model_type)
        
            if not (mod is None):
                params.loc[gene] = mod.params
                pvals.loc[gene] = mod.pvalues
            else: 
                continue
        else: 
            continue
        idx+=1
        if (idx%1000) == 0: print(idx)
    end = time.time()
    print(end-start)
                     
    return params, pvals

In [None]:
def run_twostep_linear_model(adata, FORM1, FORM2, run_str, cluster, model='mixedlm', run_mode='zscore', 
                     scale=False, run_repeat=False, min_obs=20):
    import os
    outdir = './write/%s_%s_%s' %(run_str, model, run_mode)
    if not os.path.isdir(outdir): os.mkdir(outdir)
        
    # check for previous run results
    params_file = '%s/params_%s.pickle' %(outdir, cluster.replace(' ',''))
    pvals_file = '%s/pvals_%s.pickle' %(outdir, cluster.replace(' ',''))    

    # execute run 
    if not os.path.isfile(params_file) or run_repeat:
        print('Starting run...')
        params, pvals = run_two_partmodel(adata, FORM1, FORM2, model_type=model, cluster_name=cluster, 
                                  data_mode=run_mode, scale=scale, min_obs=min_obs)
        params.to_pickle(params_file)
        pvals.to_pickle(pvals_file)

    else: 
        print('Loading prior run result')
        params = pd.read_pickle(params_file)
        pvals = pd.read_pickle(pvals_file)
        
    print('%s model with %s for cluster %s' %(model, run_mode, cluster))
    
    return params, pvals


In [None]:
def run_linear_model(adata, FORM, tissue, cluster, antibody, model='OLS', run_mode='lognorm', 
                     run_repeat=False, regress=False, permute=False, scale=False, 
                     min_obs=20):  
    import os
    outdir = './write/%s_%s_%s' %(tissue, model, run_mode)
    if not os.path.isdir(outdir): os.mkdir(outdir)
        
    if type(antibody) is list:
        ab_prefix = 'all'
    else: 
        ab_prefix = antibody

    # check for previous run results
    if permute: 
        params_file = '%s/%s_%s_params_permuted.pickle' %(outdir, cluster.replace(' ',''), ab_prefix)
        pvals_file = '%s/%s_%s_pvals_permuted.pickle' %(outdir, cluster.replace(' ',''), ab_prefix)
    else: 
        params_file = '%s/%s_%s_params.pickle' %(outdir, cluster.replace(' ',''), ab_prefix)
        pvals_file = '%s/%s_%s_pvals.pickle' %(outdir, cluster.replace(' ',''), ab_prefix)    

    # execute run 
    if not os.path.isfile(params_file) or run_repeat:
        params, pvals = run_model(ad=adata, FORMULA=FORM, model_type=model, cluster_name=cluster, 
                                  CITE=antibody, data_mode=run_mode, regress=regress, scale=scale, 
                                  permute=permute, min_obs=min_obs)
        if params is not None: 
            params.to_pickle(params_file)
            pvals.to_pickle(pvals_file)

    else: 
        print('Loading prior run result')
        params = pd.read_pickle(params_file)
        pvals = pd.read_pickle(pvals_file)
        
    print('Ran %s model with %s for cluster %s and CITE %s' %(model, run_mode, cluster, ab_prefix))
    
    return params, pvals


In [None]:
def load_mixedlm_results_scaled(ad, run_name, cts_type, scale=True, run_repeat=False, min_obs=10, THRESHOLD=0.05): 
    # cts_type: 'lognorm' or 'unspliced' or 'zscore'

    formula = 'Gene ~ log_ncounts + cFos_nCLR + p65_nCLR + PU1_nCLR + NeuN_nCLR + log_hashtag_counts'
    antibodies = ['cFos_nCLR','p65_nCLR','PU1_nCLR','NeuN_nCLR']
    
    if run_name=='proteins_PBS': 
        RUN_STR = 'hippocampus_PBS_CITE_nCLR_scaled'
        ad_run = ad[ad.obs['assignment']=='PBS']
        plot_str = 'mixedlm_PBS_proteins_scaled'
        
    elif run_name=='proteins_KA': 
        RUN_STR = 'hippocampus_KA_CITE_nCLR_scaled'
        ad_run = ad[ad.obs['assignment']=='KainicAcid']
        plot_str = 'mixedlm_KA_proteins_scaled'
        
    elif run_name=='proteins_interaction': 
        formula = 'Gene ~ log_ncounts + cFos_nCLR*C(assignment) + p65_nCLR*C(assignment) + PU1_nCLR*C(assignment) + NeuN_nCLR*C(assignment) + log_hashtag_counts'
        RUN_STR = 'hippocampus_by_cluster_interaction'
        ad_run = ad
        plot_str = 'mixedlm_by_cluster_interaction'
        
    elif run_name=='proteins_interaction_flipped': 
        formula = 'Gene ~ log_ncounts + cFos_nCLR*C(assignment) + p65_nCLR*C(assignment) + PU1_nCLR*C(assignment) + NeuN_nCLR*C(assignment) + log_hashtag_counts'
        RUN_STR = 'hippocampus_by_cluster_interaction_flipped'
        ad_run = ad
        ad_run.obs['assignment'].cat.reorder_categories(['PBS','KainicAcid'], inplace=True)
        plot_str = 'mixedlm_by_cluster_interaction'
        
    params, pvals = combine_model_across_cell_types(ad_run, antibodies, formula, 
                                    run_class=RUN_STR, model='mixedlm', run_mode=cts_type,
                                    scale=scale,
                                    run_repeat=run_repeat,
                                    min_obs=min_obs)
    
    sig = dict.fromkeys(pvals.keys())
    for variate in sig.keys(): 
        sig[variate] = get_significance_df(pvals[variate], 
                                           method='fdr_bh', alpha=THRESHOLD)
        print(variate)
        print(sig[variate].sum())
        
    return params, pvals, sig, plot_str

In [None]:
def combine_model_across_cell_types(ad, CITE, formula, run_class='HIP', clusters='all', 
                                    pretty_plot=False, model='OLS', run_mode='lognorm',
                                    permute=False, regress=False, scale=False, run_repeat=False,
                                    min_obs=20): 

    if clusters=='all': 
        clusters = ad.obs['annot'].cat.categories
    
    if type(min_obs) is int: 
        min_obs_dict = dict.fromkeys(clusters, min_obs)
    else: 
        min_obs_dict = dict(zip(clusters, min_obs))
        
    # collect all significance and coeffs
    cluster_params = {}
    cluster_pvals = {}
    cur_params = []
    for clust in clusters: 
        print(clust)
    
        # run model 
        params, pvals = run_linear_model(ad, formula, run_class, clust, 
                                         antibody=CITE,
                                         model=model, run_mode=run_mode, 
                                         run_repeat=run_repeat,
                                        regress=regress, 
                                         scale=scale, 
                                        permute=permute,
                                        min_obs=min_obs_dict[clust])
        
        # add to aggregate
        if params is not None: 
            cluster_params[clust] = params
            cluster_pvals[clust] = pvals
            cur_params = params
        # if no data for current cluster, remove
        else: 
            clusters = clusters.drop(labels=[clust])

    print(clusters)
    # aggregated dataframes - one for each variate
    variates = cur_params.columns
    params_all = {}
    pvals_all = {}
    for var in variates: 
        var_param = aggregate_variate_dfs(cluster_params, var, clusters)
        var_pval = aggregate_variate_dfs(cluster_pvals, var, clusters)
        
        params_all[var], pvals_all[var] = var_param, var_pval
    
    return params_all, pvals_all

In [None]:
def load_mixedlm_results_total_model(ad, run_name, cts_type, run_repeat=False, min_obs=20,
                                     model_type='mixedlm', scale=False, THRESHOLD=0.05): 
    import pickle
    
    # cts_type: 'lognorm' or 'unspliced' or 'zscore'
    
    formula = 'Gene ~ log_ncounts + log_hashtag_counts + C(annot) + cFos_nCLR + p65_nCLR + PU1_nCLR + NeuN_nCLR'
    antibodies = ['cFos_nCLR','p65_nCLR','PU1_nCLR','NeuN_nCLR']
    
    if run_name=='proteins_PBS': 
        RUN_STR = 'hippocampus_PBS_CITE_nCLR_cluster_term'
        ad_run = ad[ad.obs['assignment']=='PBS']
        plot_str = 'total_mixedlm_PBS'
        
    elif run_name=='proteins_KA': 
        RUN_STR = 'hippocampus_KA_CITE_nCLR_cluster_term'
        ad_run = ad[ad.obs['assignment']=='KainicAcid']
        plot_str = 'total_mixedlm_KA'
    
    elif run_name=='treatment_interaction': 
        formula = 'Gene ~ log_ncounts + log_hashtag_counts + C(annot)*C(assignment) + cFos_nCLR + p65_nCLR + PU1_nCLR + NeuN_nCLR'
        RUN_STR = 'hippocampus_CITE_all_treatment_interaction'
        ad_run = ad
        plot_str = 'total_mixedlm_treatment_interaction'

    elif run_name=='treatment_interaction_complete': 
        formula = 'Gene ~ log_ncounts + log_hashtag_counts + C(annot)*C(assignment) + cFos_nCLR*C(assignment) + p65_nCLR*C(assignment) + PU1_nCLR*C(assignment) + NeuN_nCLR*C(assignment)'
        RUN_STR = 'hippocampus_CITE_all_treatment_interaction_complete'
        ad_run = ad
        plot_str = 'total_mixedlm_treatment_interaction'
        
    elif run_name=='treatment_interaction_complete_flipped': 
        formula = 'Gene ~ log_ncounts + log_hashtag_counts + C(annot)*C(assignment) + cFos_nCLR*C(assignment) + p65_nCLR*C(assignment) + PU1_nCLR*C(assignment) + NeuN_nCLR*C(assignment)'
        RUN_STR = 'hippocampus_CITE_all_treatment_interaction_complete_flipped'
        ad_run = ad
        ad_run.obs['assignment'].cat.reorder_categories(['PBS','KainicAcid'], inplace=True)
        plot_str = 'total_mixedlm_treatment_interaction'
        
    params, pvals = run_linear_model(ad_run, formula, RUN_STR, 
                                         '', antibodies, model=model_type, run_mode=cts_type, 
                                            scale=scale, min_obs=min_obs,
                                         run_repeat=run_repeat, permute=False)

    sig = get_significance_df(pvals, method='fdr_bh', alpha=THRESHOLD) #fdr_bh
                        
#     gene_frac = make_gene_frac_df(params['Intercept'], ad_run)
    
    return params, pvals, sig, plot_str

In [None]:
def aggregate_variate_dfs(df_dict, variate, clusts): 
    # gather all genes across all clusters
    all_genes = []
    for clust in clusts: 
        all_genes = np.union1d(all_genes, df_dict[clust]['Intercept'].index)

    var_df = pd.DataFrame([], columns=clusts, index=all_genes)
    for clust in clusts: 
        var_df[clust] = df_dict[clust][variate]
    return var_df

In [None]:
def combine_model_across_clusters_two_step(ad, formula1, formula2, run_class='HIP', clusters='all', 
                                    model='mixedlm', run_mode='zscore', run_repeat=False,
                                    scale=False, min_obs=20): 

    if clusters=='all': 
        clusters = ad.obs['annot'].cat.categories
    
    if type(min_obs) is int: 
        min_obs_dict = dict.fromkeys(clusters, min_obs)
    else: 
        min_obs_dict = dict(zip(clusters, min_obs))
        
    # collect all significance and coeffs
    cluster_params = {}
    cluster_pvals = {}
    for clust in clusters: 
        print(clust)
    
        # run model 
        
        params, pvals = run_twostep_linear_model(ad, formula1, formula2, run_class, 
                                                 cluster=clust, run_repeat=run_repeat, 
                                                 scale=scale, min_obs=min_obs_dict[clust]) 
        
        # add to aggregate
        cluster_params[clust] = params
        cluster_pvals[clust] = pvals
        
    # aggregated dataframes - one for each variate
    variates = params.columns
    params_all = {}
    pvals_all = {}
    for var in variates: 
        var_param = aggregate_variate_dfs(cluster_params, var, clusters)
        var_pval = aggregate_variate_dfs(cluster_pvals, var, clusters)
        
        params_all[var], pvals_all[var] = var_param, var_pval
        
    return params_all, pvals_all

## curating model results

In [None]:
def get_sig_gene_list_cluster(sig,params,var,clust):
    clust_genes = params[var][clust].loc[sig[var][clust]].index
    return clust_genes

def get_sig_gene_list(sig,params,var):
    sig_genes = dict.fromkeys(sig[var].columns)
#     sig_params = dict.fromkeys(sig[var].columns)
    for col in sig[var].columns: 
        cur_sig = sig[var][col]
        sig_genes[col] = cur_sig.loc[cur_sig==True].index
#         sig_params[col] = params[var]
    return sig_genes

In [None]:
def merge_sig_gene_lists_cluster(var1, sig1, params1, str1, 
                                 var2, sig2, params2, str2, 
                         colors=['#FC942D','#2765D9','#A337F0']): 

    color_map = {str1:colors[0], str2:colors[1], 'both':colors[2]} 
    
    sig_genes = dict.fromkeys(sig1[var1].columns)
    for clust in sig_genes.keys():
        genes_df = pd.DataFrame([], columns=['gene','list','cluster'])

        genes_1 = get_sig_gene_list_cluster(sig1, params1, var1, clust)
        genes_2 = get_sig_gene_list_cluster(sig2, params2, var2, clust)

        genes_both = list(np.intersect1d(genes_1, genes_2))

        for g1 in genes_1: 
            if g1 not in genes_both: 
                genes_df = genes_df.append({'gene':g1,'list':str1,'cluster':clust}, ignore_index=True)
        for g2 in genes_2: 
            if g2 not in genes_both: 
                genes_df = genes_df.append({'gene':g2,'list':str2,'cluster':clust}, ignore_index=True)
        for gb in genes_both: 
            genes_df = genes_df.append({'gene':gb,'list':'both','cluster':clust}, ignore_index=True)

        genes_df['color'] = genes_df['list'].map(color_map)
        
        # build df
        sig_genes[clust] = genes_df
    
    return sig_genes

In [None]:
def get_sig_gene_df_with_color(var, sig, params, desc): 
    if isinstance(sig, pd.DataFrame): 
        sig_genes = list(sig[sig[var]].index)
    else: 
        genes = get_sig_gene_list(sig, params, var)
        sig_genes = list(set([item for val in genes.values() for item in val]))
    
    genes_df = pd.DataFrame([], columns=['gene','list'])
    for g in sig_genes: 
        genes_df = genes_df.append({'gene':g, 'list':desc}, ignore_index=True)        
    genes_df['color'] = '#FC942D'
    return genes_df

In [None]:
def merge_sig_gene_lists(var1, sig1, params1, str1, var2, sig2, params2, str2, 
                         colors=['#FC942D','#2765D9','#A337F0']): 
    if isinstance(params1,dict):
        genes_1 = get_sig_gene_list(sig1, params1, var1)
        sig_genes_1 = list(set([item for val in genes_1.values() for item in val]))
    else: 
        sig_genes_1 = list(sig1[sig1[var1]].index)
        
    if isinstance(params2,dict):
        genes_2 = get_sig_gene_list(sig2, params2, var2)
        sig_genes_2 = list(set([item for val in genes_2.values() for item in val]))
    else: 
        sig_genes_2 = list(sig2[sig2[var2]].index)
        
    sig_genes_both = list(np.intersect1d(sig_genes_1, sig_genes_2))
    
    genes_df = pd.DataFrame([], columns=['gene','list'])
    for g1 in sig_genes_1: 
        if g1 not in sig_genes_both: 
            genes_df = genes_df.append({'gene':g1, 'list':str1}, ignore_index=True)
    for g2 in sig_genes_2: 
        if g2 not in sig_genes_both: 
            genes_df = genes_df.append({'gene':g2, 'list':str2}, ignore_index=True)
    for gb in sig_genes_both: 
        genes_df = genes_df.append({'gene':gb, 'list':'both'}, ignore_index=True)
        
    color_map = {str1:colors[0], str2:colors[1], 'both':colors[2]} 
    genes_df['color'] = genes_df['list'].map(color_map)
    return genes_df

In [None]:
def sorted_sig_genes(feat, param, pval, sig):
    sig_genes = {}
    for clust in sig[feat].columns: 
        param_clust = param[feat][clust].loc[~param[feat][clust].isnull()]
        pval_clust = pval[feat][clust].loc[~pval[feat][clust].isnull()]
        sig_clust = sig[feat][clust].loc[~sig[feat][clust].isnull()]

        param_sorted = param_clust.sort_values(ascending=True)
        sig_sorted = sig_clust[param_sorted.index]

        # organize gene list
        plot_df = pd.concat([param_sorted, sig_sorted, pd.Series(np.arange(0,len(param_sorted),1), 
                                                       index=param_sorted.index)], 
                              axis=1, keys=['param','sig', 'rank'])
        sig_genes[clust] = plot_df[plot_df['sig']]
    return sig_genes 

In [None]:
def readjust_sig_threshold(ad, clust, params, pvals, sig, min_obs=10, THRESHOLD=0.05): 
    import copy 
    
    ad_clust = ad[ad.obs['annot']==clust]
    genes_thresh = ad_clust[:,ad_clust.layers['counts'].astype(bool).sum(axis=0)>min_obs].var.index
    
    if isinstance(params, pd.DataFrame):
        # global         
        genes_thresh_sub = [g for g in params.index if g in genes_thresh]
        params_copy = params.loc[genes_thresh_sub]
        pvals_copy = pvals.loc[genes_thresh_sub]
        pvals_copy.fillna(1, inplace=True)
        
        sig_copy = get_significance_df(pvals_copy, method='fdr_bh', alpha=THRESHOLD) 

    else: 
        # cluster specific
        genes_thresh_sub = [g for g in params['Intercept'].index if g in genes_thresh]
        print('Truncated %i genes to %i genes' %(params['Intercept'].shape[0], len(genes_thresh_sub)))
        
        params_copy = copy.deepcopy(params)
        pvals_copy = copy.deepcopy(pvals)
        sig_copy = copy.deepcopy(sig)
        
        for variate in sig_copy.keys(): 
#             print(variate)
            params_copy[variate] = params_copy[variate].loc[genes_thresh_sub]
            pvals_copy[variate] = pvals_copy[variate].loc[genes_thresh_sub]
            sig_copy[variate] = sig_copy[variate].loc[genes_thresh_sub]
#             pvals_copy[variate].fillna(1, inplace=True)
        
            sig_copy[variate] = get_significance_df(pvals_copy[variate], #.fillna(1), 
                                               method='fdr_bh', alpha=THRESHOLD) 

    return params_copy, pvals_copy, sig_copy

In [None]:
def get_significance_df(pvals, method='bonferroni', alpha=0.05): 
    import copy
    import statsmodels.stats as sms
    
    df_sig = copy.deepcopy(pvals)
    for col in pvals.columns: 
        cur_pvals = pvals[col]
        NOT_NAN = ~cur_pvals.isnull()
        sig, _, _, _ = sms.multitest.multipletests(cur_pvals.loc[NOT_NAN], 
                                                 method=method, alpha=alpha)
        cur_sig = NOT_NAN.copy()
        cur_sig.replace(True, False, inplace=True)
        cur_sig.loc[NOT_NAN] = sig
        df_sig[col] = cur_sig 
        
    df_sig.fillna(False, inplace=True)
    return df_sig


In [None]:
def add_gene_frac(ad):
    ad.var['gene_frac'] = np.sum(ad.layers['counts'].astype(bool), axis=0).A1 / (ad.shape[0])

In [None]:
def make_gene_frac_df(df, ad): 
    df_genefrac = df.copy()
    for clust in df_genefrac.columns: 
        ad_clust = ad[ad.obs['annot']==clust].copy()
        add_gene_frac(ad_clust)
        df_genefrac[clust] = ad_clust.var['gene_frac']
    return df_genefrac

In [None]:
def make_dotplot_df(params, pvals, sig, df_frac, feat, THRESH, COEFF_THRESH=0): 
    
    df_sig = sig[feat].loc[sig[feat].sum(axis=1)>=THRESH]
    temp_params = params[feat].loc[df_sig.index] # 10/31
    sig_genes = temp_params[(np.abs(temp_params[df_sig])>COEFF_THRESH).sum(axis=1)>=1].index # 10/31
    print('%i genes meet thresh' %(len(sig_genes)) )
    
    df_sig = sig[feat].loc[sig_genes] # 10/31
    df_coeff = params[feat].loc[sig_genes] # 10/31
    df_pval = pvals[feat].loc[sig_genes] # 10/31
    
    clusters = df_sig.columns
    col_list = ['gene', 'cluster', 'coefficient', 'neglog_pval', 'significant', 'fraction']
    plot_df = pd.DataFrame([], columns=col_list)
    for gene in df_coeff.index: 
        for clust in df_coeff.columns: 
            p = df_pval.loc[gene,clust]
            if not np.isnan(p):
                temp_row = pd.DataFrame([[gene, clust, 
                                          df_coeff.loc[gene, clust],
                                          -np.log10(p), 
                                          df_sig.loc[gene, clust], 
                                          df_frac.loc[gene, clust] ] ], 
                                        columns=col_list)
                plot_df = plot_df.append(temp_row, ignore_index=True)

    plot_df['cluster'] = plot_df['cluster'].astype('category')
    plot_df['cluster'].cat.reorder_categories(new_categories=clusters, inplace=True)
    return plot_df, df_sig

In [None]:
def combine_df_sig(df_list, feat, groups): 
    df_feat = [df[feat] for df in df_list]
    feats_concat = pd.concat(df_feat, keys=groups, axis=1)
    return feats_concat

In [None]:
def combine_df_sig_multi_feats(df_list, feat_list, groups):
    df_feat=[]
    for df, feat in zip(df_list, feat_list): 
        df_feat.append(df[feat])
    feats_concat = pd.concat(df_feat, keys=groups, axis=1)
    return feats_concat

# NMF

In [None]:
def cluster_sig_genes(feat, sig_df, clusters='all'):
    if clusters=='all': 
        clusters = sig_cell_type['Intercept'].columns
        
    sig_genes = []
    for clust in clusters: 
        cur_sig = sig_df[feat][clust]
        sig_genes.extend(list(sig_df[feat][sig_df[feat][clust]].index))
    return list(set(sig_genes))

In [None]:
def rev_dict_list(mydict):
    from collections import defaultdict

    reversed_dict = defaultdict(list)
    for key, values in mydict.items():
        for val in values:
            reversed_dict[val].append(key)
    return reversed_dict

In [None]:
def subset_feats_clusters_NMF(clusters, sig_df, ad, HVGs, add_NMF_genes=True, remove_ribo=False, feats=['cFos_nCLR','p65_nCLR']): 
    
    # merge genes associated with each feature
    feat_genes_all = []
    for feat in feats: 
        feat_genes = cluster_sig_genes(feat, sig_df, clusters)
        feat_genes_all.extend(feat_genes)

    DEGs = list(set(feat_genes_all))
    
    # add HVGs
    if add_NMF_genes: 
        NMF_genes = list(np.union1d(DEGs, list(HVGs)))
    else: 
        NMF_genes = list(HVGs)
    
    # remove highly expressed genes
    remove_genes = ['Gm42418', 'Ttr', 'Fth1', 'Ptgds']
    for g in remove_genes: 
        if g in NMF_genes: 
            NMF_genes.remove(g)

    if remove_ribo: 
        NMF_genes = [g for g in NMF_genes if not g.startswith('Rpl')]
        NMF_genes = [g for g in NMF_genes if not g.startswith('Rps')]
    
    NMF_genes = [g for g in NMF_genes if not g.endswith('Rik')]

    if isinstance(sig_df, pd.DataFrame): 
        NMF_genes = [g for g in NMF_genes if g in sig_df.index]
    else: 
        NMF_genes = [g for g in NMF_genes if g in sig_df['Intercept'].index]
    NMF_genes = [g for g in NMF_genes if ad.var.loc[g,'means']>0]
    
    ad_NMF = ad[:,NMF_genes].copy()
    
#     ad_NMF, NMF_genes = get_cluster_genes_for_NMF(ad, clusters, DEGs, sig_df, remove_ribo=remove_ribo)

    # set ad to use
    ad_run = ad_NMF[ad_NMF.obs['annot'].isin(clusters)].copy()
    
    return ad_run, NMF_genes

In [None]:
def module_top_GO_terms(module_genes):
    fields_to_keep = ['source','name','p_value','significant','description','term_size',
                      'query_size','intersection_size']
    GO_aggregate = pd.DataFrame([], columns=fields_to_keep+['module'])
    for i in range(len(module_genes)): 
        module_i_GO = parse_GO_query(module_genes[str(i)], 'mmusculus', db_to_keep=['GO:BP','KEGG'])
        df_i = module_i_GO[fields_to_keep].iloc[:5]
        df_i['module'] = 'module%i' %i
        GO_aggregate = GO_aggregate.append(df_i)
    return GO_aggregate

In [None]:
def sorted_module_genes(module_by_gene, module_num, genes, top_N): 
    module_n = module_by_gene[module_num,:]
    idx = np.argsort(module_n)
    idx_list = list(idx)
    idx_list.reverse()
    genes_sorted = genes[idx_list]
    module_genes_sorted = module_n[idx_list]
    
    return module_genes_sorted, genes_sorted
    
def corr_gene_vs_module(ad, gene, module_scores): 
    from scipy.stats import pearsonr, spearmanr
    return pearsonr(ad[:,gene].layers['zscore'].flatten(), module_scores)

def corr_geneset_vs_module(ad, genes, module_scores): 
    from scipy.stats import pearsonr, spearmanr
    Rs = []
    for gene in genes: 
        R, pval = corr_gene_vs_module(ad, gene, module_scores)
        Rs.append(R)
    return Rs


In [None]:
def get_corr_df(ad, module_genes): 
    DEG_df = pd.DataFrame(ad.layers['zscore'], index=ad.obs.index, columns=ad.var.index)
    if isinstance(module_genes, dict): 
        top_module_genes = list(set([item for sublist in module_genes.values() for item in sublist]))
    else: 
        top_module_genes = module_genes
    DEG_df_sub = DEG_df[top_module_genes]
    DEG_corr = DEG_df_sub.corr()
    return DEG_corr

In [None]:
def map_coeffs_to_color(prot, coeffs, genes_color, max_val=0, cmap='Blues'): 
    
    import matplotlib.cm as cm
    import matplotlib.colors as mcolors

    if max_val==0: 
        max_val = max([np.abs(c) for c in coeffs])
        
    norm = mcolors.Normalize(vmin=-max_val, vmax=max_val, clip=True)
    mapper = cm.ScalarMappable(norm=norm, cmap=cmap) 

    for g, v in zip(coeffs.index, coeffs):
        genes_color.loc[g,'%s_effect' %prot] = mcolors.to_hex(mapper.to_rgba(v))
    return genes_color

In [None]:
def get_cell_type_coefficient(prot, sig, params, clusters): 
    if 'nCLR' in prot: 
        prot_nCLR = prot
    else: 
        prot_nCLR = '%s_nCLR' %prot
    
    feat_sub = sig[prot_nCLR][sig[prot_nCLR][clusters].sum(axis=1)>0]
    feat_coeffs = pd.DataFrame([],index=feat_sub.index, columns=['%s_effect' %prot])
    for idx, row in feat_sub.iterrows(): 
        if sum(row)==1: 
            feat_coeffs.loc[idx] = float(params[prot_nCLR].loc[idx,row[row].index])
        elif sum(row)>1:
            feat_coeffs.loc[idx] = max(params[prot_nCLR].loc[idx,row[row].index])
        else: 
            continue
            
    return feat_coeffs

In [None]:
# construct gene feature colors
def gene_feat_colors_cluster(sig, params, top_module_genes, clusters, 
                             proteins=['p65','cFos'], colors=['PuOr','PiYG_r']): 
    
    genes_color = pd.DataFrame('#ffffff',index=top_module_genes, 
                               columns=['%s_effect' %prot for prot in proteins])
    
    for prot in proteins:
        if 'nCLR' in prot: 
            prot_nCLR = prot
        else: 
            prot_nCLR = '%s_nCLR' %prot
            
        feat_sig_genes = sig[prot_nCLR][sig[prot_nCLR][clusters].sum(axis=1)>0].index
        for g in genes_color.index: 
            if g in feat_sig_genes: 
                genes_color.loc[g, '%s_effect' %prot] = '#000000'

#     colors = ['Purples','Greens']
    
    for prot, prot_color in zip(proteins, colors): 
        prot_coeffs = get_cell_type_coefficient(prot, sig, params, clusters)
        genes_sig_with_prot = genes_color[genes_color['%s_effect' %prot]!='#bfbfbf'].index
        prot_coeffs = prot_coeffs.loc[np.intersect1d(genes_sig_with_prot,list(prot_coeffs.index))]
        genes_color = map_coeffs_to_color(prot, prot_coeffs['%s_effect' %prot], genes_color, 
                                          max_val=0.1, cmap=prot_color)
        
    return genes_color

In [None]:
def corr_DEGs_modules(mxg, cxm, params, sig, feat, ad, cell_type=''): 
    mod_genes = mxg.columns
    if isinstance(params, pd.DataFrame): 
        lm_genes = params.index
        params_type = 'df'
    else: 
        lm_genes = params['Intercept'].index
        params_type = 'dict'
    
    modules = mxg.index
    both_genes = np.intersect1d(mod_genes, lm_genes)
    df = pd.DataFrame([], index=both_genes, columns=modules)
    
    for mod in modules: 
        df['%s_corr' %mod] = corr_geneset_vs_module(ad, df.index, cxm[mod])
        if params_type=='df': 
            df['params'] = params.loc[both_genes, feat]
            df['sig'] = sig.loc[both_genes, feat]
        elif params_type=='dict': 
            df['params'] = params[feat].loc[both_genes,cell_type]
            df['sig'] = sig[feat].loc[both_genes,cell_type]
            
    df_boxplot = pd.DataFrame([], columns=['gene','module','corr','sig','params'])
    for idx, row in df.iterrows(): 
        for mod in modules: 
            df_boxplot = df_boxplot.append({'gene':idx,
                                        'module':mod,
                                        'corr':row['%s_corr' %mod],
                                        'sig':row['sig'],
                                        'params':row['params']}, 
                                        ignore_index=True)
    return df, df_boxplot

# plotting tools

## comparing UMI and gene counts

In [None]:
def scatter_UMI_genes_hist(ad, samplestr, density=True, savefig=True): 
    from scipy.stats import gaussian_kde

    # definitions for the axes
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    spacing = 0.005

    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom + height + spacing, width, 0.2]
    rect_histy = [left + width + spacing, bottom, 0.2, height]

    h = plt.figure(figsize=(3,3), dpi=1200)

    ax_scatter = plt.axes(rect_scatter)
    ax_scatter.tick_params(direction='in', top=True, right=True)
    ax_histx = plt.axes(rect_histx)
    ax_histx.tick_params(direction='in', labelbottom=False)
    ax_histy = plt.axes(rect_histy)
    ax_histy.tick_params(direction='in', labelleft=False)

    # plot x=y line
    ax_scatter.plot([0,6],[0,6],'k--',linewidth=1)
    
    x = np.log10(ad.obs['n_counts'])
    y = np.log10(ad.obs['n_genes'])
    
    # calculate density
    if density: 
        xy = np.vstack([x,y])
        z = gaussian_kde(xy)(xy)
        ax_sc = ax_scatter.scatter(x, y, c=z, s=3, alpha=0.5, cmap='coolwarm')
#         plt.colorbar(ax_sc)
    
    else: 

        ax_scatter.scatter(x, y, s=3, color='#696969', alpha=0.5)
    
    bins = np.arange(0,6,0.1)
    ax_histx.hist(x, bins=bins, facecolor='#696969')
    ax_histy.hist(y, bins=bins, orientation='horizontal', facecolor='#696969')

    # set axis properties 
    ax_scatter.set_xlim([1,5.1])
    ax_scatter.set_xticks([1,2,3,4,5])
    ax_scatter.set_xticklabels([1,2,3,4,5])
    ax_scatter.set_ylim([1,4.5])
    ax_scatter.set_yticks([1,2,3,4])
    ax_scatter.set_xlabel('UMIs', fontsize=14)
    ax_scatter.set_ylabel('Genes', fontsize=14)
    
    ax_histx.set_xlim(ax_scatter.get_xlim())
    ax_histy.set_ylim(ax_scatter.get_ylim())
    
    ax_histx.spines['top'].set_visible(False)
    ax_histx.spines['right'].set_visible(False)
    ax_histx.spines['left'].set_visible(False)
    ax_histx.set_xticks([])
    ax_histx.set_yticks([])
    
    ax_histy.spines['top'].set_visible(False)
    ax_histy.spines['right'].set_visible(False)
    ax_histy.spines['bottom'].set_visible(False)
    ax_histy.set_xticks([])
    ax_histy.set_yticks([])

    if savefig: 
        plt.savefig('%s/scatter_ngenes_UMIs_hist_%s.png' %(sc.settings.figdir, samplestr), bbox_inches='tight')

## cluster proportions

In [None]:
def get_cluster_proportions(adata,
                            cluster_key="leiden",
                            sample_key="batch",
                            drop_values=None):
    """
    Input
    =====
    adata : AnnData object
    cluster_key : key of `adata.obs` storing cluster info
    sample_key : key of `adata.obs` storing sample/replicate info
    drop_values : list/iterable of possible values of `sample_key` that you don't want
    
    Returns
    =======
    pd.DataFrame with samples as the index and clusters as the columns and 0-100 floats
    as values
    """
    
    adata_tmp = adata.copy()
    sizes = adata_tmp.obs.groupby([cluster_key, sample_key]).size()
    props = sizes.groupby(level=1).apply(lambda x: 100 * x / x.sum()).reset_index() 
    props = props.pivot(columns=sample_key, index=cluster_key).T
    props.index = props.index.droplevel(0)
    props.fillna(0, inplace=True)
    
    if drop_values is not None:
        for drop_value in drop_values:
            props.drop(drop_value, axis=0, inplace=True)
    return props


def plot_cluster_proportions(cluster_props, 
                             cluster_palette=None,
                             xlabel_rotation=0): 
    import seaborn as sns
    fig, ax = plt.subplots(dpi=300)
    fig.patch.set_facecolor("white")
    
    cmap = None
    if cluster_palette is not None:
        cmap = sns.palettes.blend_palette(
            cluster_palette, 
            n_colors=len(cluster_palette), 
            as_cmap=True)
   
    cluster_props.plot(
        kind="bar", 
        stacked=True, 
        ax=ax, 
        legend=None, 
        colormap=cmap
    )
    
    ax.legend(bbox_to_anchor=(1.01, 1), frameon=False, title="Replicate / batch")
    sns.despine(fig, ax)
    ax.tick_params(axis="x", rotation=xlabel_rotation)
    ax.set_xlabel(cluster_props.index.name.capitalize())
    ax.set_ylabel("% of nuclei in cluster")
    ax.set_xticklabels(cluster_props.index, rotation = 90)
    ax.set_yticks([0,50,100])
    ax.set_yticklabels([0,50,100])
    
    fig.tight_layout()
    
    return fig

## colormaps

In [None]:
def make_gray_monoscale_cmap(): 
    import matplotlib
    from matplotlib import cm
    blues = cm.get_cmap('Blues', 500)
    blues_array = blues(np.linspace(0, 1, 15)).tolist()
    blues_array.insert(0, [0.85, 0.85, 0.85, 1.0])
    bg = matplotlib.colors.ListedColormap(blues_array,name='blues_with_gray')
    return bg

In [None]:
def make_seismic_with_nan(): 
    import matplotlib
    from matplotlib import cm
    seismic = cm.get_cmap('seismic', 500)
    seismic_array = seismic(np.linspace(0, 1, 499)).tolist()
    seismic_array.insert(0, [0.85, 0.85, 0.85, 1.0])
    sg = matplotlib.colors.ListedColormap(seismic_array,name='seismic_with_gray')
    return sg

In [None]:
def make_YlGn_colorbars(): 
    a = np.array([[0,1]])
    plt.figure(figsize=(0.5, 5))
    img = plt.imshow(a, cmap="YlGn", vmin=0, vmax=0.8)
    plt.gca().set_visible(False)
    cax = plt.axes([0.1, 0.2, 0.8, 0.6])
    plt.colorbar(orientation="vertical", cax=cax)
    figname = "%s/YlGn_colorbar.pdf" %sc.settings.figdir
    print('Saving to %s' %figname)
    plt.savefig(figname, bbox_inches='tight')

    a = np.array([[0,1]])
    plt.figure(figsize=(5, 0.5))
    img = plt.imshow(a, cmap="YlGn", vmin=0, vmax=0.8)
    plt.gca().set_visible(False)
    cax = plt.axes([0.1, 0.2, 0.6, 0.8])
    plt.colorbar(orientation="horizontal", cax=cax)
    figname = "%s/YlGn_colorbar_horizontal.pdf" %sc.settings.figdir
    print('Saving to %s' %figname)
    plt.savefig(figname, bbox_inches='tight')

In [None]:
def make_vertical_colorbar(cm='coolwarm', vmin=-0.15, vmax=0.15): 
    a = np.array([[0,1]])
    plt.figure(figsize=(0.5, 4))
    img = plt.imshow(a, cmap=cm, vmin=vmin, vmax=vmax)
    plt.gca().set_visible(False)
    cax = plt.axes([0.1, 0.2, 0.8, 0.6])
    plt.colorbar(orientation="vertical", cax=cax)
    figname = "%s/%s_colorbar_vertical.pdf" %(sc.settings.figdir, cm)
    print('Saving to %s' %figname)
    plt.savefig(figname, bbox_inches='tight')

In [None]:
def make_horizontal_colorbar(cm='coolwarm', vmin=-0.15, vmax=0.15): 
    a = np.array([[0,1]])
    plt.figure(figsize=(5,0.5))
    img = plt.imshow(a, cmap=cm, vmin=vmin, vmax=vmax)
    plt.gca().set_visible(False)
    cax = plt.axes([0.1, 0.2, 0.6, 0.8])
    plt.colorbar(orientation="horizontal", cax=cax)
    figname = "%s/%s_colorbar_horizontal.pdf" %(sc.settings.figdir, cm)
    print('Saving to %s' %figname)
    plt.savefig(figname, bbox_inches='tight')

In [None]:
def hex_to_rgb(value): 
    value = value.strip('#')
    lv = len(value)
    return tuple(int(value[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)) 

def rgb_to_dec(value): 
    return [v/256 for v in value]

def get_continuous_cmap(hex_list, map_name, float_list=None): 
    import matplotlib.colors as mcolors
    rgb_list = [rgb_to_dec(hex_to_rgb(i)) for i in hex_list]
    if float_list: 
        pass
    else: 
        float_list = list(np.linspace(0,1,len(rgb_list)))
    cdict = dict()
    for num, col in enumerate(['red','green','blue']):
        col_list = [[float_list[i], rgb_list[i][num], rgb_list[i][num]] for i in range(len(float_list))]
        cdict[col] = col_list
    cmp = mcolors.LinearSegmentedColormap(map_name, segmentdata=cdict, N=256)
    return cmp

## protein levels

In [None]:
def get_ab_lims(ab): 
    if ab=='NeuN_nCLR': 
        bins = np.arange(-2,2,0.05)
        xlims = [-2,1.5]
        batch_thresh = {'0':-0.2,
                        '1':-0.11}
    elif ab=='NeuN_CLR':
        bins = np.arange(-3,2.5,0.05)
        xlims = [-3,2.5]
        batch_thresh = {'0':-1,
                        '1':0}
    elif ab=='cFos_nCLR': 
        bins = np.arange(-5,5,0.1)
        xlims = [-3,3]
        batch_thresh = {'0':0.1,
                        '1':-0.65}
    elif ab=='cFos_CLR': 
        bins = np.arange(-5,5,0.1)
        xlims = [-2,2]
        batch_thresh = {'0':-0.1,
                        '1':-0.08}
    elif ab=='p65_nCLR': 
        bins = np.arange(-2,3,0.05)
        xlims = [-3,3.5]
        batch_thresh = {'0':0.4,
                        '1':0.75}
    elif ab=='p65_CLR': 
        bins = np.arange(-5,5,0.1)
        xlims = [-3,3]
        batch_thresh = {'0':0.1,
                        '1':1.25}
        
    elif ab=='PU1_nCLR':
        bins = np.arange(-2,4,0.2)
        xlims = [-2,3]
        batch_thresh = {'0':0.85,
                        '1':0.6}
    
    return xlims, bins, batch_thresh 

### histograms

In [None]:
def normalized_disthist(vals, bins, color='#696969', alpha=0.5, kde_bw=0.2, kde_thresh=0.8, ax=None): 
    import matplotlib.pyplot as plt
    import seaborn as sns
    ax = sns.distplot(vals, bins=bins, norm_hist=True, color=color,
                        kde_kws={'bw':kde_bw,
                                'thresh':kde_thresh},
                         ax=ax)
#     plt.xlim([min(bins),max(bins)])
    axes = plt.axes(ax)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(1.5)
    ax.spines['left'].set_linewidth(1.5)
    return ax

In [None]:
def get_multi_mode(l): 
    from collections import Counter
    from itertools import groupby

    # group most_common output by frequency
    freqs = groupby(Counter(l).most_common(), lambda x:x[1])
    # pick off the first group (highest frequency)
    return [val for val,count in next(freqs)[1]]
    

def plot_disthist_by_batch(ab, ad_fore, fore_c, ad_back, back_c, savefig=False, dotted_line=True): 
    
    xlims, bins, b_thresh = get_ab_lims(ab)
#     fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(4,7))
    fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(4,5))
    
    plt.axes(ax1)
    ab_b0 = ad_back[ad_back.obs['batch']=='0'].obs[ab]
    ab_f0 = ad_fore[ad_fore.obs['batch']=='0'].obs[ab]
    
    normalized_disthist(      ab_b0, bins=bins, color=back_c, 
                             kde_bw=0.1, kde_thresh=10, alpha=0.7, ax=ax1)
    ax1 = normalized_disthist(ab_f0, bins=bins, color=fore_c, 
                             kde_bw=0.1, kde_thresh=10, alpha=0.7, ax=ax1)
    if dotted_line: 
        ax1.axvline(b_thresh['0'], color='k', linestyle='--', alpha=0.5)
    ax1.set_xlim(xlims)
    
    plt.sca(ax2)
    ab_b1 = ad_back[ad_back.obs['batch']=='1'].obs[ab]
    ab_f1 = ad_fore[ad_fore.obs['batch']=='1'].obs[ab]
    ax2 = normalized_disthist(ab_b1, bins=bins, color=back_c, 
                             kde_bw=0.1, kde_thresh=10, alpha=0.7, ax=ax2)
    ax2 = normalized_disthist(ab_f1, bins=bins, color=fore_c, 
                             kde_bw=0.1, kde_thresh=10, alpha=0.7, ax=ax2)
    if dotted_line: 
        ax2.axvline(b_thresh['1'], color='k', linestyle='--', alpha=0.5)
    ax2.set_xlim(xlims)

    
    if savefig: 
        figname = '%s/distplot_%s_by_batch.pdf' %(sc.settings.figdir, ab)
        print('Saving to %s' %figname)
        fig.savefig(figname, bbox_inches='tight')

def plot_disthist_single_batch(ab, ad_fore, fore_c, ad_back, back_c, batch, xlims=[], savefig=False, 
                               vertline=True, figstr='', kde_bw=0.1):
    xlim_pre, bins, b_thresh = get_ab_lims(ab)
    if len(xlims) == 0: 
        xlims = xlim_pre
    
    fig, ax1 = plt.subplots(figsize=(4,3))
    plt.axes(ax1)
    back_gr = ad_back[ad_back.obs['batch']==batch].obs[ab]
    normalized_disthist(back_gr, bins=bins, color=back_c, 
                             kde_bw=kde_bw, kde_thresh=10, alpha=0.7, ax=ax1)
    fore_gr = ad_fore[ad_fore.obs['batch']==batch].obs[ab]
    ax1 = normalized_disthist(fore_gr, bins=bins, color=fore_c, 
                             kde_bw=kde_bw, kde_thresh=10, alpha=0.7, ax=ax1)
    
    if vertline:
        ax1.axvline(b_thresh[batch], color='k', linestyle='--', alpha=0.5)
        
    ax1.set_xlim(xlims)
    if savefig and figstr=='': 
        fig.savefig('%s/distplot_%s_batch%s.pdf' %(sc.settings.figdir, ab, batch), bbox_inches='tight')
    elif savefig: 
        fig.savefig('%s/%s.pdf' %(sc.settings.figdir, figstr), bbox_inches='tight')
        
    # stats test
    import scipy
    stat, pval = scipy.stats.ttest_ind(back_gr,fore_gr)
    print(pval, stat)

### boxplots - single protein

In [None]:
def get_RNA_level_by_protein_bin(ad, gene, cite, scaling='lognorm', remove_zeros=False): 
    bin_name = cite+'_binary'
    if remove_zeros: 
        if scaling=='lognorm' or scaling=='zscore': 
            feat_type='counts'
        elif scaling=='spliced': 
            feat_type='spliced_counts'
        elif scaling=='unspliced': 
            feat_type='unspliced_counts'
        else: 
            feat_type=scaling
        ad = ad[ad[:,gene].layers[feat_type].toarray()>0].copy()
    
    prot_neg = ad[ad.obs[bin_name]==0]
    prot_pos = ad[ad.obs[bin_name]==1]
    
    RNA_neg = get_feat_values(prot_neg, gene, scaling)
    RNA_pos = get_feat_values(prot_pos, gene, scaling)
    
    RNA_off_df = pd.DataFrame(list(zip(RNA_neg,len(RNA_neg)*['off'])),columns=[gene,'protein'])
    RNA_on_df = pd.DataFrame(list(zip(RNA_pos,len(RNA_pos)*['on'])),columns=[gene,'protein'])
    RNA_df = RNA_off_df.append(RNA_on_df,ignore_index=True)
    
    return RNA_neg, RNA_pos, RNA_df

In [None]:
def boxplot_RNA_by_protein_bin(ad, gene, prot, color_dict, ylims=[0,11], desc='', scaling='zscore', 
                               remove_zeros=True):
    from statannot import add_stat_annotation
    if prot+'_binary' not in ad.obs.columns: 
        _,_,thresh = get_ab_lims(prot+'_nCLR')
        CITE_binarize_by_batch(ad, prot, 'nCLR', thresh)
    
    RNA_off, RNA_on, RNA_df = get_RNA_level_by_protein_bin(ad, gene, prot,
                                                            scaling=scaling, 
                                                            remove_zeros=remove_zeros)
    fig, ax = plt.subplots(figsize=(2,3))
    ax = sns.boxplot(data=RNA_df, x='protein', y=gene, 
                     order=['off','on'], width=0.3, 
                     palette=color_dict)
    ax.set_ylim(ylims)
    for patch in ax.artists:
        r, g, b, a = patch.get_facecolor()
        patch.set_facecolor((r, g, b, 0.5))
        
    ax = sns.stripplot(data=RNA_df, x='protein', y=gene, 
                       order=['off','on'], s=2, 
                       alpha=0.5, jitter=0.08,
                       palette=color_dict)
    test_results = add_stat_annotation(ax, data=RNA_df, x='protein', y=gene, 
                                       order=['off','on'],
                                       box_pairs=[('off', 'on')],
                                       test='t-test_ind', text_format='star',
                                       loc='outside', verbose=2)
    figname = '%s/boxplot_%s_%s_%s_%s.pdf' %(sc.settings.figdir, prot, gene, scaling, desc)
    print('Saving to %s' %figname)
    fig.savefig(figname, bbox_inches='tight')
    
    return ax


In [None]:
def boxplot_RNA_by_protein_bin_splicing(ad, gene, prot, color_dict, desc='', remove_zeros=True):
    from statannot import add_stat_annotation
    if prot+'_binary' not in ad.obs.columns: 
        _,_,thresh = get_ab_lims(prot+'_nCLR')
        CITE_binarize_by_batch(ad, prot, 'nCLR', thresh)
    
    _, _, RNA_spliced = get_RNA_level_by_protein_bin(ad, gene, prot,
                                                            scaling='zscore', 
                                                            remove_zeros=remove_zeros)
    RNA_spliced['data_type'] = 'spliced'
    
    _, _, RNA_unspliced = get_RNA_level_by_protein_bin(ad, gene, prot,
                                                            scaling='unspliced', 
                                                            remove_zeros=remove_zeros)
    RNA_unspliced['data_type'] = 'unspliced'
    
    RNA_df = pd.concat([RNA_spliced,RNA_unspliced])
    
    plt.figure()
    fig, ax = plt.subplots(figsize=(3,3))
    ax = sns.boxplot(data=RNA_df, x='protein', y=gene, 
                     order=['off','on'], 
                     hue='data_type',
                     hue_order=['unspliced','spliced'])
    for patch in ax.artists:
        r, g, b, a = patch.get_facecolor()
        patch.set_facecolor((r, g, b, 0.5))
        
    ax = sns.stripplot(data=RNA_df, x='protein', y=gene, 
                       order=['off','on'], 
                       s=2, 
                       hue='data_type',
                       split=True,
                       hue_order=['unspliced','spliced'],
                       alpha=0.6, jitter=0.08)

    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    test_results = add_stat_annotation(ax, data=RNA_df, x='protein', y=gene, 
                                       order=['off','on'],
                                       hue='data_type',
                                       box_pairs=[(('off','unspliced'),('on','unspliced')),
                                                 (('off','spliced'),('on','spliced'))],
                                       test='t-test_ind', text_format='star',
                                       loc='inside', verbose=2)
    figname = '%s/boxplot_by_splicing_%s_%s_%s.pdf' %(sc.settings.figdir, prot, gene, desc)
    print('Saving to %s' %figname)
    fig.savefig(figname, bbox_inches='tight')
    return ax


In [None]:
def boxplot_RNA_by_treatment(adata, gene):
    ad_PBS = adata[adata.obs['assignment']=='PBS'].copy()
    ad_KA = adata[adata.obs['assignment']=='KainicAcid'].copy()

    scaling = 'zscore'
    gene_PBS = get_feat_values(ad_PBS, gene, scaling)
    gene_KA = get_feat_values(ad_KA, gene, scaling)

    PBS_df = pd.DataFrame(list(zip(gene_PBS,len(gene_PBS)*['PBS'])),columns=[gene,'treatment'])
    KA_df = pd.DataFrame(list(zip(gene_KA,len(gene_KA)*['KA'])),columns=[gene,'treatment'])
    gene_df = PBS_df.append(KA_df,ignore_index=True)

    from statannot import add_stat_annotation
    fig, ax = plt.subplots(figsize=(2,2))
    ax = sns.boxplot(data=gene_df, x='treatment', y=gene, 
                     order=['PBS','KA'], width=0.3, fliersize=1, 
                     palette=['#dcdcdc','#00CC33'])
    # ax.set_ylim(ylims)
    for patch in ax.artists:
        r, g, b, a = patch.get_facecolor()
        patch.set_facecolor((r, g, b, 0.5))

    ax = sns.stripplot(data=gene_df, x='treatment', y=gene, 
                       order=['PBS','KA'], s=1, 
                       alpha=1, jitter=0.15,
                       palette=['#696969','#00CC33'])
    test_results = add_stat_annotation(ax, data=gene_df, x='treatment', y=gene, 
                                       order=['PBS','KA'],
                                       box_pairs=[('PBS', 'KA')],
                                       test='t-test_ind', text_format='star',
                                       loc='outside', verbose=2)

    figname = '%s/IEGs_boxplot_by_treatment_%s.pdf' %(sc.settings.figdir, gene)
    print('Saving to %s' %figname)
    fig.savefig(figname, bbox_inches='tight')

## linear model results

In [None]:
def make_significance_dotplot_R(plot_df, feat, casestr, figsize=(6,12), circle_size='neglog_pval'): 
    from rpy2.robjects.packages import importr
    from rpy2.robjects.conversion import localconverter
    from rpy2.robjects import pandas2ri, numpy2ri, r, Formula
    from rpy2.robjects.vectors import StrVector, FloatVector, ListVector
    import rpy2.robjects as ro
    
    num_genes = len(set(plot_df['gene']))
    if circle_size == 'neglog_pval': 
        circle_text = '-log10(p-value)'
    elif circle_size == 'fraction': 
        circle_text = 'Frac. cells expressed'
        
    limit = max(plot_df.coefficient.abs()) * np.array([-1, 1])
    g = (
        ggplot(aes(x='cluster', y='gene'), data=plot_df) +
        geom_point(aes(size=circle_size, fill='coefficient', color='significant'))+
        scale_fill_distiller(type='div', limits=limit, name='DE coefficient') + 
        scale_color_manual(values=('#808080', '#000000')) +  # 990E1D
        labs(size = circle_text, y='', x='', title='$%s$ %s'%(feat, casestr) ) +
        guides(size = guide_legend(reverse=True)) +
        theme_bw() +
        scale_size(range = (1,10)) +
        scale_y_discrete(drop=False) +
        theme(
          figure_size=figsize,
          legend_key=element_blank(),
          axis_text_x = element_text(rotation=45, hjust=1.),
        )
    )

    # ggsave(g, 'figure-1-c.pdf', width=9, height=12)
    print(g)
    return(g)

In [None]:
def plot_pretty_volcano_simplified(plot_df, ylims=None, yticks=None, 
                                   xlims=None, xticks=None, xlabel='',
                                   annotate=False, 
                                   annotate_list=[], color=None,
                                   savefig=False, figdir='', filename=''): 

    COEFF_THRESH = 0.01 # 0.02
    x_padding = 0.1
    if color is None: 
        dotcolor = 'r'
    else: 
        dotcolor = color
        
    fig, axes = plt.subplots(figsize=(4,4))
#     axes = fig.axes()
#     axes.set_position([0, 0, 8, 8])
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['bottom'].set_linewidth(1)
    axes.spines['left'].set_linewidth(1)

    plt.scatter(plot_df.loc[~plot_df['sig'],'coeff'], 
                plot_df.loc[~plot_df['sig'],'-log10p'], 
               s=1, alpha=0.5, color='#dcdcdc')
    plt.scatter(plot_df.loc[plot_df['sig'],'coeff'], 
                plot_df.loc[plot_df['sig'],'-log10p'], 
               s=4, alpha=1, color=dotcolor)
    
    # determine xaxis min/max 
    if xlims==None: 
        xmax = round_decimals_up(max(np.abs(plot_df['coeff']))+0.01,2)
        xmin = -xmax
    else: 
        xmin, xmax = xlims
        
    if ylims==None: 
        ymax = roundup(max(plot_df['-log10p']))
        ymin = 0
    else: 
        ymin, ymax = ylims
        
    plt.xlim([-xmax,xmax])
    plt.ylim([ymin,ymax])
    
    if yticks:
        plt.yticks(yticks,fontsize=10)
    if xticks: 
        plt.xticks(xticks,fontsize=10)
    
#     plt.xlim(xlims)
    
#     plt.ylim([min(yticks),max(yticks)+1])

    # labels
    if xlabel=='':
        plt.xlabel('coefficient',fontsize=14)
    else: 
        plt.xlabel('%s coefficient' %xlabel, fontsize=14)
    plt.ylabel('-log10(P)',fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    if annotate: 
        if len(annotate_list)==0:
            ANNOT_CRITERIA = ((plot_df['sig']) & (np.abs(plot_df['coeff'])>COEFF_THRESH))
            for idx in plot_df[ANNOT_CRITERIA].index:
                axes.annotate(idx, (plot_df.loc[idx, 'coeff'], 
                                  plot_df.loc[idx,'-log10p']),
                                 fontsize=10, fontweight='normal')
        else: 
            for idx in annotate_list:
                axes.annotate(idx, (plot_df.loc[idx, 'coeff'], 
                                  plot_df.loc[idx,'-log10p']),
                                 fontsize=10, fontweight='normal')

    if savefig: 
        if annotate: 
            figname = '%s/volcano_%s_annotated.pdf' %(figdir,filename)
        else: 
            figname = '%s/volcano_%s_no_annotation.pdf' %(figdir,filename)
            
        fig.savefig(figname, bbox_inches='tight')
        print('Saved volcano_%s' %figname)
#     plt.close(fig)

In [None]:
def get_volcano_plot_df(params, pvals, sig, feat, clust): 
    df_plot = pd.concat([params[feat][clust], pvals[feat][clust], sig[feat][clust]], 
                        axis=1, sort=False, keys=['coeff','pval','sig'])
    if feat=='C(assignment)[T.PBS]': 
        df_plot['coeff'] = -df_plot['coeff']
    df_plot['-log10p'] = -np.log10(df_plot['pval'].tolist())
    return df_plot


In [None]:
def plot_pretty_volcano(param_df, pval_df, feat_name, clust, annotate=False, savefig=False, figdir=''): 
    import statsmodels.stats as sms
    pval_df['log10'] = -np.log10(list(pval_df[feat_name]))
    pval_df['sig'], pval_df['p_adj'], _, _ = sms.multitest.multipletests(pval_df[feat_name], 
                                                                         method='fdr_bh',                                                                        
                                                                         alpha=0.01) #bonferroni fdr_bh
    if feat_name=='p65_norm':
        filename='p65_norm'
        COEFF_THRESH = 0.45
        x_padding = 0.1
        param_plot = param_df[feat_name]
    elif feat_name=='cFos_norm': 
        filename='cFos_norm'
        COEFF_THRESH = 0.2
        x_padding = 0.1
        param_plot = param_df[feat_name]
    elif feat_name=='C(treatment)[T.PBS]': 
        filename='treatment'
        COEFF_THRESH = 0.1
        x_padding = 0.1
        param_plot = -param_df[feat_name]
    else: 
        filename=feat_name
        COEFF_THRESH = 0.1
        x_padding = 0.1
        param_plot = param_df[feat_name]
        
    plt.figure(figsize=(3,3))
    axes = plt.axes()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['bottom'].set_linewidth(1)
    axes.spines['left'].set_linewidth(1)

    plt.scatter(param_plot.loc[~pval_df['sig']], 
                pval_df.loc[~pval_df['sig'],'log10'], 
               s=1, alpha=0.7, color='#bfbfbf')
    plt.scatter(param_plot.loc[pval_df['sig']], 
                pval_df.loc[pval_df['sig'],'log10'], 
               s=2, alpha=0.8, color='#EB5A49')
    
    # determine xaxis min/max 
    xlim = max(map(abs,param_plot))+x_padding
    plt.xlim([-xlim, xlim])
    
    # labels
    plt.xlabel('%s coefficient' %filename,fontsize=12)
    plt.ylabel('-log10(p)',fontsize=12)
    plt.title(clust,fontsize=12)
    
    if annotate: 
        ANNOT_CRITERIA = ((pval_df['sig']) & (np.abs(param_plot)>COEFF_THRESH))
        for idx in pval_df[ANNOT_CRITERIA].index:
            axes.annotate(idx, (param_plot.loc[idx], 
                              pval_df.loc[idx,'log10']),
                             fontsize=8, fontweight='normal')

    if savefig: 
        if annotate: 
            figname = '%s/OLS_volcano_%s_%s_annotated.pdf' %(figdir,clust,filename)
        else: 
            figname = '%s/OLS_volcano_%s_%s_no_annotation.pdf' %(figdir,clust,filename)
        print('Saving to %s' %figname)
        plt.savefig(figname, bbox_inches='tight')


In [None]:
def make_significant_heatmap(params_i, filename, min_ct=1, FC_cutoff=0.1, zmin=-3, zmax=6, fontsize=0.7):
    from matplotlib.colors import Normalize    

    class MidpointNormalize(Normalize):
#         from matplotlib.colors import Normalize
        def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
            self.midpoint = midpoint
            Normalize.__init__(self, vmin, vmax, clip)

        def __call__(self, value, clip=None):
            # I'm ignoring masked values and all kinds of edge cases to make a
            # simple example...
            x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
            return np.ma.masked_array(np.interp(value, x, y))
    
    seismic_gray = make_seismic_with_nan()
    
    params = params_i.copy()
    params = params[params.fillna(0).astype(bool).sum(axis=1)>=min_ct]
    
    if FC_cutoff:
        filtered_df = get_filtered_df(params, FC_cutoff)
    
    idx_to_keep = filtered_df.index[filtered_df.fillna(0).astype(bool).sum(axis=1)>0]
    filtered_df = filtered_df.loc[idx_to_keep]
    
    # get masked df
    mask_val=-100
    df_sub = get_mask_subbed_df(filtered_df, mask_val)
    
    # plot clustergram
    sns.set(font_scale=fontsize) 
    sns.set_style('white')
    g = sns.clustermap(df_sub, 
                       mask=df_sub==mask_val, 
                       center=0,
                        yticklabels=df_sub.index,
                        xticklabels=df_sub.columns,
                        metric='euclidean',
                        vmin=zmin, vmax=zmax, 
                       cmap=seismic_gray)
    row_idx = g.dendrogram_row.reordered_ind
    col_idx = g.dendrogram_col.reordered_ind
    g.ax_row_dendrogram.set_visible(False)
    g.ax_col_dendrogram.set_visible(False)
#     g.savefig('%s/clustermap_%s.pdf' %(sc.settings.figdir, filename) )
    sns.reset_orig()
    
    # matplotlib version
    ordered_params = params.iloc[row_idx]
    print(ordered_params.columns)
    print(col_idx)
    ordered_cols = [ordered_params.columns[i] for i in col_idx]
    print(ordered_cols)
    ordered_params = ordered_params[ordered_cols]
    ordered_df = get_mask_dropped_df(ordered_params)
    
    fig,ax = plt.subplots(figsize=(2,ordered_df.shape[0]*(1/10))) 
    heatmap = ax.pcolor(ordered_df, cmap=seismic_gray, norm=MidpointNormalize(midpoint=0), 
                  vmin=zmin, vmax=zmax, linewidth=2) 
    ax.patch.set(facecolor='#d3d3d3', edgecolor='black')
    ax.set_xticks(np.arange(ordered_params.shape[1])+0.5, minor=False)
    ax.set_xticklabels(ordered_params.columns, rotation=90, fontsize=6)
    ax.set_yticks(np.arange(0,ordered_params.shape[0],1)+0.5, minor=False)
    ax.set_yticklabels(ordered_params.index, fontsize=5.5)
#     fig.savefig('%s/heatmap_clustered_%s.pdf' %(sc.settings.figdir, filename), bbox_inches='tight')

    return filtered_df


In [None]:
def plot_feat_coefficients_PBS_KA(feat, params_PBS, params_KA, pvals_PBS, pvals_KA, 
                                  sig_PBS, sig_KA, GENE_TO_PLOT=None,
                                  x_sig_lim=None, y_sig_lim=None):
    prot_name = '%s_nCLR' %feat
    
    params = combine_df_sig([params_PBS, params_KA], prot_name, ['PBS','KA'])
    pvals = combine_df_sig([pvals_PBS, pvals_KA], prot_name, ['PBS','KA'])        
    sigs = combine_df_sig([sig_PBS, sig_KA], prot_name, ['PBS','KA'])
    
    # remove sigs that don't meet params threshold
    if x_sig_lim != None:
        print('limiting PBS sigs')
        sigs.loc[(np.abs(params['PBS'])<x_sig_lim), 'PBS'] = False
    if y_sig_lim != None:
        print('limiting KA sigs')
        sigs.loc[(np.abs(params['KA'])<y_sig_lim), 'KA'] = False

    sigs['combined'] = 'neither'
    sigs.loc[((sigs['PBS']==False) & (sigs['KA']==True)), 'combined'] = 'KA'
    sigs.loc[((sigs['PBS']==True) & (sigs['KA']==False)), 'combined'] = 'PBS'
    sigs.loc[((sigs['PBS']==True) & (sigs['KA']==True)), 'combined'] = 'both'
    
    fig = plt.figure(figsize=(5,5))
    axes = plt.gca()
    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    
    if prot_name=='cFos_nCLR':
        tx = [-0.1,-0.05,0,0.05,0.1]
        xlims = [-0.1, 0.1]
        ylims = [-0.075, 0.12]
    elif prot_name=='p65_nCLR':
        tx = [-0.4,-0.3,-0.2,-0.1,0,0.1,0.2,0.3,0.4]
        xlims = [-0.2, 0.4]
        ylims = [-0.1, 0.2]
    elif prot_name=='PU1_nCLR': 
        tx = [-0.1,-0.05,0,0.05,0.1]
        xlims = [-0.12, 0.12]
        ylims = [-0.075, 0.075]
    elif prot_name=='NeuN_nCLR':
        tx = [-0.4,-0.2,0,0.2,0.4]
        xlims = [-0.6, 0.3]
        ylims = [-0.4, 0.3]
    
    axes.set_xticks(tx)
    axes.set_xticklabels(labels=tx, fontsize=10)
    axes.set_xlim(xlims)

    axes.set_yticks(tx)
    axes.set_yticklabels(labels=tx, fontsize=10)
    axes.set_ylim(ylims)
    
    plt.scatter(data = params.loc[sigs['combined']=='neither'], 
                x='PBS',y='KA', s=2, 
                axes=axes, color='#dcdcdc') #bfbfbf
    plt.scatter(data = params.loc[sigs['combined']=='PBS'], 
                x='PBS',y='KA', s=9,
#                 s=-np.log10(pvals.loc[sigs['combined']=='PBS','PBS'].tolist()),
                axes=axes, color='#708a97')
    plt.scatter(data = params.loc[sigs['combined']=='KA'], 
                x='PBS',y='KA', s=9,
#                 s=-np.log10(pvals.loc[sigs['combined']=='KA','KA'].tolist()),
                axes=axes, color='#00CC33')
    # use the bigger p value (to be more conservative)
    plt.scatter(data = params.loc[sigs['combined']=='both'], 
                x='PBS',y='KA', s=9,
#                 s=-np.log10(pvals.loc[sigs['combined']=='both'].min(axis=1).tolist()), 
                axes=axes, color='k')
    plt.axvline(0, color='k', linewidth=0.8)
    plt.axhline(0, color='k', linewidth=0.8)
    if x_sig_lim != None:
        plt.axvline(x_sig_lim, color='#ff4500', linestyle='--', linewidth=0.5, 
                   alpha=0.8)
        plt.axvline(-x_sig_lim, color='#ff4500', linestyle='--', linewidth=0.5,
                   alpha=0.8)
    if y_sig_lim != None:
        plt.axhline(y_sig_lim, color='#ff4500', linestyle='--', linewidth=0.5, 
                   alpha=0.8)
        plt.axhline(-y_sig_lim, color='#ff4500', linestyle='--', linewidth=0.5,
                   alpha=0.8)

    plt.xlabel('%s coefficient (PBS)' %feat, fontsize=12)
    plt.ylabel('%s coefficient (KA)' %feat, fontsize=12)

    if GENE_TO_PLOT is not None:
        if type(GENE_TO_PLOT) is list: 
            for gene in GENE_TO_PLOT:
                axes.annotate(gene, (params.loc[gene,'PBS'], 
                              params.loc[gene,'KA']),
                             fontsize=10, fontweight='medium')
        else: 
            axes.annotate(GENE_TO_PLOT, (params.loc[GENE_TO_PLOT,'PBS'], 
                          params.loc[GENE_TO_PLOT,'KA']),
                         fontsize=10, fontweight='medium')
    figname = '%s/%s.pdf' %(sc.settings.figdir, 'DEG_coeff_scatter_PBSvsKA_mixedlm_regularized_%s' %feat)
    print('Saving to %' %figname)
    fig.savefig(figname, bbox_inches='tight')
#     plt.close(fig)
    return params

In [None]:
def params_linear_fit(params): 
    from scipy import stats
    params_copy = params.astype(float)
    params_copy.dropna(inplace=True)
    
    xvals = params_copy['PBS']
    yvals = params_copy['KA']
    m,b = np.polyfit(xvals, yvals, 1)
    slope, intercept, r_value, p_value, std_err = stats.linregress(xvals, yvals)
    print('Slope %.4f \t R^2 %.4f \t pval %.6f' %(slope, r_value**2, p_value) )
    print(r_value**2, p_value)


In [None]:
def plot_CITE_gene_scatter(ad, gene, cite, scaling='lognorm', remove_zeros=False, 
                           savefig=False, figdir='', fignote='', plotline=False): 
    
    if remove_zeros: 
        if scaling=='lognorm' or scaling=='zscore': 
            feat_type='counts'
        elif scaling=='spliced': 
            feat_type='spliced_counts'
        elif scaling=='unspliced': 
            feat_type='unspliced_counts'
        else: 
            feat_type=scaling
        ad = ad[ad[:,gene].layers[feat_type].toarray()>0].copy()
    
    plt.figure(figsize=(3,3))
    axes = plt.axes()

    axes.spines['top'].set_visible(False)
    axes.spines['right'].set_visible(False)
    axes.spines['bottom'].set_linewidth(1)
    axes.spines['left'].set_linewidth(1)
    axes.set_xlabel(xlabel=gene,fontsize=12)
    axes.set_ylabel(ylabel=cite,fontsize=12)
    axes.figure.set_size_inches(3,3)

    x_list = get_feat_values(ad, gene, scaling)
    y_list = get_feat_values(ad, cite, scaling)

    plt.scatter(x_list+np.random.uniform(low=-0.005,high=0.005,size=len(x_list)), y_list,
                 s=4,alpha=0.7,color='#bcbcbc',axes=axes) 
    plt.xlabel(gene,fontsize=12)
    plt.ylabel(cite,fontsize=12)

    X_V, Y_V = get_linear_regression_stats(ad, gene, cite, scaling)
    
    if plotline: 
        plt.plot(X_V, Y_V, color='k', linewidth=0.5)

    if savefig: 
        if sc.settings.figdir=='': 
            print('Please enter valid figdir')
            return
        else: 
            plt.savefig('%s/scatter_%s_%s_%s.pdf' %(figdir, gene, cite, fignote),
                                                bbox_inches='tight')

## dotplot tools

In [None]:
def cluster_df(df):
    mask_val=-100
    df_sub = get_mask_subbed_df(df, mask_val)
    g = sns.clustermap(df_sub, 
                            mask=df_sub==mask_val, 
                           metric='rogerstanimoto',
                            yticklabels=df.index,
                            xticklabels=df.columns,)
    row_idx = g.dendrogram_row.reordered_ind
    col_idx = g.dendrogram_col.reordered_ind
    plt.close()
    return row_idx, col_idx

def get_ordered_dotplot_df(params, pvals, sig, frac, feat, THRESH, COEFF_THRESH=0): 
    dotplot_df, df = make_dotplot_df(params, pvals, sig, frac, feat, THRESH, COEFF_THRESH=COEFF_THRESH)

    rows, cols = cluster_df(df)
    dotplot_df['gene'] = dotplot_df['gene'].astype('category')
    dotplot_df.gene.cat.reorder_categories(new_categories=df.index[rows], inplace=True)
    return dotplot_df

In [None]:
def get_volcano_plot_df_total_model(params, pvals, sig, feat): 
    df_plot = pd.concat([params[feat], pvals[feat], sig[feat]], 
                        axis=1, sort=False, keys=['coeff','pval','sig'])
#     if feat=='C(assignment)[T.PBS]': 
#         df_plot['coeff'] = -df_plot['coeff']
    df_plot['-log10p'] = -np.log10(df_plot['pval'].tolist())
    return df_plot

## cross correlation tools

In [None]:
from scipy.stats import pearsonr
import pandas as pd

def calculate_pvalues(df):
    df = df.dropna()._get_numeric_data()
    dfcols = pd.DataFrame(columns=df.columns)
    pvalues = dfcols.transpose().join(dfcols, how='outer')
    for r in df.columns:
        for c in df.columns:
            pvalues[r][c] = round(pearsonr(df[r], df[c])[1], 4)
    return pvalues


In [None]:
def add_coefficients(df, feat, params_unspliced, params_spliced): 
    df['coeff'] = 0
    for idx, row in df.iterrows():
        if row['list']=='unspliced':
            df.loc[idx,'coeff'] = params_unspliced.loc[row['gene'],feat]
        elif row['list']=='spliced':
            df.loc[idx,'coeff'] = params_spliced.loc[row['gene'],feat]
        elif row['list']=='both': 
            unsp = params_unspliced.loc[row['gene'],feat]
            sp = params_spliced.loc[row['gene'],feat]
            if np.sign(unsp)==np.sign(sp):
                df.loc[idx,'coeff'] = np.median([unsp,sp])
            else: 
                print('WARNING: opposite signs for %s' %row['gene'])
    return df

In [None]:
def add_coefficients_cluster(sig_dict, feat, params_unspliced, params_spliced): 
    for clust in sig_dict.keys(): 
        if sig_dict[clust].shape[0]>0: 
            sig_dict[clust]['coeff'] = 0
            for idx, row in sig_dict[clust].iterrows():
                if row['list']=='unspliced':
                    sig_dict[clust].loc[idx,'coeff'] = params_unspliced[feat].loc[row['gene'],clust]
                elif row['list']=='spliced':
                    df.loc[idx,'coeff'] = params_spliced[feat].loc[row['gene'],clust]
                elif row['list']=='both': 
                    unsp = params_unspliced[feat].loc[row['gene'],clust]
                    sp = params_spliced[feat].loc[row['gene'],clust]
                    if np.sign(unsp)==np.sign(sp):
                        df.loc[idx,'coeff'] = np.median([unsp,sp])
                    else: 
                        print('WARNING: opposite signs for %s' %row['gene'])
    return df

In [None]:
# def DEG_graph(ad, sig_genes_df, figstr, thresh=0.08, color_by='all', 
#               data_mode='zscore', color_mode='fraction', random_seed=42): 
#     import matplotlib as mpl

#     P_THRESH = 0.05
    
#     if color_by=='all': 
#         ad_col = ad
#     elif color_by=='neuron':
#         ad_col = ad_neuron
#     else: 
#         if type(color_by) is list: 
#             ad_col = ad[ad.obs['annot'].isin(color_by)]
#         else: 
#             ad_col = ad[ad.obs['annot'].isin([color_by])]
#     add_gene_frac(ad_col)   

#     DEG_corr, DEG_pval = get_significant_corr(ad, sig_genes_df, data_mode)
#     DEG_sig_corr = DEG_corr[DEG_pval<=P_THRESH]

#     # Transform it in a links data frame (3 columns only):
#     links = DEG_sig_corr.stack().reset_index()
#     links.columns = ['Gene1','Gene2','corr']
    
#     # Keep only correlation over a threshold and remove self correlation (cor(A,A)=1)
#     links_filtered = links.loc[ (links['corr'] > thresh) & (links['Gene1'] != links['Gene2'])]
    
#     import networkx as nx
#     G=nx.from_pandas_edgelist(links_filtered, 'Gene1', 'Gene2')
    
#     # color by fraction expressed
#     if color_mode=='fraction': 
#         node_vals = ad_col.var['gene_frac'].loc[list(G.nodes)]
#         cmap = 'YlGn'
#         vmin= 0
#         vmax= 0.8
#         nodescale = 1200
#         node_size = nodescale*node_vals
#     elif color_mode=='coeff': 
#         vmin = -0.15
#         vmax = 0.15
#         df_color = sig_genes_df.copy()
        
#         df_color.set_index('gene', inplace=True)
#         df_color = df_color[df_color.index.isin(list(G.nodes))]
#         df_color = df_color.reindex(G.nodes())
#         node_vals = df_color['coeff']
#         cmap = 'coolwarm'
#         node_size = 350
        
#     # node edge color by gene type
#     gene_color_dict = dict(zip(sig_genes_df['gene'], sig_genes_df['color']))
#     splice_type_color = [gene_color_dict[g] for g in list(G.nodes)]
        
#     plt.figure(figsize=[10,10])
# #     np.random.seed(42)
#     np.random.seed(random_seed)
#     nx.draw_kamada_kawai(G, with_labels=True,  
#                          edge_color='#D3D3D3', # edge
#                          width=2, # edge
#                          edgecolors=splice_type_color, # node
#                          linewidths=2.5,  # node
#                          font_size=14, # node
#                          node_color=node_vals,  # node
#                          node_size=node_size, # node
#                          alpha=0.8, 
#                          vmin=vmin,
#                          vmax=vmax,
#                          cmap=cmap)

#     plt.tight_layout()
#     plt.savefig('%s/DEG_graph_%0.3f_%s.pdf' %(sc.settings.figdir, thresh, figstr), bbox_inches='tight')
    
#     return DEG_sig_corr


In [None]:
def DEG_corr_clustermap(df_corr, figstr, thresh=0.08, savefig=True):
    plt.figure(figsize=[10,10])
    df_corr[df_corr<thresh]=0
    df_corr[np.isnan(df_corr)]=0
    df_corr[np.isinf(df_corr)]=0
    off_diag = df_corr.sum()>1
    df_off_diag = df_corr.loc[off_diag, off_diag]
    sns.set(font_scale=0.8)
    f = sns.clustermap(df_off_diag, 
                   xticklabels=df_off_diag.index, 
                   yticklabels=df_off_diag.columns,
                   cmap='coolwarm',
                  vmax=0.5)
    if savefig: 
        f.savefig('%s/DEG_corr_clustermap_%s_%0.2f.pdf' %(sc.settings.figdir, figstr, thresh), bbox_inches='tight')
    sns.reset_orig()

In [None]:
def pairwise_corr(ad, params, sig, cell_type, interaction=False, corr_thresh=0.05): 
#     cell_type = 'EX_neuron'
    cFos = list(sig['cFos_nCLR'][sig['cFos_nCLR'][cell_type]].index)
    p65 = list(sig['p65_nCLR'][sig['p65_nCLR'][cell_type]].index)
    
    if interaction: 
        cFos_p65 = list(sig['cFos_nCLR:p65_nCLR'][sig['cFos_nCLR:p65_nCLR'][cell_type]].index)
        DEGs = list(set(cFos + p65 + cFos_p65))
        
        prot_list = ['NeuN','PU1','p65','cFos','cFos_nCLR:p65_nCLR','cFos_nCLR:NeuN_nCLR','p65_nCLR:NeuN_nCLR']
        color_list = ['PiYG_r','PiYG_r','PiYG_r','PiYG_r','PiYG_r','PiYG_r','PiYG_r']
    else: 
        DEGs = list(set(cFos + p65))
        
        prot_list = ['NeuN','PU1','p65','cFos']
        color_list = ['PiYG_r','PiYG_r','PiYG_r','PiYG_r']

    genes_color = gene_feat_colors_cluster(sig, params, DEGs, [cell_type], 
                                           proteins=prot_list, colors=color_list)


    df_corr = get_corr_df(ad[ad.obs['annot']==cell_type], DEGs)
    
    return df_corr, genes_color

In [None]:
def clustermap_corr_DEGs(DEG_corr, genes_color, figstr='', vmin=-0.5, vmax=0.5, savefig=False):
    from matplotlib.patches import Rectangle
    
    sns.set(font_scale=0.8)
    g = sns.clustermap(DEG_corr, 
                    xticklabels=[], 
                    yticklabels=DEG_corr.index,
                    col_colors=genes_color,
                    dendrogram_ratio=0.08, 
                    colors_ratio=0.025,
                    vmin=vmin, vmax=vmax, 
                    cmap='bwr', 
                    cbar_pos=[1,0.6,0.02,0.1])
    g.ax_col_dendrogram.set_visible(False)
    g.ax_row_dendrogram.set_visible(False)
    ax = g.ax_heatmap
    ax.axhline(y=0, color='k',linewidth=2)
    ax.axvline(x=0, color='k',linewidth=2)
    figname = '%s/modules_cross_corr_cluster_%s.pdf' %(sc.settings.figdir, figstr)
    if savefig: 
        print('Saving to %s' %figname)
        plt.savefig(figname, bbox_inches='tight')
    sns.reset_orig()

## NMF plotting

In [None]:
def plot_top_genes_per_module(ax, top_scores, top_genes): 
    
    ax.barh(range(len(top_genes)), top_scores, color='#bfbfbf')
    ax.set_xlabel('Correlation')
    ax.set_yticks(range(len(top_genes)))
    ax.set_yticklabels(top_genes, rotation=0)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
#     ax.spines['bottom'].set_visible(False)
    ax.invert_yaxis()


In [None]:
def heatmap_module_genes(ad, mod_genes): 
    blah = []
    for k, v in mod_genes.items(): 
        blah.extend(list(v))

    DEG_corr_heatmap = get_corr_df(ad, blah)

    sns.set(font_scale=0.4)
    sns.heatmap(DEG_corr_heatmap, 
                    xticklabels=DEG_corr_heatmap.index,
                    yticklabels=DEG_corr_heatmap.index,
                   vmin=-0.3, vmax=0.3, 
                   cmap='bwr')
    sns.reset_orig()

In [None]:
def module_score_treatment_boxplot(ad, module, color_dict, figstr=''): 
    from statannot import add_stat_annotation
    df = ad.obs
    fig, ax = plt.subplots(figsize=(2,3))
    ax = sns.boxplot(data=df, x='assignment', y=module, 
                     order=['PBS','KainicAcid'], width=0.3,
                     fliersize=0,
                     palette=color_dict)

    for patch in ax.artists:
        r, g, b, a = patch.get_facecolor()
        patch.set_facecolor((r, g, b, 0.5))
        
    ax = sns.stripplot(data=df, x='assignment', y=module, 
                       order=['PBS','KainicAcid'], s=1, 
                       alpha=0.5, jitter=0.08,
                       palette=color_dict)
    test_results = add_stat_annotation(ax, data=df, 
                                       x='assignment', y=module, 
                                       order=['PBS','KainicAcid'],
                                       box_pairs=[('PBS', 'KainicAcid')],
                                       test='t-test_ind', text_format='star',
                                       loc='outside', verbose=2)
    figname = '%s/module_score_by_treatment_boxplot_%s.pdf' %(sc.settings.figdir, figstr)
    print('Saving to %s' %figname)
    fig.savefig(figname, bbox_inches='tight')
    
def module_score_treatment_violinplot(ad, module, color_dict, figstr=''): 
    from statannot import add_stat_annotation
    df = ad.obs
    fig, ax = plt.subplots(figsize=(2,2))
    ax = sns.violinplot(data=df, x='assignment', y=module, 
                     order=['PBS','KainicAcid'], width=0.3,
                     scale='width', saturation=0.7,
                     palette=color_dict)
    for patch in ax.artists:
        r, g, b, a = patch.get_facecolor()
        patch.set_facecolor((r, g, b, 0.5))
        
#     ax = sns.stripplot(data=df, x='assignment', y=module, 
#                        order=['PBS','KainicAcid'], s=0.5, 
#                        alpha=0.5, 
#                        palette=color_dict)
    ax.set(xlabel=None)
    ax.set(xticklabels=['PBS','KA'])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    figname = '%s/module_score_by_treatment_violinplot_%s.pdf' %(sc.settings.figdir, figstr)
    print('Saving to %s' %figname)
    fig.savefig(figname, bbox_inches='tight')

# GO analysis

In [None]:
def parse_GO_query(gene_list, species, db_to_keep='all'): 
    if db_to_keep=='all': 
        db_to_keep = ['GO:BP', 'GO:MF', 'KEGG', 'REAC', 'TF']
    GO_df = sc.queries.enrich(gene_list, org=species)
    GO_df = GO_df[GO_df['significant']==True]
    GO_df = GO_df[GO_df['source'].isin(db_to_keep)]
    return GO_df

In [None]:
def sig_genes_GO_query(sig, params, variate, clust_lim=1000):
    sig_genes = get_sig_gene_list(sig,params,variate)
    GO_results = pd.DataFrame([],columns=['cluster','source','name','p_value','description','native','parents'])
    
    idx_ct = 0
    clusters = sig[variate].columns
    for clust in clusters: 
        clust_ct = 0
        if len(sig_genes[clust])>1: 
            GO_df = parse_GO_query(sig_genes[clust],'mmusculus',['GO:BP','KEGG'])
            if len(GO_df)>0:
                for index, row in GO_df.iterrows():
                    if clust_ct<clust_lim:
                        GO_row = pd.DataFrame({'cluster':clust,'source':row['source'],
                                             'name':row['name'],'p_value':row['p_value'],
                                             'description':row['description'], 
                                             'native':row['native'], 'parents':[row['parents']]},
                                                index=[idx_ct])
                        clust_ct+=1
                        idx_ct+=1
                        GO_results = pd.concat([GO_results, GO_row])
    return GO_results

In [None]:
def plot_GO_terms(df,alpha,filename,colormap='#d3d3d3',xlims=[0,5]): 
    
    # add color column
    if colormap != '#d3d3d3': 
        df['color'] = df['cluster'].map(colormap)
        color=df['color']
    else: 
        color=colormap
    
    df = df.loc[df['p_value']<=alpha]
    
    fig_height = df.shape[0]*(1/10)
    
    fig, ax = plt.subplots(figsize=(3,fig_height))
    y_pos = np.arange(df.shape[0])
    log10p = -np.log10(df['p_value'].tolist())
    df['-log10p'] = log10p
    
    sns.reset_orig()
    ax.barh(y_pos, log10p, align='center', color=color)
    ax.set_yticks(y_pos)
#     ax.set_yticklabels(df['native']+':'+df['name'],fontsize=6)
    ax.set_yticklabels(df['name'],fontsize=6)
    ax.invert_yaxis()
    ax.set_xlabel('-log10(P)')
    ax.set_xlim(xlims)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(1)
#     plt.show()
    figname = '%s/GO_hbar_%s.pdf' %(sc.settings.figdir, filename)
    print('Saving to %s' %figname)
    fig.savefig(figname, bbox_inches='tight')

# Cell cycle

In [None]:
def add_three_state_cell_cycle(ad): 
    # load cell cycle genes - 3 phases, from Seurat tutorial
    seurat_cc_genes = [x.strip() for x in open('./regev_lab_cell_cycle_genes.txt')]
    s_genes = seurat_cc_genes[:43]
    g2m_genes = seurat_cc_genes[43:]
    seurat_cc_genes = [x for x in seurat_cc_genes if x in ad.var_names]
    
    # make a var field for True/False cell_cycle_genes
    ad.var['cell_cycle'] = False
    ad.var.loc[seurat_cc_genes, 'cell_cycle'] = True
    
    # score
    sc.tl.score_genes_cell_cycle(ad, s_genes=s_genes, g2m_genes=g2m_genes)
    
    sc.pl.scatter(ad,x='S_score',y='G2M_score',color='phase')
    
    adata_cc_genes = ad[:, seurat_cc_genes]
    sc.pp.scale(adata_cc_genes, max_value=10)
    sc.tl.pca(adata_cc_genes)
    sc.pl.pca_scatter(adata_cc_genes, color=['phase','assignment'])