### Run this script in case you used bash to create a BF file

In [1]:
import os
import sys
import pandas as pd
import numpy as np
from numpy import inf
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import pickle
import operator
import matplotlib
import scipy.stats as stats
import statsmodels.stats.multitest as multi
from collections import defaultdict
from ast import literal_eval
import scanpy as sc
import csv 
import anndata
csv.field_size_limit(sys.maxsize)
from matplotlib.patches import Patch
from matplotlib import cm
%matplotlib inline
plt.rcParams['figure.figsize'] = [15, 10]
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.size'] = 6
pd.set_option("display.max_rows", 50, "display.max_columns", 50)
sns.set_style("ticks")
sc.set_figure_params(scanpy=True, dpi=80, dpi_save=300, frameon=True, vector_friendly=True, fontsize=20, figsize=None, color_map=None, format='pdf', facecolor=None, transparent=False, ipython_format='png2x')

Load definitions

In [69]:
def intersection(lst1, lst2): 
    return list(set(lst1) & set(lst2)) 

def union(lst1, lst2): 
    final_list = lst1 + lst2 
    return final_list

def ftest(aba_spec_cutoff,st_spec_cutoff):
    # DIFFERENTIAL GENES PER REGION - Fisher's exact test
    bb_count = 0
    fisher_dict = {}
    pval_list = []

    ct = np.unique(aba_spec_cutoff['ident'].tolist())
    for condition, df in st_spec_cutoff.groupby('condition_1'):

            ####################### Fisher's exact test
            #########################################################################
            #regions_tmp = list(set(st_spec_cutoff['ABA_aar1'].tolist()))
            print(condition)
            regions_tmp = list(set(st_spec_cutoff['AAR1'].tolist()))

            regions = [x for x in regions_tmp if str(x) != 'nan']

            for i in regions:

                for j in ct:
                    # print(i,j)
                    #break

                    #ST genes
                    #st_genes = df[df['ABA_aar1'] == i]['gene'].tolist()
                    st_genes = df[df['AAR1'] == i]['gene_new'].tolist()

                    # ABA-genes
                    aba_genes = aba_spec_cutoff[aba_spec_cutoff['ident'] == j]['gene'].tolist()

                    # ST genes in all other regions
                    #st_rest = df[df['ABA_aar1'] != i]['gene'].tolist()
                    st_rest = df[df['AAR1'] != i]['gene_new'].tolist()

                    # ABA genes in all other regions
                    aba_rest = aba_spec_cutoff[aba_spec_cutoff['ident'] != j]['gene'].tolist()

                    # g1 = genes in both ST and ABA
                    # g2 = genes unique to ST
                    # g3 = genes unique to ABA
                    # g4 = genes neither in st or aba region but in the other regions

                    g1 = len(list(set(st_genes).intersection(aba_genes)))
                    g2 = len(list(set(aba_genes).difference(set(st_genes)))) 
                    g3 = len(list(set(st_genes).difference(set(aba_genes))))
                    g4 = len(list(set(st_rest).intersection(aba_rest)))

                    # print(list(set(st_genes).intersection(aba_genes)))

                    # Fisher's test
                    oddsratio, pvalue = stats.fisher_exact([[g4, g2], [g3, g1]], alternative='greater')

                    # Store pvalues in list to use for multiple corrections testing
                    pval_list.append(pvalue)

                    # Store fisher's test results in DF
                    ff = [condition, i, j, oddsratio, pvalue, g1,list(set(st_genes).intersection(aba_genes)) ]
                    # print(i, j, g1, g2, g3, g4, pvalue)

                    if bb_count == 0:
                        fisher_dict[bb_count] = ff

                        df_ff = pd.DataFrame.from_dict(fisher_dict)

                        df_ff['idx'] = ['condition', 'AAR_ST', 'ident','Odds ratio', 'p value', 'Num shared genes', 'shared genes']

                        df_ff.set_index('idx', inplace = True)

                        bb_count += 1
                    else:
                        df_ff[bb_count] = ff

                        bb_count += 1

    # Do multiple testing correction on the pvalues
    pp = multi.multipletests(pval_list, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)

    df_ff_t = df_ff.T 

    # Add corrected p-values
    df_ff_t['p-value, corrected'] = list(pp[1])
    
    return df_ff_t

def grouped_obs_mean(adata, group_key, layer=None, gene_symbols=None):
    if layer is not None:
        getX = lambda x: x.layers[layer]
    else:
        getX = lambda x: x.X
    if gene_symbols is not None:
        new_idx = adata.var[idx]
    else:
        new_idx = adata.var_names

    grouped = adata.obs.groupby(group_key)
    out = pd.DataFrame(
        np.zeros((adata.shape[1], len(grouped)), dtype=np.float64),
        columns=grouped.groups.keys(),
        index=adata.var_names
    )

    for group, idx in grouped.indices.items():
        X = getX(adata[idx])
        out[group] = np.ravel(X.mean(axis=0, dtype=np.float64))
    return out

def cluster_color_map(cc):
    # create color map
    cmap1 = cm.get_cmap('tab20b')
    c1 = [matplotlib.colors.rgb2hex(cmap1(i)) for i in range(cmap1.N)]
    cmap2 = cm.get_cmap('tab20c')
    c2 = [matplotlib.colors.rgb2hex(cmap2(i)) for i in range(cmap2.N)]
    cmap3 = cm.get_cmap('Accent')
    c3 = [matplotlib.colors.rgb2hex(cmap3(i)) for i in range(cmap3.N)]
    cmap4 = cm.get_cmap('Set2')
    c4 = [matplotlib.colors.rgb2hex(cmap4(i)) for i in range(cmap4.N)]
    cmap5 = cm.get_cmap('Pastel1')
    c5 = [matplotlib.colors.rgb2hex(cmap5(i)) for i in range(cmap5.N)]
    cmap6 = cm.get_cmap('Set1')
    c6 = [matplotlib.colors.rgb2hex(cmap6(i)) for i in range(cmap6.N)]
    cmap7 = cm.get_cmap('Dark2')
    c7 = [matplotlib.colors.rgb2hex(cmap7(i)) for i in range(cmap7.N)]
    
    c = c1 + c2
    if cc == 'tab20b':
        return c1
    if cc == 'tab20c':
        return c2    
    if cc == 'Accent':
        return c3
    if cc == 'Set2':
        return c4
    if cc == 'Pastel1':
        return c5
    if cc == 'Set1':
        return c6
    if cc == 'Dark2':
        return c7
    if cc == 'all':
        return c

def labeled_clustermap(a, gene, obs_cat,use_common_regions = False):
    # pick gene
    # gene = ['Tff3', 'Actb']

    # pre-process for common areas
    tmp =  a[:,gene]
    # obs_cat = ['Genotype', 'Sex', 'annotation','Specimen_ID'] #, 'Sex','Specimen_ID'
    preproc = grouped_obs_mean(tmp, obs_cat).T

    # Set levels variable to empty 
    ann_multiindex = []
    region_multiindex = []
    age_multiindex = []
    sex_multiindex = []
    id_multiindex = []

    # get index location
    if 'annotation' in obs_cat:
        ann_multiindex = np.where(np.asarray(obs_cat) == "annotation")[0][0]
    if 'Genotype' in obs_cat:
        region_multiindex = np.where(np.asarray(obs_cat) == "Genotype")[0][0]
    if 'Age' in obs_cat:
        age_multiindex = np.where(np.asarray(obs_cat) == "Age")[0][0]
    if 'Sex' in obs_cat:
        sex_multiindex = np.where(np.asarray(obs_cat) == "Sex")[0][0]
    if 'Specimen_ID' in obs_cat:
        id_multiindex = np.where(np.asarray(obs_cat) == "Specimen_ID")[0][0]

    # get expression info
    gene_df = preproc


    columns = []
    if gene_df.index.nlevels > 1:
        for i in range(0,len(gene_df.index[0])):
            if i != ann_multiindex:
                columns.append(gene_df.index.get_level_values(i))
    else:
        columns.append(gene_df.index)

    if ann_multiindex:
        htdata2 = pd.pivot_table(gene_df,  values=gene_df.columns, 
                             columns=[gene_df.index.get_level_values(ann_multiindex)], 
                             index = columns,
                                 fill_value=min(gene_df.min()))
    else:
        htdata2 = pd.pivot_table(gene_df,  values=gene_df.columns, 
                         index = columns,
                             fill_value=min(gene_df.min()))

    
    # subset to common areas
    if ann_multiindex:
        if use_common_regions == True:
            htdata2 = htdata2.loc[:, (htdata2 != min(htdata2.min())).all(axis=0)]
    
    # set dendogram cluster colors
    row_colors = []
    columns_colors = []
    cat_cols_dict = dict()
    row_colors_dict = dict()
    if htdata2.index.nlevels > 1:
        for i in range(0,len(htdata2.index[0])):       
                # color first category
                if i == 0:
                    c = cluster_color_map('Accent')
                if i == 1:
                    c = cluster_color_map('Set2')
                if i == 2:
                    c = cluster_color_map('Set1')
                if i == 3:
                    c = cluster_color_map('Pastel1')
                row_cols = dict(zip(np.unique([j for j in htdata2.index.get_level_values(i)]), c))
                row_color = pd.Series([j for j in htdata2.index.get_level_values(i)]).map(row_cols)
                row_color.name = obs_cat[i]
                row_colors.append(row_color)
                row_colors_dict.update(row_cols)
    else:
        c = cluster_color_map('Accent')
        row_color = dict(zip(np.unique([j for j in htdata2.index]), c))
        row_colors = pd.Series([j for j in htdata2.index]).map(row_color)
        row_colors_dict = row_color
        row_colors.name = obs_cat[0]

    if htdata2.columns.nlevels > 1:
        for i in range(0,len(htdata2.columns[0])):       
                # color first category
                if i == 1:
                    c = cluster_color_map('all')
                if i == 0:
                    c = cluster_color_map('Dark2')
                cat_cols = dict(zip(np.unique([j for j in htdata2.columns.get_level_values(i)]), c))
                col_color = pd.Series([j for j in htdata2.columns.get_level_values(i)]).map(cat_cols)
                if i == 1:
                    col_color.name = 'annotation'
                else:
                    col_color.name = 'Genes'
                columns_colors.append(col_color)
                cat_cols_dict.update(cat_cols)
    else:
        c = cluster_color_map('Dark2')
        cat_cols = dict(zip(np.unique([j for j in htdata2.columns]), c))
        columns_colors = pd.Series([j for j in htdata2.columns]).map(cat_cols)
        cat_cols_dict = cat_cols
        columns_colors.name = 'Genes'

    if not isinstance(columns_colors, list): 
        columns_colors = [columns_colors]
    if not isinstance(row_colors, list): 
        row_colors = [row_colors]

    # gets final colors 
    rcol = pd.DataFrame(row_colors).T
    rcol.index = htdata2.index
    ccol = pd.DataFrame(columns_colors).T
    ccol.index = htdata2.columns

    # print("DE results for gene:", gene)

    # Plot heatmap
    sns.set(font_scale=.45)
    sns.set_style('ticks')

    hb = sns.clustermap(htdata2,row_cluster=False, vmin = 0, row_colors = rcol, col_colors = ccol,
                        col_cluster=False, cmap = 'magma', linewidth = 0.05, 
                        linecolor = 'black', cbar_kws={'label': u'lambda', 'pad':0})
    plt.setp(hb.ax_heatmap.get_yticklabels(), rotation=0, ha="left",
         rotation_mode="anchor")
    hb.ax_heatmap.set_xlabel('')
    hb.ax_heatmap.set_ylabel('')
    handles = [Patch(facecolor={**cat_cols_dict, **row_colors_dict}[name]) for name in {**cat_cols_dict, **row_colors_dict}]
    plt.legend(handles, {**cat_cols_dict, **row_colors_dict}, title='color annotations',
               bbox_to_anchor=(0, 1), bbox_transform=plt.gcf().transFigure, loc='upper right')

    return hb

#Define cluster score for individual genes
def marker_gene_expression(anndata, marker_dict, gene_symbol_key=None, partition_key='louvain_r1'):
    """
    A function to get mean z-score expressions of marker genes
    # 
    # Inputs:
    #    anndata         - An AnnData object containing the data set and a partition
    #    marker_dict     - A dictionary with cell-type markers. The markers should be stores as anndata.var_names or 
    #                      an anndata.var field with the key given by the gene_symbol_key input
    #    gene_symbol_key - The key for the anndata.var field with gene IDs or names that correspond to the marker 
    #                      genes
    #    partition_key   - The key for the anndata.obs field where the cluster IDs are stored. The default is
    #                      'louvain_r1' 
    """

    #Test inputs
    if partition_key not in anndata.obs.columns.values:
        print('KeyError: The partition key was not found in the passed AnnData object.')
        print('   Have you done the clustering? If so, please tell pass the cluster IDs with the AnnData object!')
        raise

    if (gene_symbol_key != None) and (gene_symbol_key not in anndata.var.columns.values):
        print('KeyError: The provided gene symbol key was not found in the passed AnnData object.')
        print('   Check that your cell type markers are given in a format that your anndata object knows!')
        raise
        
    if gene_symbol_key:
        gene_ids = anndata.var[gene_symbol_key]
    else:
        gene_ids = anndata.var_names

    clusters = anndata.obs[partition_key].cat.categories
    n_clust = len(clusters)
    marker_exp = pd.DataFrame(columns=clusters)
    marker_exp['cell_type'] = pd.Series({}, dtype='str')
    marker_names = []
    
    z_scores = sc.pp.scale(anndata, copy=True)

    i = 0
    for group in marker_dict:
        # Find the corresponding columns and get their mean expression in the cluster
        for gene in marker_dict[group]:
            ens_idx = np.in1d(gene_ids, gene) #Note there may be multiple mappings
            if np.sum(ens_idx) == 0:
                continue
            else:
                z_scores.obs[ens_idx[0]] = z_scores.X[:,ens_idx].mean(1) #works for both single and multiple mapping
                ens_idx = ens_idx[0]

            clust_marker_exp = z_scores.obs.groupby(partition_key)[ens_idx].apply(np.mean).tolist()
            clust_marker_exp.append(group)
            marker_exp.loc[i] = clust_marker_exp
            marker_names.append(gene)
            i+=1

    #Replace the rownames with informative gene symbols
    marker_exp.index = marker_names

    return(marker_exp)

def splotch2anndata(ST_top_gene_dict, a, mode):
    """
    A function to add DE genes as ranked genes to scanpy anndata object
    # 
    # Inputs:
    #    a                  - An AnnData object containing the data set and a partition:conditions and annotation
    #    ST_top_gene_dict   - A pd.DataFrame with fields: age_1, age_2, region_1, region_2, AAR1, AAR2, logsBFs (list), Delta (list), genes (list)  
    #    mode               - A string denotype type of analysis to be collected: annotation_analysis,genotype_analysis,temporal_analysis

    """

    ### Add DE genes as ranked genes to scanpy anndata object

    # make sure something to merge on
    ST_top_gene_dict['final_conditions'] = [i+"_"+j+"_"+k for i,j,k in zip(ST_top_gene_dict['region_1'], ST_top_gene_dict['age_1'], ST_top_gene_dict['AAR1'])]
    a.obs['final_conditions'] = [i+"_"+j for i,j in zip(a.obs['conditions'], a.obs['annotation'])]

    # filters ST_top for 'Annotation analysis'
    if mode == 'annotation_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['age_1'] == ST_top_gene_dict['age_2']) & (ST_top_gene_dict['region_1'] == ST_top_gene_dict['region_2']) & (ST_top_gene_dict['AAR2'] == 'Rest')]
    if mode == 'genotype_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['age_1'] == ST_top_gene_dict['age_2']) & (ST_top_gene_dict['AAR2'] != 'Rest')]
    if mode == 'temporal_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['region_1'] == ST_top_gene_dict['region_2']) & (ST_top_gene_dict['AAR2'] != 'Rest')]
    ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['final_conditions'], right_on = ['final_conditions'])[['annotation', 'logBFs', 'Delta','genes']]
    ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['annotation'], keep='first').reset_index(drop=True)

    # creates ranked genes object
    rank_genes_groups = dict()
    rank_genes_groups['params'] = dict(groupby = 'annotation',
                                       reference = 'rest',
                                       method = mode,
                                       use_raw = False,
                                       layer = None,)

    rank_genes_groups['names'] = pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = ann_pd_merged['annotation']).T.to_records(column_dtypes='O',index=False)
    rank_genes_groups['names'] = pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = ann_pd_merged['annotation']).T.to_records(column_dtypes='O',index=False)
    rank_genes_groups['logfoldchanges'] = pd.DataFrame([pd.Series(i) for i in ann_pd_merged['Delta']], index = ann_pd_merged['annotation']).T.to_records(column_dtypes='O',index=False)
    rank_genes_groups['pvals'] = pd.DataFrame([pd.Series(i) for i in ann_pd_merged['logBFs']], index = ann_pd_merged['annotation']).T.to_records(column_dtypes='O',index=False)

    # creates markers dict
    tmp = pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = ann_pd_merged['annotation'])
    markers_dict_annotation_analysis = {}
    for i,j in enumerate(tmp.index):
        y = tmp.iloc[i,:]
        markers_dict_annotation_analysis[j] = [x for x in y if str(x) != 'nan']

    
    # makes zscores
    marker_gene_expressions = marker_gene_expression(a, markers_dict_annotation_analysis, gene_symbol_key=None, partition_key='annotation')
    marker_gene_expressions = marker_gene_expressions.drop(labels='cell_type', axis=1).to_records(column_dtypes='O',index=False)
    rank_genes_groups['scores'] = marker_gene_expressions
    
    #return a new anndata object
    print(mode)
    #atest = a.copy()
    a.uns[mode] = rank_genes_groups
    
def splotch2anndata_v3(ST_top_gene_dict, a, mode, conditions_order = None):
    
    #ST_top_gene_dict['final_conditions'] = [i+"_"+j+"_"+k for i,j,k in zip(ST_top_gene_dict['region_1'], ST_top_gene_dict['age_1'], ST_top_gene_dict['AAR1'])]
    #a.obs['final_conditions'] = [i+"_"+j for i,j in zip(a.obs['conditions'], a.obs['annotation'])]


    # creates ranked genes object
    rank_genes_groups = dict()
    rank_genes_groups['params'] = dict(groupby = 'annotation',
                                       reference = 'rest',
                                       method = mode,
                                       use_raw = False,
                                       layer = None,)


    # filters ST_top for 'Annotation analysis'
    if mode == 'annotation_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['age_1'] == ST_top_gene_dict['age_2']) & (ST_top_gene_dict['region_1'] == ST_top_gene_dict['region_2']) & (ST_top_gene_dict['AAR2'] == 'rest')]
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['AAR1'], right_on = ['annotation'])[['index', 'annotation', 'logBFs', 'Delta','genes','age_1', 'region_1', 'AAR1']]
        ann_pd_merged['dups'] = [i+"_"+j+"_"+k for i,j,k in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'], ann_pd_merged['annotation'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        if conditions_order == None: 
            conditions_order = np.unique(ann_pd_merged['annotation'])  
        ann_pd_merged = ann_pd_merged[ann_pd_merged['annotation'].isin(conditions_order)]
               
        
        #inx_name = [i+'_'+j for i,j in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'])]
        inx_name = [i for i in ann_pd_merged['annotation']]
        #rank_genes_groups['names'] = np.unique(pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).astype(str).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False))
        rank_genes_groups['names'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False)
        rank_genes_groups['logfoldchanges'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['Delta']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float32',index=False)
        rank_genes_groups['pvals'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['logBFs']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float64',index=False)
        #ann_pd_merged['age1_region1'] = [i+"_"+j for i,j in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'])]
        ann_pd_merged['age1_region1'] = [i for i in ann_pd_merged['annotation']]
        means_pd_index = ann_pd_merged.groupby('age1_region1')['index'].apply(list).reset_index()
        marker_gene_means = []
        cond_names = []
        #for cond in means_pd_index['age1_region1']:
        for cond in np.array(conditions_order):
            if not cond in list(inx_name):
                continue
            cond_names.append(cond)
            sub = means_pd_index[means_pd_index['age1_region1'] == cond]['index'].iloc[0]
            asub = a[a.obs.index.isin(sub)]
            asub = asub[:,[x for x in rank_genes_groups['names'][cond] if str(x) != 'nan']]
            #asub = asub[:,rank_genes_groups['names'][cond]]
            asub.obs['merging'] = cond
            asub.var_names_make_unique()
            marker_gene_means.append(list(grouped_obs_mean(asub, group_key = 'merging').reindex(rank_genes_groups['names'][cond])[cond]))
        rank_genes_groups['scores'] = pd.DataFrame(marker_gene_means, index = cond_names).T.to_records(column_dtypes='float32',index=False)

    if mode == 'genotype_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['age_1'] == ST_top_gene_dict['age_2']) & (ST_top_gene_dict['region_1'] != ST_top_gene_dict['region_2']) & (ST_top_gene_dict['AAR2'] != 'rest')]
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['region_1'], right_on = ['Region'])[['index', 'annotation', 'logBFs', 'Delta','genes','age_1','region_1', 'region_2']]
        ann_pd_merged['dups'] = [i+j+k+l for i,j,k,l in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'], ann_pd_merged['region_2'], ann_pd_merged['annotation'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        ann_pd_merged['region1_region2'] = [i+"_vs_"+j for i,j in zip(ann_pd_merged['region_1'], ann_pd_merged['region_2'])]
        if conditions_order == None: 
            conditions_order = np.unique(ann_pd_merged['region1_region2'])
        ann_pd_merged = ann_pd_merged[ann_pd_merged['region1_region2'].isin(conditions_order)]
        inx_name = [i+'_vs_'+j for i,j in zip(ann_pd_merged['region_1'], ann_pd_merged['region_2'])]
        
        #rank_genes_groups['names'] = np.unique(pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).astype(str).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False))
        rank_genes_groups['names'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False)
        rank_genes_groups['logfoldchanges'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['Delta']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float32',index=False)
        rank_genes_groups['pvals'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['logBFs']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float64',index=False)

        means_pd_index = ann_pd_merged.groupby('region1_region2')['index'].apply(list).reset_index()
        marker_gene_means = []
        cond_names = []
        #for cond in means_pd_index['region1_region2']:
        #    cond_names.append(cond)
        
        #ann_pd_merged['index'] = inx_name   
        for cond in conditions_order:
            if not cond in list(inx_name):
                continue
            cond_names.append(cond)
            sub = means_pd_index[means_pd_index['region1_region2'] == cond]['index'].iloc[0]
            asub = a[a.obs.index.isin(sub)]
            asub = asub[:,[x for x in rank_genes_groups['names'][cond] if str(x) != 'nan']]
            #asub = asub[:,rank_genes_groups['names'][cond]]
            asub.obs['merging'] = cond
            asub.var_names_make_unique()
            marker_gene_means.append(list(grouped_obs_mean(asub, group_key = 'merging').reindex(rank_genes_groups['names'][cond])[cond]))
        rank_genes_groups['scores'] = pd.DataFrame(marker_gene_means, index = cond_names).T.to_records(column_dtypes='float32',index=False)

    if mode == 'temporal_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['region_1'] == ST_top_gene_dict['region_2']) & (ST_top_gene_dict['age_1'] != ST_top_gene_dict['age_2']) & (ST_top_gene_dict['AAR2'] != 'rest')]
        #ann_pd['age1_age2'] = [i+"_vs_"+j for i,j in zip(ann_pd['age_1'], ann_pd['age_2'])]
        #print(np.unique(ann_pd['age1_age2']))
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['age_1'], right_on = ['Age'])[['index', 'annotation', 'logBFs', 'Delta','genes','age_1','age_2', 'region_1']]
        #ann_pd_merged['age1_age2'] = [i+"_vs_"+j for i,j in zip(ann_pd_merged['age_1'], ann_pd_merged['age_2'])]
        #print(np.unique(ann_pd_merged['age1_age2']))
        ann_pd_merged['dups'] = [i+j+k+l for i,j,k,l in zip(ann_pd_merged['region_1'], ann_pd_merged['age_1'], ann_pd_merged['age_2'], ann_pd_merged['annotation'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        
        ann_pd_merged['age1_age2'] = [i+"_vs_"+j for i,j in zip(ann_pd_merged['age_1'], ann_pd_merged['age_2'])]
        #print(np.unique(ann_pd_merged['age1_age2']))
        if conditions_order == None: 
            conditions_order = np.unique(ann_pd_merged['age1_age2'])
        ann_pd_merged = ann_pd_merged[ann_pd_merged['age1_age2'].isin(conditions_order)]
        inx_name = [i+'_vs_'+j for i,j in zip(ann_pd_merged['age_1'], ann_pd_merged['age_2'])]
        #ann_pd_merged['index'] = inx_name 
        #rank_genes_groups['names'] = np.unique(pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).astype(str).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False))
        rank_genes_groups['names'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False)
        rank_genes_groups['logfoldchanges'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['Delta']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float32',index=False)
        rank_genes_groups['pvals'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['logBFs']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float64',index=False)

        means_pd_index = ann_pd_merged.groupby('age1_age2')['index'].apply(list).reset_index()
        marker_gene_means = []
        cond_names = []
        #for cond in means_pd_index['age1_age2']:
        #    cond_names.append(cond)
            
           
        for cond in np.array(conditions_order):
            if not cond in list(inx_name):
                continue
            cond_names.append(cond)
            sub = means_pd_index[means_pd_index['age1_age2'] == cond]['index'].iloc[0]
            asub = a[a.obs.index.isin(sub)]
            asub = asub[:,[x for x in rank_genes_groups['names'][cond] if str(x) != 'nan']]
            #asub = asub[:,rank_genes_groups['names'][cond]]
            asub.obs['merging'] = cond
            asub.var_names_make_unique()
            marker_gene_means.append(list(grouped_obs_mean(asub, group_key = 'merging').reindex(rank_genes_groups['names'][cond])[cond]))
        rank_genes_groups['scores'] = pd.DataFrame(marker_gene_means, index = cond_names).T.to_records(column_dtypes='float32',index=False)

    #return a new anndata object
    print(mode)
    a.uns[mode] = rank_genes_groups


def splotch2anndata_v2(ST_top_gene_dict, a, mode):
    """
    A function to add DE genes as ranked genes to scanpy anndata object
    # 
    # Inputs:
    #    a                  - An AnnData object containing the data set and a partition:conditions and annotation
    #    ST_top_gene_dict   - A pd.DataFrame with fields: age_1, age_2, region_1, region_2, AAR1, AAR2, logsBFs (list), Delta (list), genes (list)  
    #    mode               - A string denotype type of analysis to be collected: annotation_analysis,genotype_analysis,temporal_analysis

    """


    ST_top_gene_dict['final_conditions'] = [i+"_"+j+"_"+k for i,j,k in zip(ST_top_gene_dict['region_1'], ST_top_gene_dict['age_1'], ST_top_gene_dict['AAR1'])]
    a.obs['final_conditions'] = [i+"_"+j for i,j in zip(a.obs['conditions'], a.obs['annotation'])]

    # filters ST_top for 'Annotation analysis'
    if mode == 'annotation_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['age_1'] == ST_top_gene_dict['age_2']) & (ST_top_gene_dict['region_1'] == ST_top_gene_dict['region_2']) & (ST_top_gene_dict['AAR2'] == 'rest')]
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['final_conditions'], right_on = ['final_conditions'])[['annotation', 'logBFs', 'Delta','genes','age_1', 'region_1', 'AAR1']]
        ann_pd_merged['dups'] = [i+j+k for i,j,k in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'], ann_pd_merged['annotation'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        inx_name = [i+'_'+j for i,j in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'])]
    if mode == 'genotype_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['age_1'] == ST_top_gene_dict['age_2']) & (ST_top_gene_dict['region_1'] != ST_top_gene_dict['region_2']) & (ST_top_gene_dict['AAR2'] != 'rest')]
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['final_conditions'], right_on = ['final_conditions'])[['annotation', 'logBFs', 'Delta','genes','age_1','region_1', 'region_2']]
        ann_pd_merged['dups'] = [i+j+k+l for i,j,k,l in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'], ann_pd_merged['region_2'], ann_pd_merged['annotation'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        inx_name = [i+'_vs_'+j for i,j in zip(ann_pd_merged['region_1'], ann_pd_merged['region_2'])]
    if mode == 'temporal_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['region_1'] == ST_top_gene_dict['region_2']) & (ST_top_gene_dict['age_1'] != ST_top_gene_dict['age_2']) & (ST_top_gene_dict['AAR2'] != 'rest')]
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['final_conditions'], right_on = ['final_conditions'])[['annotation', 'logBFs', 'Delta','genes','age_1','age_2', 'region_1']]
        ann_pd_merged['dups'] = [i+j+k+l for i,j,k,l in zip(ann_pd_merged['region_1'], ann_pd_merged['age_1'], ann_pd_merged['age_2'], ann_pd_merged['annotation'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        inx_name = [i+'_vs_'+j for i,j in zip(ann_pd_merged['age_1'], ann_pd_merged['age_2'])]
        
        
    # creates ranked genes object
    rank_genes_groups = dict()
    rank_genes_groups['params'] = dict(groupby = 'annotation',
                                       reference = 'rest',
                                       method = mode,
                                       use_raw = False,
                                       layer = None,)

    #rank_genes_groups['names'] = np.unique(pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).astype(str).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False))
    rank_genes_groups['names'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False)
    rank_genes_groups['logfoldchanges'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['Delta']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float32',index=False)
    rank_genes_groups['pvals'] = pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['logBFs']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float64',index=False)

    # creates markers dict
    tmp_genes = pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)
    tmp_bfs = pd.DataFrame([pd.Series(i) for i in ann_pd_merged['logBFs']], index = inx_name)
    tmp_deltas = pd.DataFrame([pd.Series(i) for i in ann_pd_merged['Delta']], index = inx_name)
    markers_dict_annotation_analysis = {}
    bfs_dict_annotation_analysis = {}
    deltas_dict_annotation_analysis = {}
    for i,j in enumerate(tmp_genes.index):
        y = tmp_genes.iloc[i,:]
        k = tmp_bfs.iloc[i,:]
        l = tmp_deltas.iloc[i,:]
        markers_dict_annotation_analysis[j] = [x for x in y if str(x) != 'nan']
        bfs_dict_annotation_analysis[j] = [x for x in k if str(x) != 'nan']
        deltas_dict_annotation_analysis[j] = [x for x in l if str(x) != 'nan']

    # makes zscores
    marker_gene_expressions = marker_gene_expression(a, markers_dict_annotation_analysis, gene_symbol_key=None, partition_key='annotation')
    marker_gene_expressions = marker_gene_expressions.drop(labels='cell_type', axis=1).to_records(column_dtypes='O',index=False)
    rank_genes_groups['scores'] = marker_gene_expressions

    #return a new anndata object
    print(mode)
    a.uns[mode] = rank_genes_groups
    
    
    
    a.obs['index'] = a.obs.index

def splotch2anndata_v4(ST_top_gene_dict, a, mode, conditions_order = None):    

    ST_top_gene_dict['final_conditions'] = [i+"_"+j+"_"+k for i,j,k in zip(ST_top_gene_dict['region_1'], ST_top_gene_dict['age_1'], ST_top_gene_dict['AAR1'])]
    a.obs['final_conditions'] = [i+"_"+j for i,j in zip(a.obs['conditions'], a.obs['annotation'])]

    # creates ranked genes object
    rank_genes_groups = dict()
    rank_genes_groups['params'] = dict(groupby = 'annotation',
                                       reference = 'rest',
                                       method = mode,
                                       use_raw = False,
                                       layer = None,)

    if mode == 'temporal_analysis':
        
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['region_1'] == ST_top_gene_dict['region_2']) & (ST_top_gene_dict['age_1'] != ST_top_gene_dict['age_2']) & (ST_top_gene_dict['AAR2'] != 'rest')]
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['age_1'], right_on = ['Age'])[['index', 'annotation', 'logBFs', 'Delta','genes','age_1','age_2', 'region_1']]
        #ann_pd_merged['dups'] = [i+j+k for i,j,k in zip(ann_pd_merged['region_1'], ann_pd_merged['age_1'], ann_pd_merged['age_2'])]
        ann_pd_merged['dups'] = [i+j+k+l for i,j,k,l in zip(ann_pd_merged['region_1'], ann_pd_merged['age_1'], ann_pd_merged['age_2'], ann_pd_merged['annotation'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        ann_pd_merged['age'] = [i+"_vs_"+j for i,j in zip(ann_pd_merged['age_1'], ann_pd_merged['age_2'])]
        #ann_pd_merged['age'] = [i for i in ann_pd_merged['age_1']]
        if conditions_order == None: 
            conditions_order = np.unique(ann_pd_merged['age'])
        ann_pd_merged = ann_pd_merged[ann_pd_merged['age'].isin(conditions_order)]
        #print(np.unique(ann_pd_merged.age))

        #makes sure unique gene names per group
        tmp = ann_pd_merged.groupby("age").sum()
        tmp = tmp.replace(to_replace='None', value=np.nan).dropna()
        ann_pd_merged_tmp = pd.DataFrame(columns = ['genes', 'Delta', 'logBFs'])
        for i,j in enumerate(tmp.index):
            gn_inx = list(np.unique(list(tmp.iloc[i,:]['genes']),return_index=True))
            bfs_ind = np.argsort(np.array([list(tmp.iloc[i,:]['logBFs'])[j] for j in gn_inx[1]]))[::-1]
            ann_pd_merged_tmp.at[i,'genes'] = [[list(tmp.iloc[i,:]['genes'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
            ann_pd_merged_tmp.at[i,'Delta'] = [[list(tmp.iloc[i,:]['Delta'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
            ann_pd_merged_tmp.at[i,'logBFs'] = [[list(tmp.iloc[i,:]['logBFs'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
        ann_pd_merged_tmp['age'] = tmp.index
        inx_name = ann_pd_merged_tmp['age']

        #rank_genes_groups['names'] = np.unique(pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).astype(str).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False))
        rank_genes_groups['names'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['genes']], index = inx_name)]).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False)
        rank_genes_groups['logfoldchanges'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['Delta']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float32',index=False)
        rank_genes_groups['pvals'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['logBFs']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float64',index=False)

        #means_pd_index = ann_pd_merged_tmp.groupby('age')['index'].apply(list).reset_index()
        marker_gene_means = []
        cond_names = []
        for cond in conditions_order:
            if not cond in list(ann_pd_merged_tmp['age']):
                continue
            cond_names.append(cond)
            asub = a[:,[x for x in rank_genes_groups['names'][cond] if str(x) != 'nan']]
            asub.obs['merging'] = cond
            sc.pp.scale(asub, max_value=10)
            marker_gene_means.append(list(grouped_obs_mean(asub, group_key = 'merging').reindex(rank_genes_groups['names'][cond])[cond]))
        rank_genes_groups['scores'] = pd.DataFrame(marker_gene_means, index = cond_names).T.to_records(column_dtypes='float32',index=False)
        
    if mode == 'genotype_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['age_1'] == ST_top_gene_dict['age_2']) & (ST_top_gene_dict['region_1'] != ST_top_gene_dict['region_2']) & (ST_top_gene_dict['AAR2'] != 'rest')]
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['final_conditions'], right_on = ['final_conditions'])[['index', 'annotation', 'logBFs', 'Delta','genes','age_1','region_1', 'region_2']]
        ann_pd_merged['dups'] = [i+j+k+l for i,j,k,l in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'], ann_pd_merged['region_2'], ann_pd_merged['annotation'])]
        #ann_pd_merged['dups'] = [i+j+k for i,j,k in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'], ann_pd_merged['region_2'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        ann_pd_merged['region'] = [i+"_vs_"+j for i,j in zip(ann_pd_merged['region_1'],ann_pd_merged['region_2'])]
        #ann_pd_merged['region'] = [i for i in ann_pd_merged['region_1']]
        if conditions_order == None: 
            conditions_order = np.unique(ann_pd_merged['region'])
        
        ann_pd_merged = ann_pd_merged[ann_pd_merged['region'].isin(conditions_order)]

        #makes sure unique gene names per group
        tmp = ann_pd_merged.groupby("region").sum()
        tmp = tmp.replace(to_replace='None', value=np.nan).dropna()
        ann_pd_merged_tmp = pd.DataFrame(columns = ['genes', 'Delta', 'logBFs'])
        for i,j in enumerate(tmp.index):
            gn_inx = list(np.unique(list(tmp.iloc[i,:]['genes']),return_index=True))
            bfs_ind = np.argsort(np.array([list(tmp.iloc[i,:]['logBFs'])[j] for j in gn_inx[1]]))[::-1]
            ann_pd_merged_tmp.at[i,'genes'] = [[list(tmp.iloc[i,:]['genes'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
            ann_pd_merged_tmp.at[i,'Delta'] = [[list(tmp.iloc[i,:]['Delta'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
            ann_pd_merged_tmp.at[i,'logBFs'] = [[list(tmp.iloc[i,:]['logBFs'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
        ann_pd_merged_tmp['region'] = tmp.index
        inx_name = ann_pd_merged_tmp['region']

        #rank_genes_groups['names'] = np.unique(pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).astype(str).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False))
        rank_genes_groups['names'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['genes']], index = inx_name)]).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False)
        rank_genes_groups['logfoldchanges'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['Delta']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float32',index=False)
        rank_genes_groups['pvals'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['logBFs']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float64',index=False)

        #means_pd_index = ann_pd_merged_tmp.groupby('age')['index'].apply(list).reset_index()
        marker_gene_means = []
        cond_names = []
        for cond in np.array(conditions_order):
            if not cond in list(ann_pd_merged_tmp['region']):
                continue
            cond_names.append(cond)
            asub = a[:,[x for x in rank_genes_groups['names'][cond] if str(x) != 'nan']]
            asub.obs['merging'] = cond
            sc.pp.scale(asub, max_value=10)
            marker_gene_means.append(list(grouped_obs_mean(asub, group_key = 'merging').reindex(rank_genes_groups['names'][cond])[cond]))
        rank_genes_groups['scores'] = pd.DataFrame(marker_gene_means, index = cond_names).T.to_records(column_dtypes='float32',index=False)        
        
    if mode == 'annotation_analysis':
        ann_pd = ST_top_gene_dict[(ST_top_gene_dict['region_1'] == ST_top_gene_dict['region_2']) & (ST_top_gene_dict['age_1'] != ST_top_gene_dict['age_2']) & (ST_top_gene_dict['AAR2'] != 'rest')]
        ann_pd_merged = pd.merge(ann_pd, a.obs, left_on = ['final_conditions'], right_on = ['final_conditions'])[['index', 'annotation', 'logBFs', 'Delta','genes','age_1','region_1', 'region_2']]
        ann_pd_merged['dups'] = [i+j+k for i,j,k in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'],  ann_pd_merged['annotation'])]
        #ann_pd_merged['dups'] = [i+j+k for i,j,k in zip(ann_pd_merged['age_1'], ann_pd_merged['region_1'], ann_pd_merged['annotation'])]
        ann_pd_merged = ann_pd_merged.drop_duplicates(subset=['dups'], keep='first').reset_index(drop=True)
        if conditions_order == None: 
            conditions_order = np.unique(ann_pd_merged['annotation'])  
        ann_pd_merged = ann_pd_merged[ann_pd_merged['annotation'].isin(conditions_order)]

        #makes sure unique gene names per group
        tmp = ann_pd_merged.groupby("annotation").sum()
        tmp = tmp.replace(to_replace='None', value=np.nan).dropna()
        ann_pd_merged_tmp = pd.DataFrame(columns = ['genes', 'Delta', 'logBFs'])
        for i,j in enumerate(tmp.index):
            gn_inx = list(np.unique(list(tmp.iloc[i,:]['genes']),return_index=True))
            #gn_inx = list((list(tmp.iloc[i,:]['genes']),return_index=True))
            bfs_ind = np.argsort(np.array([list(tmp.iloc[i,:]['logBFs'])[j] for j in gn_inx[1]]))[::-1]
            ann_pd_merged_tmp.at[i,'genes'] = [[list(tmp.iloc[i,:]['genes'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
            ann_pd_merged_tmp.at[i,'Delta'] = [[list(tmp.iloc[i,:]['Delta'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
            ann_pd_merged_tmp.at[i,'logBFs'] = [[list(tmp.iloc[i,:]['logBFs'])[j] for j in gn_inx[1]][k] for k in bfs_ind]
        ann_pd_merged_tmp['annotation'] = tmp.index
        inx_name = ann_pd_merged_tmp['annotation']

        #rank_genes_groups['names'] = np.unique(pd.concat([pd.DataFrame([pd.Series(i) for i in ann_pd_merged['genes']], index = inx_name)]).astype(str).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False))
        rank_genes_groups['names'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['genes']], index = inx_name)]).groupby(level=0).first().T.replace('nan', '').to_records(column_dtypes='O',index=False)
        rank_genes_groups['logfoldchanges'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['Delta']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float32',index=False)
        rank_genes_groups['pvals'] = pd.concat([pd.DataFrame([(pd.Series(i)) for i in ann_pd_merged_tmp['logBFs']], index = inx_name)]).groupby(level=0).first().T.to_records(column_dtypes='float64',index=False)

        #means_pd_index = ann_pd_merged_tmp.groupby('age')['index'].apply(list).reset_index()
        marker_gene_means = []
        cond_names = []
        for cond in np.array(conditions_order):
            if not cond in list(ann_pd_merged_tmp['annotation']):
                continue
            cond_names.append(cond)
            asub = a[:,[x for x in rank_genes_groups['names'][cond] if str(x) != 'nan']]
            asub.obs['merging'] = cond
            sc.pp.scale(asub, max_value=10)
            marker_gene_means.append(list(grouped_obs_mean(asub, group_key = 'merging').reindex(rank_genes_groups['names'][cond])[cond]))
        rank_genes_groups['scores'] = pd.DataFrame(marker_gene_means, index = cond_names).T.to_records(column_dtypes='float32',index=False)

    #return a new anndata object
    print(mode)
    a.uns[mode] = rank_genes_groups

def filter_rank_genes_groups(
    adata,
    key=None,
    key_added='rank_genes_groups_filtered',
    min_fold_change=0,
    min_pvals_change=0.5,
    min_scores=0
) -> None:

    if key is None:
        key = 'rank_genes_groups'

    # convert structured numpy array into DataFrame
    gene_names = pd.DataFrame(adata.uns[key]['names'])
    
    #gets unique names
    genes_new = pd.DataFrame(columns = gene_names.columns)
    for cond, ind in enumerate(gene_names):
        genes_tmp = []
        for i in range(0,len(gene_names[ind])):
            if str(gene_names[ind][i]) in genes_tmp:
                genes_new.at[i, ind] = np.nan
            else:
                if gene_names[ind][i] == 'nan':
                    gene_names[ind][i] = np.nan
                genes_new.at[i, ind] = gene_names[ind][i]
            genes_tmp.append(gene_names[ind][i])
    gene_names = genes_new    

    fold_change_matrix = pd.DataFrame(adata.uns[key]['logfoldchanges'])
    pvals_change_matrix = pd.DataFrame(adata.uns[key]['pvals'])
    scores_change_matrix = pd.DataFrame(adata.uns[key]['scores'])
    
    # filter original_matrix
    gene_names = gene_names[(fold_change_matrix >= min_fold_change)]
    gene_names = gene_names[(pvals_change_matrix >= min_pvals_change)]
    gene_names = gene_names[(scores_change_matrix >= min_scores)]
    fold_change_matrix = fold_change_matrix[(fold_change_matrix >= min_fold_change)]
    pvals_change_matrix = pvals_change_matrix[(pvals_change_matrix >= min_pvals_change)]  
    scores_change_matrix = scores_change_matrix[(scores_change_matrix >= min_scores)]
    
    # create new structured array using 'key_added'.
    adata.uns[key_added] = adata.uns[key].copy()
    adata.uns[key_added]['names'] = gene_names.to_records(index=False,column_dtypes='O')
    adata.uns[key_added]['logfoldchanges'] = fold_change_matrix.to_records(index=False,column_dtypes='float32')
    adata.uns[key_added]['pvals'] = pvals_change_matrix.to_records(index=False,column_dtypes='float64')
    adata.uns[key_added]['pvals_adj'] = pvals_change_matrix.to_records(index=False,column_dtypes='float64')
    adata.uns[key_added]['scores'] = scores_change_matrix.to_records(index=False,column_dtypes='float32')

def filter_BFs(st_spec):
    
    """
    A function that filters the BF files from cloud
    # 
    # Inputs:
    #    st_spec             - BF data frame
    #    ST_top_gene_dict   - A pd.DataFrame with fields: age_1, age_2, region_1, region_2, AAR1, AAR2, logsBFs (list), Delta (list), genes (list)  


    """
   
    # do some renaming
# st_spec = st_spec[st_spec['BF'] != inf]

    # Log10 BF
#     st_spec['BF'] = np.float64(st_spec['BF'])
#     st_spec.BF[st_spec.BF == -inf] = sys.float_info.min # makes sure no inf
#     st_spec.BF[st_spec.BF == inf] = sys.float_info.max # makes sure no inf
#     st_spec['logBF'] = np.log(st_spec['BF'])

    # rename gene names
    st_spec['gene_new'] = [i.split("_")[0] for i in st_spec['gene']]
    st_spec['genotype_1'] = [i.split(" ")[0] for i in st_spec['condition_1']]
    st_spec['sex_1'] = [i.split(" ")[1] for i in st_spec['condition_1']]
    st_spec['genotype_2'] = [i.split(" ")[0] for i in st_spec['condition_2']]
    st_spec['sex_2'] = [i.split(" ")[1] for i in st_spec['condition_2']]    
    
    ## Top 100 ST genes per condition and per region
    ST_top_gene_dict = pd.DataFrame(columns = ['genotype_1', 'genotype_2', 'sex_1', 'sex_2', 'AAR1', 'AAR2', 'genes', 'logBFs', 'Delta'])
    counter = 0
    df_group = st_spec.groupby(['genotype_1', 'genotype_2', 'sex_1', 'sex_2', 'AAR1', 'AAR2'])
    for label, dfs in df_group: # this is for splotch_one_level

        # this gets genes super specific against the whole rest of the datset

        #if (label[5] == 'Rest'):

        #print(counter)

        #dfs = df[(df['logBF'] > 2) & (df['Delta'] > 0)]
        #dfs = dfs[(dfs['logBF'] > 0)& (dfs['Delta'] > 0)]
        #dfs = df
        #print(df.sort_values(by='logBF', ascending=False)['gene_new'].head(5).tolist())
        if (len(dfs.sort_values(by='logBF', ascending=False)['gene_new'].head(5).tolist()) == 0):
            continue

        ST_top_gene_dict.at[counter, 'genotype_1'] = label[0]
        ST_top_gene_dict.at[counter, 'genotype_2'] = label[1]
        ST_top_gene_dict.at[counter, 'sex_1'] = label[2]
        ST_top_gene_dict.at[counter, 'sex_2'] = label[3]
        ST_top_gene_dict.at[counter, 'AAR1'] = label[4]
        ST_top_gene_dict.at[counter, 'AAR2'] = label[5]
        ST_top_gene_dict.at[counter, 'genes'] = dfs.sort_values(by=['logBF',], ascending=[False,])['gene_new'].tolist()[0:3999]
        ST_top_gene_dict.at[counter, 'logBFs'] = dfs.sort_values(by=['logBF', ], ascending=[False,])['logBF'].tolist()[0:3999]
        ST_top_gene_dict.at[counter, 'Delta'] = dfs.sort_values(by=['logBF', ], ascending=[False,])['Delta'].tolist()[0:3999]
        counter += 1
    print('done clean up')
    return ST_top_gene_dict


### Reads in tsv file and makes a dataframe for bacterial comparisons

In [3]:
# Load ST files  
path = '/home/sanjavickovic/data/host-microbiome_data/splotch_outputs/wt_gf_zeros/'

# Read file
filename = os.path.join(path, 'BF.tsv.gz')

# makes initial pd
fields = ['gene', 'condition_1', 'condition_2', 'AAR1', 'AAR2', 'BF', 'Delta']
# reader = pd.read_csv(filename, sep='\t', chunksize=int(100000), engine='python', skipinitialspace=True, usecols=fields)    

# # makes small pkl files
# for i, chunk in enumerate(reader):
#     out_file = path + "data_subset_{}.pkl".format(i+1)
#     print("Processing chunk: ", i+1)
#     with open(out_file, "wb") as f:
#         pickle.dump(chunk,f,pickle.HIGHEST_PROTOCOL)

# arranges pkl files into a df
data_p_files=[]
for name in glob.glob(path + "data_subset_*.pkl"):
       data_p_files.append(name)
df_bf = pd.DataFrame([])
for i in range(len(data_p_files)):
    #if i % 20 == 0:
    print("Reading chunk: ", i+1, ' out of: ', len(data_p_files))
    tmp = pd.read_pickle(data_p_files[i])
    tmp = tmp[tmp['gene'] != 'gene']    
    tmp['BF'] = tmp['BF'].astype(float)
    tmp['Delta'] = tmp['Delta'].astype(float)
    tmp = tmp[(tmp['BF']>0) & (tmp['Delta']>0)]
    df_bf = df_bf.append(tmp,ignore_index=True)
    
    

Reading chunk:  1  out of:  35
Reading chunk:  2  out of:  35
Reading chunk:  3  out of:  35
Reading chunk:  4  out of:  35
Reading chunk:  5  out of:  35
Reading chunk:  6  out of:  35
Reading chunk:  7  out of:  35
Reading chunk:  8  out of:  35
Reading chunk:  9  out of:  35
Reading chunk:  10  out of:  35
Reading chunk:  11  out of:  35
Reading chunk:  12  out of:  35
Reading chunk:  13  out of:  35
Reading chunk:  14  out of:  35
Reading chunk:  15  out of:  35
Reading chunk:  16  out of:  35
Reading chunk:  17  out of:  35
Reading chunk:  18  out of:  35
Reading chunk:  19  out of:  35
Reading chunk:  20  out of:  35
Reading chunk:  21  out of:  35
Reading chunk:  22  out of:  35
Reading chunk:  23  out of:  35
Reading chunk:  24  out of:  35
Reading chunk:  25  out of:  35
Reading chunk:  26  out of:  35
Reading chunk:  27  out of:  35
Reading chunk:  28  out of:  35
Reading chunk:  29  out of:  35
Reading chunk:  30  out of:  35
Reading chunk:  31  out of:  35
Reading chunk:  3

In [4]:
'Subset anndata and ST_dict to same genes'
'Read in large anndata'
a = sc.read_h5ad('/home/sanjavickovic/data/host-microbiome_data/st_data/anndata_hm_norm_all_Feb2022.h5ad')

In [5]:
'Add WT bacterial lambdas in place'
def rename_lambdas_index(lambdas_file): 
    nwe=[]
    nm=lambdas_file.index
    for item in nm:
        nwe.append(str(item).split("_")[0])
    return nwe

# Load Lambda pmean df as from gcp
path = '/home/sanjavickovic/data/host-microbiome_data/splotch_outputs/wt_gf_zeros'

# Read expression file
filename = os.path.join(path, 'lambdas_python.tsv.gz')  
lambda_posterior_means = pd.read_csv(filename, index_col=0, low_memory=False, header=[0,1], sep=',') 

#rename genes
lambda_posterior_means.index = rename_lambdas_index(lambda_posterior_means)

#rename columns 
count_files = lambda_posterior_means.columns.map('_'.join).str.strip('_')
count_files_names = [i.split("/")[-1].replace("_stdata_adjusted.tsv", "") for i in count_files]
lambda_posterior_means.columns = count_files_names

# Take exp()
lambda_posterior_means = lambda_posterior_means.astype(float)
lambda_posterior_means = np.exp(lambda_posterior_means-1)

#prep lamdas
lambda_posterior_means_t = lambda_posterior_means.T
lambda_posterior_means_t = lambda_posterior_means_t.loc[:,~lambda_posterior_means_t.columns.str.startswith('coordinate')]

# subset to bacteria only
#lambda_posterior_means_t = lambda_posterior_means_t[[i for i in lambda_posterior_means_t.columns if i.endswith("Bacteria")]]

In [6]:
# "Subset to GF"
# a_gf = a[~a.obs.index.isin(lambda_posterior_means_t.index)]

In [7]:
# "Subset to WT"
# a_wt = a[a.obs.index.isin(lambda_posterior_means_t.index)]
# lambda_posterior_means_t = lambda_posterior_means_t.reindex(a_wt.obs.index)

In [8]:
# "Subset WT to bac vs. genes"
# a_wt_bac = a_wt[:,a_wt.var_names.isin(lambda_posterior_means_t.columns)]
# a_wt_bac.X = lambda_posterior_means_t.values
# a_wt_genes = a_wt[:,~a_wt.var_names.isin(lambda_posterior_means_t.columns)]
# a_wt_n = a_wt.copy()
# a_wt_n.X = np.concatenate((a_wt_genes.X, a_wt_bac.X), axis = 1)

In [9]:
# "Make new anndata"
# a_n = a_wt_n.concatenate(a_gf)
# a_n.obs.drop(["batch", "Gnai3"], axis = 1, inplace = True)


In [10]:
'Put in new lambdas with exhisting observations'
lambda_posterior_means_t = lambda_posterior_means_t.reindex(a.obs.index)
lambda_posterior_means_t = lambda_posterior_means_t.dropna()
lambda_posterior_means_t = lambda_posterior_means_t.loc[:,~lambda_posterior_means_t.columns.duplicated()]

In [11]:
obs = a.obs.loc[lambda_posterior_means_t.index]

In [12]:
# puts in new lambdas into 
a_n = anndata.AnnData(X = lambda_posterior_means_t.values, obs = obs)
a_n.var_names = lambda_posterior_means_t.columns

In [13]:
# remove two extra bacteria
#a_n = a_n[:,a_n.var_names.isin([i for i in a_n.var_names if 'Lacrimispora' not in i])]
#a_n = a_n[:,a_n.var_names.isin([i for i in a_n.var_names if 'Acetobacterium' not in i])]

In [14]:
'Adds bacterial to genes expression lambdas'
a_n.write_h5ad(filename = '/home/sanjavickovic/data/host-microbiome_data/st_data/anndata_hm_norm_all_n_Feb2022.h5ad')

In [15]:
'Filter to get only bacteria'
df_bf['gene_new'] = [i.split("_")[0] for i in df_bf.gene]
#df_bf = df_bf.loc[[i[0] for i in enumerate(df_bf.gene_new) if i[1] in a.var_names]]
#df_bf.reset_index(inplace = True)
# df_bf.condition_1 = 'WT F'
# df_bf.condition_2 = 'WT F'

'This are the wt bacterial comparisons'
tmp = df_bf[(df_bf.AAR2 == 'Rest')&(df_bf.condition_1 == df_bf.condition_2) &(df_bf.condition_1 == 'WT F')]
tmp_bac = tmp[tmp.gene.isin([i for i in tmp.gene if "Bacteria" in i])]

In [16]:
# remove two additional bacteria
tmp_bac = tmp_bac[tmp_bac.gene.isin([i for i in tmp_bac.gene if "Lacrimispora" not in i])]
tmp_bac = tmp_bac[tmp_bac.gene.isin([i for i in tmp_bac.gene if "Acetobacterium" not in i])]

In [17]:
tmp_wt_vs_gf = df_bf[(df_bf.condition_1 == 'WT F') & ((df_bf.condition_2 == 'GF F') | (df_bf.condition_2 == 'GF M')) & (df_bf.gene.isin([i for i in df_bf.gene if "Bacteria" in i]))]
tmp_gf_vs_wt = df_bf[((df_bf.condition_1 == 'GF F') | (df_bf.condition_1 == 'GF M')) & (df_bf.condition_2 == 'WT F') & (df_bf.gene.isin([i for i in df_bf.gene if "Bacteria" in i]))]

In [18]:
# remove two additional bacteria
tmp_wt_vs_gf = tmp_wt_vs_gf[tmp_wt_vs_gf.gene.isin([i for i in tmp_wt_vs_gf.gene if "Lacrimispora" not in i])]
tmp_wt_vs_gf = tmp_wt_vs_gf[tmp_wt_vs_gf.gene.isin([i for i in tmp_wt_vs_gf.gene if "Acetobacterium" not in i])]
tmp_gf_vs_wt = tmp_gf_vs_wt[tmp_gf_vs_wt.gene.isin([i for i in tmp_gf_vs_wt.gene if "Acetobacterium" not in i])]
tmp_gf_vs_wt = tmp_gf_vs_wt[tmp_gf_vs_wt.gene.isin([i for i in tmp_gf_vs_wt.gene if "Lacrimispora" not in i])]

In [19]:
tmp_wt1 = df_bf[~df_bf.gene.isin([i for i in df_bf.gene if "Bacteria" in i])]

### Reads in tsv file and makes a dataframe for wt vs gf gene comparisons

In [63]:
# # Load ST files  
# path = '/home/brittalotstedt/host-microbiome/data/bfs/'

# # Read file
# filename = os.path.join(path, 'BF.tsv.gz')

# # makes initial pd
# fields = ['gene', 'condition_1', 'condition_2', 'AAR1', 'AAR2', 'BF', 'Delta']
# # reader = pd.read_csv(filename, sep='\t', chunksize=int(2000000), engine='python', skipinitialspace=True, usecols=fields)    

# # # makes small pkl files
# # for i, chunk in enumerate(reader):
# #     out_file = path + "data_subset_{}.pkl".format(i+1)
# #     print("Processing chunk: ", i+1)
# #     with open(out_file, "wb") as f:
# #         pickle.dump(chunk,f,pickle.HIGHEST_PROTOCOL)

# # arranges pkl files into a df
# data_p_files=[]
# for name in glob.glob(path + "data_subset*.pkl"):
#        data_p_files.append(name)
# df_bf = pd.DataFrame([])
# for i in range(len(data_p_files)):
#     #if i % 20 == 0:
#     print("Reading chunk: ", i+1, ' out of: ', len(data_p_files))
#     tmp = pd.read_pickle(data_p_files[i])
#     tmp = tmp[tmp['gene'] != 'gene']    
#     tmp['BF'] = tmp['BF'].astype(float)
#     tmp['Delta'] = tmp['Delta'].astype(float)
#     tmp = tmp[(tmp['BF']>0) & (tmp['Delta']>0)]
#     df_bf = df_bf.append(tmp,ignore_index=True)
    
    

In [64]:
# 'Remove bacterial comparisons from df file'
# df_bf['gene_new'] = [i.split("_")[0] for i in df_bf.gene]
# #tmp_wt_vs_gf = df_bf[(df_bf.condition_1 == 'WT F') & ((df_bf.condition_2 == 'GF F') | (df_bf.condition_2 == 'GF M')) & (df_bf.gene.isin([i for i in df_bf.gene if "Bacteria" in i]))]
# tmp_wt = df_bf[~df_bf.gene.isin([i for i in df_bf.gene if "Bacteria" in i])]

In [58]:
df_bf_combined = pd.concat([tmp_bac, tmp_wt1, tmp_wt_vs_gf, tmp_gf_vs_wt])
df_bf_combined.reset_index(inplace = True)

In [59]:
df_bf_combined.BF = np.float64(df_bf_combined.BF)
df_bf_combined.BF[df_bf_combined.BF == -inf] = sys.float_info.min # makes sure no inf
df_bf_combined.BF[df_bf_combined.BF == inf] = sys.float_info.max # makes sure no inf
df_bf_combined['logBF'] = np.log(df_bf_combined.BF)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  This is separate from the ipykernel package so we can avoid doing imports until


In [70]:
ST_top_gene_dict = filter_BFs(df_bf_combined)

done clean up


In [71]:
ST_top_gene_dict

Unnamed: 0,genotype_1,genotype_2,sex_1,sex_2,AAR1,AAR2,genes,logBFs,Delta
0,GF,GF,F,F,crypt apex and crypt mid,Rest,"[Plac8, Slco6c1, Uckl1, Gm12367, Tspear, Mtmr1...","[2.0122117282095053, 1.542921882516165, 1.4583...","[0.65223562354725, 3.8131681028384583, 1.25715..."
1,GF,GF,F,F,crypt apex and mucosa,Rest,"[Mtmr11, Gm32200, Scd4, Cyp2d34, Gimd1, Lyzl4o...","[0.28519925093975473, 0.029994950489092207, -0...","[0.6372250459891668, 1.2219540229105, 0.415214..."
2,GF,GF,F,F,crypt base,Rest,"[Bpifc, Mapre3, Usp9y, Dydc1, Gm4955, Gm15655,...","[1.0082060694559545, 0.8305594259474725, 0.734...","[2.5167639623125004, 1.8401814456808336, 2.079..."
3,GF,GF,F,F,crypt mid,Rest,"[Ydjc, Aknad1, Stard4, Rnf4, Phb, Ctc1, Smim4,...","[3.4466429431031056, 2.533555270712622, 1.4905...","[1.937673891208333, 4.860282585321833, 0.70985..."
4,GF,GF,F,F,epithelium,Rest,"[Fgf23, Igkv8-31, Npr1, Gm43705, Tnfsf11, Clec...","[0.21825720536127427, -0.06046643125800337, -0...","[0.05216546580691617, 0.7135883506803333, 0.08..."
...,...,...,...,...,...,...,...,...,...
139,WT,WT,F,F,mucosa and pellet,Rest,"[Saa1, Lypd8, Car4, Abcb1a, Sepp1, Hist1h1c, P...","[3.5810409565550567, 1.1435238760455382, 1.131...","[1.142071956405275, 0.5003454392266667, 0.8313..."
140,WT,WT,F,F,mucosae and interna,Rest,"[P4ha1, Mylk3, Gm3985, Dhx32, Tmem45a, Nog, Gm...","[1.30868985066532, 1.2956374600132459, 1.14272...","[1.1789173022458332, 2.065658233677333, 2.0385..."
141,WT,WT,F,F,muscle and submucosa,Rest,"[Cd2, Lrp8os3, Cyp4a10, Oas1g, Gm38073, 493057...","[0.2730773858106193, 0.17285301481707396, 0.16...","[1.7318922431808335, 0.30870173119785826, 0.32..."
142,WT,WT,F,F,pellet,Rest,"[Clca1, Rbm47, Slc26a3, Saa1, Massilistercora-...","[2.4439504979436975, 0.6585590799053893, 0.574...","[0.31390891916666597, 0.28664277937833305, 0.3..."


Saves formated DE genes df

In [72]:
path = '/home/sanjavickovic/data/host-microbiome_data/splotch_outputs'
ST_top_gene_dict.to_csv(os.path.join(path, 'ST_top_gene_dict_BF0_combined.csv'))